/// <summary> /// Query actions based on curren states. The first dimension of the array must be batch dimension /// </summary> /// <param name="vectorObservation">current vector states. Can be batch input</param> /// <returns></returns> public virtual float[,] EvaluateAction(float[,] vectorObservation, List <float[, , , ]> visualObservation) { List <Array> inputLists = new List <Array>(); if (HasVectorObservation) { Debug.Assert(vectorObservation != null, "Must Have vector observation inputs!"); inputLists.Add(vectorObservation); } if (HasVisualObservation) { Debug.Assert(visualObservation != null, "Must Have visual observation inputs!"); inputLists.AddRange(visualObservation); } var result = ActionFunction.Call(inputLists); var outputAction = ((float[, ])result[0].eval()); float[,] actions = new float[outputAction.GetLength(0), outputAction.GetLength(1)]; actions = outputAction; if (useInputNormalization && HasVectorObservation) { UpdateNormalizerFunction.Call(new List <Array>() { vectorObservation }); } /*for(int i = 0; i < actions.GetLength(0); ++i) * { * for (int j = 0; j < actions.GetLength(1); ++j) * { * if (float.IsNaN(actions[i, j])) * { * Debug.LogError("error"); * } * } * }*/ return(actions); }
/// <summary> /// Query actions' probabilities based on curren states. The first dimension of the array must be batch dimension /// </summary> public virtual float[,] EvaluateProbability(float[,] vectorObservation, float[,] actions, List <float[, , , ]> visualObservation) { Debug.Assert(mode == Mode.PPO, "This method is for PPO mode only"); Debug.Assert(TrainingEnabled == true, "The model needs to initalized with Training enabled to use EvaluateProbability()"); List <Array> inputLists = new List <Array>(); if (HasVectorObservation) { Debug.Assert(vectorObservation != null, "Must Have vector observation inputs!"); inputLists.Add(vectorObservation); } if (HasVisualObservation) { Debug.Assert(visualObservation != null, "Must Have visual observation inputs!"); inputLists.AddRange(visualObservation); } var actionProbs = new float[actions.GetLength(0), ActionSpace == SpaceType.continuous ? actions.GetLength(1) : 1]; if (ActionSpace == SpaceType.continuous) { inputLists.Add(actions); var result = ActionProbabilityFunction.Call(inputLists); actionProbs = ((float[, ])result[0].eval()); } else if (ActionSpace == SpaceType.discrete) { var result = ActionFunction.Call(inputLists); var outputAction = ((float[, ])result[0].eval()); for (int j = 0; j < outputAction.GetLength(0); ++j) { actionProbs[j, 0] = outputAction.GetRow(j)[Mathf.RoundToInt(actions[j, 0])]; } } return(actionProbs); }
/// <summary> /// Query actions based on curren states. The first dimension of the array must be batch dimension /// </summary> /// <param name="vectorObservation">current vector states. Can be batch input</param> /// <param name="actionProbs">output actions' probabilities</param> /// <param name="useProbability">when true, the output actions are sampled based on output mean and variance. Otherwise it uses mean directly.</param> /// <returns></returns> public virtual float[,] EvaluateAction(float[,] vectorObservation, out float[,] actionProbs, List <float[, , , ]> visualObservation, bool useProbability = true) { Debug.Assert(mode == Mode.PPO, "This method is for PPO mode only"); List <Array> inputLists = new List <Array>(); if (HasVectorObservation) { Debug.Assert(vectorObservation != null, "Must Have vector observation inputs!"); inputLists.Add(vectorObservation); } if (HasVisualObservation) { Debug.Assert(visualObservation != null, "Must Have visual observation inputs!"); inputLists.AddRange(visualObservation); } var result = ActionFunction.Call(inputLists); var outputAction = ((float[, ])result[0].eval()); float[,] actions = new float[outputAction.GetLength(0), ActionSpace == SpaceType.continuous ? outputAction.GetLength(1) : 1]; actionProbs = new float[outputAction.GetLength(0), ActionSpace == SpaceType.continuous ? outputAction.GetLength(1) : 1]; if (ActionSpace == SpaceType.continuous) { actions = outputAction; actionProbs = ((float[, ])result[1].eval()); //var actionsMean = (float[,])(result[2].eval()); //var actionsVars = (float[])(result[3].eval()); //print("actual vars" + actions.GetColumn(0).Variance()+"," + actions.GetColumn(1).Variance() + "," + actions.GetColumn(2).Variance() + "," + actions.GetColumn(3).Variance()); } else if (ActionSpace == SpaceType.discrete) { for (int j = 0; j < outputAction.GetLength(0); ++j) { if (useProbability) { actions[j, 0] = MathUtils.IndexByChance(outputAction.GetRow(j)); } else { actions[j, 0] = outputAction.GetRow(j).ArgMax(); } actionProbs[j, 0] = outputAction.GetRow(j)[Mathf.RoundToInt(actions[j, 0])]; } } if (useInputNormalization && HasVectorObservation) { UpdateNormalizerFunction.Call(new List <Array>() { vectorObservation }); //var runningMean = (float[])runningData[0].eval(); //var runningVar = (float[])runningData[1].eval(); //var steps = (float)runningData[2].eval(); //var normalized = (float[,])runningData[3].eval(); } return(actions); }