Exemplo n.º 1
0
    /// Uses the continuous inputs or dicrete inputs of the player to
    /// decide action
    public void DecideAction()
    {
        if (ctrl == null)
        {
            ctrl = brain.brainParameters.syft.controller;
        }

        if (policy == null)
        {
            found_policy = false;
        }
        else
        {
            found_policy = true;
        }

        if (ctrl.getAgent(1234) != null)
        {
            policy = ctrl.getAgent(1234);
            if (found_policy == false)
            {
                foreach (KeyValuePair <int, global::Agent> idAgent in brain.agents)
                {
                    idAgent.Value.Reset();
                }
            }
        }
        else
        {
            policy = null;
        }

        //The states are collected in order to debug the CollectStates method.
        Dictionary <int, List <float> > states  = brain.CollectStates();
        Dictionary <int, float>         rewards = brain.CollectRewards();
        Dictionary <int, bool>          dones   = brain.CollectDones();

        if (brain.brainParameters.actionSpaceType == StateType.continuous)
        {
            float[] action = new float[brain.brainParameters.actionSize];
            foreach (ContinuousPlayerAction cha in continuousPlayerActions)
            {
                if (Input.GetKey(cha.key))
                {
                    action[cha.index] = cha.value;
                }
            }
            Dictionary <int, float[]> actions = new Dictionary <int, float[]>();
            foreach (KeyValuePair <int, global::Agent> idAgent in brain.agents)
            {
                actions.Add(idAgent.Key, action);
            }
            brain.SendActions(actions);
        }
        else
        {
            float[] action = new float[1] {
                defaultAction
            };
            foreach (DiscretePlayerAction dha in discretePlayerActions)
            {
                if (Input.GetKey(dha.key))
                {
                    action[0] = (float)dha.value;
                    break;
                }
            }
            Dictionary <int, float[]> actions = new Dictionary <int, float[]>();
            foreach (KeyValuePair <int, global::Agent> idAgent in brain.agents)
            {
                if (policy == null)
                {
                    // do nothing - you don't have a network
                    actions.Add(idAgent.Key, new float[1] {
                        0
                    });
                }
                else
                {
                    //input = [Number of agents x state size]
                    FloatTensor input = ctrl.floatTensorFactory.Create(_shape: new int[] { 1, states[idAgent.Key].Count },
                                                                       _data: states[idAgent.Key].ToArray());

                    IntTensor pred = policy.Sample(input);
                    actions.Add(idAgent.Key, new float[1] {
                        pred.Data[0]
                    });
                }
            }

            brain.SendActions(actions);
        }
    }