public QNetworkSimple(int stateSize, int actionSize, int numLayers, int hiddenSize, DeviceDescriptor device, float initialWeightScale = 0.01f) { Device = device; StateSize = stateSize; ActionSize = actionSize; //create actor network part var inputA = new InputLayerDense(stateSize); var outputA = new OutputLayerDense(hiddenSize, null, OutputLayerDense.LossFunction.None); outputA.HasBias = false; outputA.InitialWeightScale = initialWeightScale; SequentialNetworkDense qNetwork = new SequentialNetworkDense(inputA, LayerDefineHelper.DenseLayers(numLayers, hiddenSize, false, NormalizationMethod.None, 0, initialWeightScale, new ReluDef()), outputA, device); //seperate the advantage and value part. It is said to be better var midStream = outputA.GetOutputVariable(); var advantageStream = CNTKLib.Slice(midStream, AxisVector.Repeat(new Axis(0), 1), IntVector.Repeat(0, 1), IntVector.Repeat(hiddenSize / 2, 1)); var valueStream = CNTKLib.Slice(midStream, AxisVector.Repeat(new Axis(0), 1), IntVector.Repeat(hiddenSize / 2, 1), IntVector.Repeat(hiddenSize, 1)); var adv = Layers.Dense(advantageStream, actionSize, device, false, "QNetworkAdvantage", initialWeightScale); var value = Layers.Dense(valueStream, 1, device, false, "QNetworkValue", initialWeightScale); InputState = inputA.InputVariable; //OutputQs = outputA.GetOutputVariable(); OutputQs = value.Output + CNTKLib.Minus(adv, CNTKLib.ReduceMean(adv, Axis.AllStaticAxes())).Output; }
public static Function LayerNormalization(Variable input, DeviceDescriptor device, float initB = 0, float initScale = 1, string name = "", float eps = 0.00001f) { //get the mean first //var mean = CNTKLib.ReduceMean(input, Axis.AllStaticAxes()); //var centered = CNTKLib.Minus(input, mean); //var squared = CNTKLib.Square(centered); //var variance = CNTKLib.ReduceMean(squared, Axis.AllStaticAxes()); var biasParams = new Parameter(new int[] { 1 }, initB, device, name + BiasSuffix); var scaleParams = new Parameter(new int[] { 1 }, initScale, device, name + WeightSuffix); // var epsConst = Constant.Scalar(DataType.Float, 0.00001f); //var std = CNTKLib.Sqrt(variance + epsConst); var normalized = Normalize(input, Axis.AllStaticAxes(), eps); var result = normalized * scaleParams + biasParams; return(result); }
//public Variable testOutputProb; public DQLModel(QNetwork network) { Network = network; InputOldAction = CNTKLib.InputVariable(new int[] { 1 }, DataType.Float); InputTargetQ = CNTKLib.InputVariable(new int[] { 1 }, DataType.Float); var oneHotOldAction = CNTKLib.OneHotOp(InputOldAction, (uint)ActionSize, false, new Axis(0)); outputTargetQ = CNTKLib.ReduceSum(CNTKLib.ElementTimes(OutputQs, oneHotOldAction), Axis.AllStaticAxes()); //OutputLoss = CNTKLib.Square(CNTKLib.Minus(outputTargetQ, InputTargetQ),"Loss"); OutputLoss = Layers.HuberLoss(outputTargetQ, InputTargetQ, Device); OutputAction = CNTKLib.Argmax(OutputQs, new Axis(0)); OutputMaxQ = CNTKLib.ReduceMax(OutputQs, new Axis(0)); CNTKFunction = Function.Combine(new List <Variable>() { OutputLoss, OutputAction, OutputMaxQ }); }
static Function ReduceMeanAll(Function errors) { var allAxes = Axis.AllStaticAxes(); return(CNTKLib.ReduceMean(errors, allAxes)); }
protected override void EndProcessing() { var axis = Axis.AllStaticAxes(); WriteObject(axis); }
//public Variable testOutputProb; public PPOModel(PPONetwork network) { Network = network; //inputs if (IsActionContinuous) { InputAction = CNTKLib.InputVariable(new int[] { Network.ActionSize }, DataType.Float); InputOldProb = CNTKLib.InputVariable(new int[] { Network.ActionSize }, DataType.Float); } else { InputAction = CNTKLib.InputVariable(new int[] { 1 }, DataType.Float); InputOldProb = CNTKLib.InputVariable(new int[] { 1 }, DataType.Float); } InputAdvantage = CNTKLib.InputVariable(new int[] { 1 }, DataType.Float); InputTargetValue = CNTKLib.InputVariable(new int[] { 1 }, DataType.Float); InputClipEpsilon = Constant.Scalar <float>(0.1f, Device); InputValuelossWeight = Constant.Scalar <float>(1f, Device); InputEntropyLossWeight = Constant.Scalar <float>(0f, Device); Variable actionProb = null; if (IsActionContinuous) { //create the entropy loss part var temp = CNTKLib.ElementTimes(Constant.Scalar(DataType.Float, 2 * Mathf.PI * 2.7182818285), Network.OutputVariance); temp = CNTKLib.ElementTimes(Constant.Scalar(DataType.Float, 0.5), temp); OutputEntropy = CNTKLib.ReduceSum(temp, Axis.AllStaticAxes()); //probability actionProb = Layers.NormalProbability(InputAction, Network.OutputMean, Network.OutputVariance, Device); } else { OutputEntropy = CNTKLib.Minus(Constant.Scalar <float>(0, Device), CNTKLib.ReduceSum( CNTKLib.ElementTimes( Network.OutputProbabilities, CNTKLib.Log( Network.OutputProbabilities + Constant.Scalar <float>(0.000000001f, Device))), Axis.AllStaticAxes())); var oneHot = CNTKLib.OneHotOp(InputAction, (uint)Network.ActionSize, false, new Axis(0)); actionProb = CNTKLib.ReduceSum(CNTKLib.ElementTimes(Network.OutputProbabilities, oneHot), Axis.AllStaticAxes()); } //testOutputProb = actionProb; //value loss. Simple square loss OutputValueLoss = CNTKLib.SquaredError(Network.OutputValue, InputTargetValue); //policyloss //1. Clipped Surrogate loss var probRatio = CNTKLib.ElementDivide(actionProb, InputOldProb + Constant.Scalar <float>(0.0000000001f, Device)); var p_opt_a = CNTKLib.ElementTimes(probRatio, InputAdvantage); var p_opt_b = CNTKLib.ElementTimes( CNTKLib.Clip(probRatio, CNTKLib.Minus( Constant.Scalar <float>(1, Device), InputClipEpsilon), Constant.Scalar <float>(1, Device) + InputClipEpsilon), InputAdvantage); OutputPolicyLoss = CNTKLib.Minus(Constant.Scalar <float>(1, Device), CNTKLib.ReduceMean(CNTKLib.ElementMin(p_opt_a, p_opt_b, "min"), Axis.AllStaticAxes())); //OutputPolicyLoss = CNTKLib.ReduceMean(CNTKLib.ElementMin(p_opt_a, p_opt_b, "min"), Axis.AllStaticAxes()); //OutputPolicyLoss = CNTKLib.Minus(Constant.Scalar<float>(1, Device), CNTKLib.ReduceMean(p_opt_a, Axis.AllStaticAxes())); //final weighted loss OutputLoss = OutputPolicyLoss + CNTKLib.ElementTimes(InputValuelossWeight, OutputValueLoss); OutputLoss = CNTKLib.Minus(OutputLoss, CNTKLib.ElementTimes(InputEntropyLossWeight, OutputEntropy)); }