/// <summary>
    /// THis is implemented for ISupervisedLearingModel so that this model can also be used for TrainerMimic
    /// </summary>
    /// <param name="vectorObservation"></param>
    /// <param name="visualObservation"></param>
    /// <returns>(mean, var) var will be null for discrete</returns>
    public ValueTuple <float[, ], float[, ]> EvaluateAction(float[,] vectorObservation, List <float[, , , ]> visualObservation, List <float[, ]> actionsMask)
    {
        List <Array> inputLists = new List <Array>();

        if (HasVectorObservation)
        {
            Debug.Assert(vectorObservation != null, "Must Have vector observation inputs!");
            inputLists.Add(vectorObservation);
        }
        if (HasVisualObservation)
        {
            Debug.Assert(visualObservation != null, "Must Have visual observation inputs!");
            inputLists.AddRange(visualObservation);
        }

        if (ActionSpace == SpaceType.discrete)
        {
            int batchSize = vectorObservation != null?vectorObservation.GetLength(0) : visualObservation[0].GetLength(0);

            int branchSize         = ActionSizes.Length;
            List <float[, ]> masks = actionsMask;
            //create all 1 mask if the input mask is null.
            if (masks == null)
            {
                masks = CreateDummyMasks(ActionSizes, batchSize);
            }
            inputLists.AddRange(masks);
        }

        var result = PolicyFunction.Call(inputLists);

        float[,] actions = ((float[, ])result[0].eval());

        float[,] outputVar = null;
        if (SLHasVar)
        {
            object varTemp = result[1].eval();
            if (varTemp is float[, ])
            {
                outputVar = (float[, ])result[1].eval();
            }
            else
            {
                outputVar = null;
            }
        }

        //normlaized the input observations in every calll of eval action
        if (useInputNormalization && HasVectorObservation)
        {
            UpdateNormalizerFunction.Call(new List <Array>()
            {
                vectorObservation
            });
        }


        return(ValueTuple.Create(actions, outputVar));
    }
Exemplo n.º 2
0
    /// <summary>
    /// Query actions based on curren states. The first dimension of the array must be batch dimension
    /// </summary>
    /// <param name="vectorObservation">current vector states. Can be batch input</param>
    /// <returns></returns>
    public virtual float[,] EvaluateAction(float[,] vectorObservation, List <float[, , , ]> visualObservation)
    {
        List <Array> inputLists = new List <Array>();

        if (HasVectorObservation)
        {
            Debug.Assert(vectorObservation != null, "Must Have vector observation inputs!");
            inputLists.Add(vectorObservation);
        }
        if (HasVisualObservation)
        {
            Debug.Assert(visualObservation != null, "Must Have visual observation inputs!");
            inputLists.AddRange(visualObservation);
        }

        var result = ActionFunction.Call(inputLists);

        var outputAction = ((float[, ])result[0].eval());

        float[,] actions = new float[outputAction.GetLength(0), outputAction.GetLength(1)];

        actions = outputAction;


        if (useInputNormalization && HasVectorObservation)
        {
            UpdateNormalizerFunction.Call(new List <Array>()
            {
                vectorObservation
            });
        }

        /*for(int i = 0; i < actions.GetLength(0); ++i)
         * {
         *  for (int j = 0; j < actions.GetLength(1); ++j)
         *  {
         *      if (float.IsNaN(actions[i, j]))
         *      {
         *          Debug.LogError("error");
         *      }
         *  }
         * }*/

        return(actions);
    }
Exemplo n.º 3
0
    /// <summary>
    /// Query actions based on curren states. The first dimension of the array must be batch dimension
    /// </summary>
    /// <param name="vectorObservation">current vector states. Can be batch input</param>
    /// <param name="actionProbs">output actions' probabilities. note that it is the normalized log probability</param>
    /// <param name="actoinsMask">action mask for discrete action. </param>
    /// <returns></returns>
    public virtual float[,] EvaluateAction(float[,] vectorObservation, out float[,] actionProbs, List <float[, , , ]> visualObservation, List <float[, ]> actionsMask = null)
    {
        Debug.Assert(mode == Mode.PPO, "This method is for PPO mode only");

        List <Array> inputLists = new List <Array>();

        if (HasVectorObservation)
        {
            Debug.Assert(vectorObservation != null, "Must Have vector observation inputs!");
            inputLists.Add(vectorObservation);
        }
        if (HasVisualObservation)
        {
            Debug.Assert(visualObservation != null, "Must Have visual observation inputs!");
            inputLists.AddRange(visualObservation);
        }



        float[,] actions = null;
        actionProbs      = null;

        if (ActionSpace == SpaceType.continuous)
        {
            var result = ActionFunction.Call(inputLists);
            actions     = ((float[, ])result[0].eval());
            actionProbs = ((float[, ])result[1].eval());
        }
        else if (ActionSpace == SpaceType.discrete)
        {
            int batchSize = vectorObservation != null?vectorObservation.GetLength(0) : visualObservation[0].GetLength(0);

            int branchSize         = ActionSizes.Length;
            List <float[, ]> masks = actionsMask;

            //create all 1 mask if the input mask is null.
            if (masks == null)
            {
                masks = CreateDummyMasks(ActionSizes, batchSize);
            }

            inputLists.AddRange(masks);

            var result = ActionFunction.Call(inputLists);
            actions = ((float[, ])result[0].eval());

            //get the log probabilities
            actionProbs = new float[batchSize, branchSize];
            for (int b = 0; b < branchSize; ++b)
            {
                var tempProbs = ((float[, ])result[b + 1].eval());
                int actSize   = ActionSizes[b];
                for (int i = 0; i < batchSize; ++i)
                {
                    actionProbs[i, b] = tempProbs[i, Mathf.RoundToInt(actions[i, b])];
                }
            }
        }

        //normlaized the input observations in every calll of eval action
        if (useInputNormalization && HasVectorObservation)
        {
            UpdateNormalizerFunction.Call(new List <Array>()
            {
                vectorObservation
            });
        }

        return(actions);
    }
Exemplo n.º 4
0
    /// <summary>
    /// Query actions based on curren states. The first dimension of the array must be batch dimension
    /// </summary>
    /// <param name="vectorObservation">current vector states. Can be batch input</param>
    /// <param name="actionProbs">output actions' probabilities</param>
    /// <param name="useProbability">when true, the output actions are sampled based on output mean and variance. Otherwise it uses mean directly.</param>
    /// <returns></returns>
    public virtual float[,] EvaluateAction(float[,] vectorObservation, out float[,] actionProbs, List <float[, , , ]> visualObservation, bool useProbability = true)
    {
        Debug.Assert(mode == Mode.PPO, "This method is for PPO mode only");

        List <Array> inputLists = new List <Array>();

        if (HasVectorObservation)
        {
            Debug.Assert(vectorObservation != null, "Must Have vector observation inputs!");
            inputLists.Add(vectorObservation);
        }
        if (HasVisualObservation)
        {
            Debug.Assert(visualObservation != null, "Must Have visual observation inputs!");
            inputLists.AddRange(visualObservation);
        }

        var result = ActionFunction.Call(inputLists);

        var outputAction = ((float[, ])result[0].eval());

        float[,] actions = new float[outputAction.GetLength(0), ActionSpace == SpaceType.continuous ? outputAction.GetLength(1) : 1];
        actionProbs      = new float[outputAction.GetLength(0), ActionSpace == SpaceType.continuous ? outputAction.GetLength(1) : 1];

        if (ActionSpace == SpaceType.continuous)
        {
            actions     = outputAction;
            actionProbs = ((float[, ])result[1].eval());
            //var actionsMean = (float[,])(result[2].eval());
            //var actionsVars = (float[])(result[3].eval());
            //print("actual vars" + actions.GetColumn(0).Variance()+"," + actions.GetColumn(1).Variance() + "," + actions.GetColumn(2).Variance() + "," + actions.GetColumn(3).Variance());
        }
        else if (ActionSpace == SpaceType.discrete)
        {
            for (int j = 0; j < outputAction.GetLength(0); ++j)
            {
                if (useProbability)
                {
                    actions[j, 0] = MathUtils.IndexByChance(outputAction.GetRow(j));
                }
                else
                {
                    actions[j, 0] = outputAction.GetRow(j).ArgMax();
                }

                actionProbs[j, 0] = outputAction.GetRow(j)[Mathf.RoundToInt(actions[j, 0])];
            }
        }

        if (useInputNormalization && HasVectorObservation)
        {
            UpdateNormalizerFunction.Call(new List <Array>()
            {
                vectorObservation
            });
            //var runningMean = (float[])runningData[0].eval();
            //var runningVar = (float[])runningData[1].eval();
            //var steps = (float)runningData[2].eval();
            //var normalized = (float[,])runningData[3].eval();
        }


        return(actions);
    }