/// <summary> /// THis is implemented for ISupervisedLearingModel so that this model can also be used for TrainerMimic /// </summary> /// <param name="vectorObservation"></param> /// <param name="visualObservation"></param> /// <returns>(mean, var) var will be null for discrete</returns> public ValueTuple <float[, ], float[, ]> EvaluateAction(float[,] vectorObservation, List <float[, , , ]> visualObservation, List <float[, ]> actionsMask) { List <Array> inputLists = new List <Array>(); if (HasVectorObservation) { Debug.Assert(vectorObservation != null, "Must Have vector observation inputs!"); inputLists.Add(vectorObservation); } if (HasVisualObservation) { Debug.Assert(visualObservation != null, "Must Have visual observation inputs!"); inputLists.AddRange(visualObservation); } if (ActionSpace == SpaceType.discrete) { int batchSize = vectorObservation != null?vectorObservation.GetLength(0) : visualObservation[0].GetLength(0); int branchSize = ActionSizes.Length; List <float[, ]> masks = actionsMask; //create all 1 mask if the input mask is null. if (masks == null) { masks = CreateDummyMasks(ActionSizes, batchSize); } inputLists.AddRange(masks); } var result = PolicyFunction.Call(inputLists); float[,] actions = ((float[, ])result[0].eval()); float[,] outputVar = null; if (SLHasVar) { object varTemp = result[1].eval(); if (varTemp is float[, ]) { outputVar = (float[, ])result[1].eval(); } else { outputVar = null; } } //normlaized the input observations in every calll of eval action if (useInputNormalization && HasVectorObservation) { UpdateNormalizerFunction.Call(new List <Array>() { vectorObservation }); } return(ValueTuple.Create(actions, outputVar)); }
/// <summary> /// Query actions' probabilities based on curren states. The first dimension of the array must be batch dimension. Note that it is the normalized log probability /// </summary> public virtual float[,] EvaluateProbability(float[,] vectorObservation, float[,] actions, List <float[, , , ]> visualObservation, List <float[, ]> actionsMask = null) { Debug.Assert(TrainingEnabled == true, "The model needs to initalized with Training enabled to use EvaluateProbability()"); List <Array> inputLists = new List <Array>(); if (HasVectorObservation) { Debug.Assert(vectorObservation != null, "Must Have vector observation inputs!"); inputLists.Add(vectorObservation); } if (HasVisualObservation) { Debug.Assert(visualObservation != null, "Must Have visual observation inputs!"); inputLists.AddRange(visualObservation); } var actionProbs = new float[actions.GetLength(0), ActionSpace == SpaceType.continuous ? actions.GetLength(1) : 1]; if (ActionSpace == SpaceType.continuous) { inputLists.Add(actions); var result = ActionProbabilityFunction.Call(inputLists); actionProbs = ((float[, ])result[0].eval()); } else if (ActionSpace == SpaceType.discrete) { List <float[, ]> masks = actionsMask; int batchSize = vectorObservation.GetLength(0); int branchSize = ActionSizes.Length; //create all 1 mask if the input mask is null. if (masks == null) { masks = CreateDummyMasks(ActionSizes, batchSize); } inputLists.AddRange(masks); var result = PolicyFunction.Call(inputLists); //get the log probabilities actionProbs = new float[batchSize, branchSize]; for (int b = 0; b < branchSize; ++b) { var tempProbs = ((float[, ])result[b + 1].eval()); int actSize = ActionSizes[b]; for (int i = 0; i < batchSize; ++i) { actionProbs[i, b] = tempProbs[i, Mathf.RoundToInt(actions[i, b])]; } } } return(actionProbs); }
/// <summary> /// Query actions based on curren states. The first dimension of the array must be batch dimension /// </summary> /// <param name="vectorObservation">current vector states. Can be batch input</param> /// <param name="actionProbs">output actions' probabilities. note that it is the normalized log probability</param> /// <param name="actoinsMask">action mask for discrete action. </param> /// <returns></returns> public virtual float[,] EvaluateAction(float[,] vectorObservation, out float[,] actionProbs, List <float[, , , ]> visualObservation, List <float[, ]> actionsMask = null) { List <Array> inputLists = new List <Array>(); if (HasVectorObservation) { Debug.Assert(vectorObservation != null, "Must Have vector observation inputs!"); inputLists.Add(vectorObservation); } if (HasVisualObservation) { Debug.Assert(visualObservation != null, "Must Have visual observation inputs!"); inputLists.AddRange(visualObservation); } float[,] actions = null; actionProbs = null; if (ActionSpace == SpaceType.continuous) { var result = PolicyFunction.Call(inputLists); actions = ((float[, ])result[0].eval()); actionProbs = ((float[, ])result[1].eval()); } else if (ActionSpace == SpaceType.discrete) { int batchSize = vectorObservation != null?vectorObservation.GetLength(0) : visualObservation[0].GetLength(0); int branchSize = ActionSizes.Length; List <float[, ]> masks = actionsMask; //create all 1 mask if the input mask is null. if (masks == null) { masks = CreateDummyMasks(ActionSizes, batchSize); } inputLists.AddRange(masks); var result = PolicyFunction.Call(inputLists); actions = ((float[, ])result[0].eval()); //get the log probabilities actionProbs = new float[batchSize, branchSize]; for (int b = 0; b < branchSize; ++b) { var tempProbs = ((float[, ])result[b + 1].eval()); int actSize = ActionSizes[b]; for (int i = 0; i < batchSize; ++i) { actionProbs[i, b] = tempProbs[i, Mathf.RoundToInt(actions[i, b])]; } } } //normlaized the input observations in every calll of eval action if (useInputNormalization && HasVectorObservation) { UpdateNormalizerFunction.Call(new List <Array>() { vectorObservation }); } return(actions); }