Example #1
0
    /// <summary>
    /// create visual input  tensors for the BrainParameters for MLagent.
    /// </summary>
    /// <param name="brainParameters">BrainParameters</param>
    /// <returns>List of all input visual tensors</returns>
    protected static List <Tensor> CreateVisualInputs(BrainParameters brainParameters)
    {
        if (brainParameters.cameraResolutions == null || brainParameters.cameraResolutions.Length == 0)
        {
            return(null);
        }
        List <Tensor> allInputs = new List <Tensor>();
        int           i         = 0;

        foreach (var r in brainParameters.cameraResolutions)
        {
            int width  = r.width;
            int height = r.height;
            int channels;
            if (r.blackAndWhite)
            {
                channels = 1;
            }
            else
            {
                channels = 3;
            }

            var input = UnityTFUtils.Input(new int?[] { height, width, channels }, name: "InputVisual" + i)[0];
            allInputs.Add(input);

            i++;
        }

        return(allInputs);
    }
Example #2
0
    /// <summary>
    /// Initialize the model for supervised learning
    /// </summary>
    /// <param name="trainerParams"></param>
    /// <param name="vectorObsTensor"></param>
    /// <param name="inputVisualTensors"></param>
    /// <param name="outputActionFromNetwork"></param>
    /// <param name="outputLogVarianceFromNetwork"></param>
    /// <param name="weightsToUpdate"></param>
    protected void InitializeSLStructures(TrainerParams trainerParams, Tensor vectorObsTensor, Tensor normalizedVectorObs, List <Tensor> inputVisualTensors)
    {
        Tensor outActionMean, outActionLogVariance, outValue;

        network.BuildNetworkForContinuousActionSapce(normalizedVectorObs, inputVisualTensors, null, null, ActionSizes[0], out outActionMean, out outValue, out outActionLogVariance);

        List <Tensor> allobservationInputs = new List <Tensor>();

        if (HasVectorObservation)
        {
            allobservationInputs.Add(vectorObsTensor);
        }
        if (HasVisualObservation)
        {
            allobservationInputs.AddRange(inputVisualTensors);
        }

        Tensor outputVariance = null;

        //if (ActionSpace == SpaceType.continuous)
        //{
        outputVariance = K.exp(outActionLogVariance);
        ActionFunction = K.function(allobservationInputs, new List <Tensor> {
            outActionMean, outputVariance
        }, null, "ActionFunction");

        /*}
         * else
         * {
         *
         *  ActionFunction = K.function(allobservationInputs, new List<Tensor> { outputActionFromNetwork }, null, "ActionFunction");
         * }*/



        ///created losses for supervised learning part
        Tensor supervisedLearingLoss = null;
        var    inputActionLabel      = UnityTFUtils.Input(new int?[] { ActionSpace == SpaceType.continuous ? ActionSizes[0] : 1 }, name: "InputAction", dtype: ActionSpace == SpaceType.continuous ? DataType.Float : DataType.Int32)[0];

        /*if (ActionSpace == SpaceType.discrete)
         * {
         *  var onehotInputAction = K.one_hot(inputActionLabel, K.constant<int>(ActionSizes[0], dtype: DataType.Int32), K.constant(1.0f), K.constant(0.0f));
         *  onehotInputAction = K.reshape(onehotInputAction, new int[] { -1, ActionSizes[0] });
         *  supervisedLearingLoss = K.mean(K.categorical_crossentropy(onehotInputAction, outputActionFromNetwork, false));
         * }
         * else
         * {*/
        supervisedLearingLoss = K.mean(K.mean(0.5 * K.square(inputActionLabel - outActionMean) / outputVariance + 0.5 * outActionLogVariance));
        //}

        var updates  = AddOptimizer(network.GetActorWeights(), supervisedLearingLoss, optimizer);
        var slInputs = new List <Tensor>();

        slInputs.AddRange(allobservationInputs); slInputs.Add(inputActionLabel);
        UpdateSLFunction = K.function(slInputs, new List <Tensor>()
        {
            supervisedLearingLoss
        }, updates, "UpdateSLFunction");
    }
    /// <summary>
    /// Initialize the model for supervised learning
    /// </summary>
    /// <param name="trainerParams"></param>
    /// <param name="stateTensor"></param>
    /// <param name="inputVisualTensors"></param>
    /// <param name="outputActionFromNetwork"></param>
    /// <param name="outputVarianceFromNetwork"></param>
    /// <param name="weightsToUpdate"></param>
    protected void InitializeSLStructures(TrainerParams trainerParams, Tensor stateTensor, List <Tensor> inputVisualTensors, Tensor outputActionFromNetwork, Tensor outputVarianceFromNetwork, List <Tensor> weightsToUpdate)
    {
        List <Tensor> allobservationInputs = new List <Tensor>();

        if (HasVectorObservation)
        {
            allobservationInputs.Add(stateTensor);
        }
        if (HasVisualObservation)
        {
            allobservationInputs.AddRange(inputVisualTensors);
        }

        if (ActionSpace == SpaceType.continuous)
        {
            ActionFunction = K.function(allobservationInputs, new List <Tensor> {
                outputActionFromNetwork, outputVarianceFromNetwork
            }, null, "ActionFunction");
        }
        else
        {
            ActionFunction = K.function(allobservationInputs, new List <Tensor> {
                outputActionFromNetwork
            }, null, "ActionFunction");
        }



        ///created losses for supervised learning part
        Tensor supervisedLearingLoss = null;
        var    inputActionLabel      = UnityTFUtils.Input(new int?[] { ActionSpace == SpaceType.continuous ? ActionSize : 1 }, name: "InputAction", dtype: ActionSpace == SpaceType.continuous ? DataType.Float : DataType.Int32)[0];

        if (ActionSpace == SpaceType.discrete)
        {
            var onehotInputAction = K.one_hot(inputActionLabel, K.constant <int>(ActionSize, dtype: DataType.Int32), K.constant(1.0f), K.constant(0.0f));
            onehotInputAction     = K.reshape(onehotInputAction, new int[] { -1, ActionSize });
            supervisedLearingLoss = K.mean(K.categorical_crossentropy(onehotInputAction, outputActionFromNetwork, false));
        }
        else
        {
            supervisedLearingLoss = K.mean(K.mean(0.5 * K.square(inputActionLabel - outputActionFromNetwork) / outputVarianceFromNetwork + 0.5 * K.log(outputVarianceFromNetwork)));
        }

        var updates  = AddOptimizer(weightsToUpdate, supervisedLearingLoss, optimizer);
        var slInputs = new List <Tensor>();

        slInputs.AddRange(allobservationInputs); slInputs.Add(inputActionLabel);
        UpdateSLFunction = K.function(slInputs, new List <Tensor>()
        {
            supervisedLearingLoss
        }, updates, "UpdateSLFunction");
    }
Example #4
0
    /// <summary>
    /// Trainers will call this method to initialize the model. This method will call the InitializeInner()
    /// </summary>
    /// <param name="brainParameters">brain parameter of the MLagent brain</param>
    /// <param name="enableTraining">whether enable training</param>
    /// <param name="trainerParams">trainer parameters passed by the trainer. Training will not be enabled </param>
    public virtual void Initialize(BrainParameters brainParameters, bool enableTraining, TrainerParams trainerParams = null)
    {
        Debug.Assert(Initialized == false, "Model already Initalized");

        NameScope ns = null;

        if (!string.IsNullOrEmpty(modelName))
        {
            ns = Current.K.name_scope(modelName);
        }

        ActionSizes = brainParameters.vectorActionSize;
        StateSize   = brainParameters.vectorObservationSize * brainParameters.numStackedVectorObservations;
        ActionSpace = brainParameters.vectorActionSpaceType;

        Debug.Assert(ActionSizes[0] > 0, "Action size can not be zero");

        //create basic inputs
        var inputStateTensor = StateSize > 0 ? UnityTFUtils.Input(new int?[] { StateSize }, name: "InputStates")[0] : null;

        HasVectorObservation = inputStateTensor != null;
        var inputVisualTensors = CreateVisualInputs(brainParameters);

        HasVisualObservation = inputVisualTensors != null;

        //create inner intialization
        InitializeInner(brainParameters, inputStateTensor, inputVisualTensors, enableTraining ? trainerParams : null);

        //test
        //Debug.LogWarning("Tensorflow Graph is saved for test purpose at: SavedGraph/" + name + ".pb");
        //((UnityTFBackend)Current.K).ExportGraphDef("SavedGraph/" + name + ".pb");

        Current.K.try_initialize_variables(true);

        if (ns != null)
        {
            ns.Dispose();
        }

        if (checkpointToLoad != null)
        {
            RestoreCheckpoint(checkpointToLoad.bytes, true);
        }
        Initialized     = true;
        TrainingEnabled = enableTraining;
    }
Example #5
0
    public void TestConv2D()
    {
        var inputLayer = UnityTFUtils.Input(shape: new int?[] { 32, 32, 3 });

        var conv1 = new Conv2D(16, new int[] { 3, 3 }, padding: PaddingType.Same, activation: new ReLU());
        var conv2 = new Conv2D(3, new int[] { 3, 3 }, padding: PaddingType.Same, activation: new ReLU());

        var target = UnityTFUtils.Input(shape: new int?[] { 32, 32, 3 });


        var pred  = conv2.Call(conv1.Call(inputLayer[0])[0])[0];
        var lossM = new MeanSquareError();

        lossM.Call(target[0], pred);


        ((UnityTFBackend)K).ExportGraphDef("SavedGraph/convLayer.pb");
    }
Example #6
0
    public void TestLayer()
    {
        var inputLayer = UnityTFUtils.Input(shape: new int?[] { 3 });

        var dense1 = new Dense(10, new ReLU(), true);
        var dense2 = new Dense(1, new ReLU(), true);

        var target = UnityTFUtils.Input(shape: new int?[] { 1 });

        var o = dense1.Call(inputLayer[0]);

        o = dense2.Call(o[0]);

        var lossM = new MeanSquareError();

        lossM.Call(target[0], o[0]);



        ((UnityTFBackend)K).ExportGraphDef("SavedGraph/testLayer.pb");
    }
    public override void InitializeInner(BrainParameters brainParameters, Tensor inputStateTensor, List <Tensor> inputVisualTensors, TrainerParams trainerParams)
    {
        //build the network
        var    networkOutputs = network.BuildNetwork(inputStateTensor, inputVisualTensors, null, ActionSize, ActionSpace);
        Tensor outputAction   = networkOutputs.Item1;
        Tensor outputVar      = networkOutputs.Item2;

        hasVariance = outputVar != null && brainParameters.vectorActionSpaceType == SpaceType.continuous;

        List <Tensor> observationInputs = new List <Tensor>();

        if (HasVectorObservation)
        {
            observationInputs.Add(inputStateTensor);
        }
        if (HasVisualObservation)
        {
            observationInputs.AddRange(inputVisualTensors);
        }
        if (hasVariance)
        {
            ActionFunction = K.function(observationInputs, new List <Tensor> {
                outputAction, outputVar
            }, null, "ActionFunction");
        }
        else
        {
            ActionFunction = K.function(observationInputs, new List <Tensor> {
                outputAction
            }, null, "ActionFunction");
        }

        //build the parts for training
        TrainerParamsMimic trainingParams = trainerParams as TrainerParamsMimic;

        if (trainerParams != null && trainingParams == null)
        {
            Debug.LogError("Trainer params for Supervised learning mode needs to be a TrainerParamsMimic type");
        }
        if (trainingParams != null)
        {
            //training inputs
            var inputActionLabel = UnityTFUtils.Input(new int?[] { ActionSpace == SpaceType.continuous ? ActionSize : 1 }, name: "InputAction", dtype: ActionSpace == SpaceType.continuous ? DataType.Float : DataType.Int32)[0];
            //creat the loss
            Tensor loss = null;
            if (ActionSpace == SpaceType.discrete)
            {
                Tensor actionOnehot   = K.one_hot(inputActionLabel, K.constant(ActionSize, dtype: DataType.Int32), K.constant(1.0f), K.constant(0.0f));
                Tensor reshapedOnehot = K.reshape(actionOnehot, new int[] { -1, ActionSize });
                loss = K.mean(K.categorical_crossentropy(reshapedOnehot, outputAction, false));
            }
            else
            {
                if (hasVariance)
                {
                    loss = K.mean(K.mean(0.5 * K.square(inputActionLabel - outputAction) / outputVar + 0.5 * K.log(outputVar)));
                }
                else
                {
                    loss = K.mean(new MeanSquareError().Call(inputActionLabel, outputAction));
                }
            }
            //add inputs, outputs and parameters to the list
            List <Tensor> updateParameters = network.GetWeights();
            List <Tensor> allInputs        = new List <Tensor>();


            if (HasVectorObservation)
            {
                allInputs.Add(inputStateTensor);
                observationInputs.Add(inputStateTensor);
            }
            if (HasVisualObservation)
            {
                allInputs.AddRange(inputVisualTensors);
                observationInputs.AddRange(inputVisualTensors);
            }
            allInputs.Add(inputActionLabel);

            //create optimizer and create necessary functions
            var updates = AddOptimizer(updateParameters, loss, optimizer);
            UpdateFunction = K.function(allInputs, new List <Tensor> {
                loss
            }, updates, "UpdateFunction");
        }
    }
Example #8
0
    protected void InitializeSLStructureContinuousAction(Tensor vectorObs, Tensor normalizedVectorObs, List <Tensor> visualObs, TrainerParams trainerParams)
    {
        //build the network
        Tensor outputValue = null; Tensor outputActionMean = null; Tensor outputLogVariance = null;

        network.BuildNetworkForContinuousActionSapce(normalizedVectorObs, visualObs, null, null, ActionSizes[0], out outputActionMean, out outputValue, out outputLogVariance);
        Tensor outputAction = outputActionMean;
        Tensor outputVar    = K.exp(outputLogVariance);

        SLHasVar = outputLogVariance != null;

        List <Tensor> observationInputs = new List <Tensor>();

        if (HasVectorObservation)
        {
            observationInputs.Add(vectorObs);
        }
        if (HasVisualObservation)
        {
            observationInputs.AddRange(visualObs);
        }
        if (SLHasVar)
        {
            ActionFunction = K.function(observationInputs, new List <Tensor> {
                outputAction, outputVar
            }, null, "ActionFunction");
        }
        else
        {
            ActionFunction = K.function(observationInputs, new List <Tensor> {
                outputAction
            }, null, "ActionFunction");
        }

        //build the parts for training
        TrainerParamsMimic trainingParams = trainerParams as TrainerParamsMimic;

        if (trainerParams != null && trainingParams == null)
        {
            Debug.LogError("Trainer params for Supervised learning mode needs to be a TrainerParamsMimic type");
        }
        if (trainingParams != null)
        {
            //training inputs
            var inputActionLabel = UnityTFUtils.Input(new int?[] { ActionSizes[0] }, name: "InputAction", dtype: DataType.Float)[0];
            //creat the loss
            Tensor loss = null;
            if (SLHasVar)
            {
                loss = K.mean(K.mean(0.5 * K.square(inputActionLabel - outputAction) / outputVar + 0.5 * outputLogVariance));
            }
            else
            {
                loss = K.mean(new MeanSquareError().Call(inputActionLabel, outputAction));
            }

            //add inputs, outputs and parameters to the list
            List <Tensor> updateParameters = network.GetActorWeights();
            List <Tensor> allInputs        = new List <Tensor>();
            allInputs.AddRange(observationInputs);
            allInputs.Add(inputActionLabel);

            //create optimizer and create necessary functions
            var updates = AddOptimizer(updateParameters, loss, optimizer);
            UpdateSLFunction = K.function(allInputs, new List <Tensor> {
                loss
            }, updates, "UpdateFunction");
        }
    }
Example #9
0
    protected void InitializeSLStructureDiscreteAction(Tensor vectorObs, Tensor normalizedVectorObs, List <Tensor> visualObs, TrainerParams trainerParams)
    {
        //all inputs list
        List <Tensor> allObservationInputs = new List <Tensor>();

        if (HasVectorObservation)
        {
            allObservationInputs.Add(vectorObs);
        }
        if (HasVisualObservation)
        {
            allObservationInputs.AddRange(visualObs);
        }

        //build basic network
        Tensor[] outputActionsLogits = null;
        Tensor   outputValue         = null;

        network.BuildNetworkForDiscreteActionSpace(normalizedVectorObs, visualObs, null, null, ActionSizes, out outputActionsLogits, out outputValue);

        //the action masks input placeholders
        List <Tensor> actionMasksInputs = new List <Tensor>();

        for (int i = 0; i < ActionSizes.Length; ++i)
        {
            actionMasksInputs.Add(UnityTFUtils.Input(new int?[] { ActionSizes[i] }, name: "AcionMask" + i)[0]);
        }
        //masking and normalized and get the final action tensor
        Tensor[] outputActions, outputNormalizedLogits;
        CreateDiscreteActionMaskingLayer(outputActionsLogits, actionMasksInputs.ToArray(), out outputActions, out outputNormalizedLogits);

        //output tensors for discrete actions. Includes all action selected actions
        var outputDiscreteActions = new List <Tensor>();

        outputDiscreteActions.Add(K.identity(K.cast(ActionSizes.Length == 1 ? outputActions[0] : K.concat(outputActions.ToList(), 1), DataType.Float), "OutputAction"));
        var actionFunctionInputs = new List <Tensor>();

        actionFunctionInputs.AddRange(allObservationInputs);
        actionFunctionInputs.AddRange(actionMasksInputs);
        ActionFunction = K.function(actionFunctionInputs, outputDiscreteActions, null, "ActionFunction");


        //build the parts for training
        TrainerParamsMimic trainingParams = trainerParams as TrainerParamsMimic;

        if (trainerParams != null && trainingParams == null)
        {
            Debug.LogError("Trainer params for Supervised learning mode needs to be a TrainerParamsMimic type");
        }
        if (trainingParams != null)
        {
            //training inputs
            var inputActionLabels = UnityTFUtils.Input(new int?[] { ActionSizes.Length }, name: "InputAction", dtype: DataType.Int32)[0];
            //split the input for each discrete branch
            List <Tensor> inputActionsDiscreteSeperated = null, onehotInputActions = null;    //for discrete action space
            var           splits = new int[ActionSizes.Length];
            for (int i = 0; i < splits.Length; ++i)
            {
                splits[i] = 1;
            }
            inputActionsDiscreteSeperated = K.split(inputActionLabels, K.constant(splits, dtype: DataType.Int32), K.constant(1, dtype: DataType.Int32), ActionSizes.Length);

            //creat the loss
            onehotInputActions = inputActionsDiscreteSeperated.Select((x, i) => K.reshape(K.one_hot(x, K.constant <int>(ActionSizes[i], dtype: DataType.Int32), K.constant(1.0f), K.constant(0.0f)), new int[] { -1, ActionSizes[i] })).ToList();

            var    losses = onehotInputActions.Select((x, i) => K.mean(K.categorical_crossentropy(x, outputNormalizedLogits[i], true))).ToList();
            Tensor loss = losses.Aggregate((x, s) => x + s);

            //add inputs, outputs and parameters to the list
            List <Tensor> updateParameters = network.GetActorWeights();
            List <Tensor> allInputs = new List <Tensor>();
            allInputs.AddRange(actionFunctionInputs);
            allInputs.Add(inputActionLabels);

            //create optimizer and create necessary functions
            var updates = AddOptimizer(updateParameters, loss, optimizer);
            UpdateSLFunction = K.function(allInputs, new List <Tensor> {
                loss
            }, updates, "UpdateFunction");
        }
    }
Example #10
0
    protected void CreatePPOOptimizer(TrainerParamsPPO trainingParams, Tensor entropy, Tensor actionLogProb, Tensor outputValueFromNetwork, List <Tensor> extraInputTensors, List <Tensor> weightsToUpdate)
    {
        ClipEpsilon       = trainingParams.clipEpsilon;
        ValueLossWeight   = trainingParams.valueLossWeight;
        EntropyLossWeight = trainingParams.entropyLossWeight;
        ClipValueLoss     = trainingParams.clipValueLoss;


        var inputOldLogProb  = UnityTFUtils.Input(new int?[] { ActionSpace == SpaceType.continuous ? ActionSizes[0] : ActionSizes.Length }, name: "InputOldLogProb")[0];
        var inputAdvantage   = UnityTFUtils.Input(new int?[] { 1 }, name: "InputAdvantage")[0];
        var inputTargetValue = UnityTFUtils.Input(new int?[] { 1 }, name: "InputTargetValue")[0];
        var inputOldValue    = UnityTFUtils.Input(new int?[] { 1 }, name: "InputOldValue")[0];

        var inputClipEpsilon       = UnityTFUtils.Input(batch_shape: new int?[] { }, name: "ClipEpsilon", dtype: DataType.Float)[0];
        var inputClipValueLoss     = UnityTFUtils.Input(batch_shape: new int?[] { }, name: "ClipValueLoss", dtype: DataType.Float)[0];
        var inputValuelossWeight   = UnityTFUtils.Input(batch_shape: new int?[] { }, name: "ValueLossWeight", dtype: DataType.Float)[0];
        var inputEntropyLossWeight = UnityTFUtils.Input(batch_shape: new int?[] { }, name: "EntropyLossWeight", dtype: DataType.Float)[0];


        // value loss
        Tensor outputValueLoss = null;

        using (K.name_scope("ValueLoss"))
        {
            var clippedValueEstimate = inputOldValue + K.clip(outputValueFromNetwork - inputOldValue, 0.0f - inputClipValueLoss, inputClipValueLoss);
            var valueLoss1           = new MeanSquareError().Call(outputValueFromNetwork, inputTargetValue);
            var valueLoss2           = new MeanSquareError().Call(clippedValueEstimate, inputTargetValue);
            outputValueLoss = K.mean(K.maximum(valueLoss1, valueLoss2));
        }
        //var outputValueLoss = K.mean(valueLoss1);

        // Clipped Surrogate loss
        Tensor outputPolicyLoss;

        using (K.name_scope("ClippedCurreogateLoss"))
        {
            //Debug.LogWarning("testnew");
            //var probStopGradient = K.stop_gradient(actionProb);
            var probRatio = K.exp(actionLogProb - inputOldLogProb);
            var p_opt_a   = probRatio * inputAdvantage;
            var p_opt_b   = K.clip(probRatio, 1.0f - inputClipEpsilon, 1.0f + inputClipEpsilon) * inputAdvantage;

            outputPolicyLoss = (-1f) * K.mean(K.mean(K.minimun(p_opt_a, p_opt_b)), name: "ClippedCurreogateLoss");
        }
        //final weighted loss
        var outputLoss = outputPolicyLoss + inputValuelossWeight * outputValueLoss;

        outputLoss = outputLoss - inputEntropyLossWeight * entropy;
        outputLoss = K.identity(outputLoss, "OutputLoss");

        //add inputs, outputs and parameters to the list
        List <Tensor> allInputs = new List <Tensor>();

        allInputs.Add(inputOldLogProb);
        allInputs.Add(inputTargetValue);
        allInputs.Add(inputOldValue);
        allInputs.Add(inputAdvantage);
        allInputs.Add(inputClipEpsilon);
        allInputs.Add(inputClipValueLoss);
        allInputs.Add(inputValuelossWeight);
        allInputs.Add(inputEntropyLossWeight);

        allInputs.AddRange(extraInputTensors);

        //create optimizer and create necessary functions
        var updates = AddOptimizer(weightsToUpdate, outputLoss, optimizer);

        UpdatePPOFunction = K.function(allInputs, new List <Tensor> {
            outputLoss, outputValueLoss, outputPolicyLoss, entropy
        }, updates, "UpdateFunction");
    }
Example #11
0
    protected void InitializePPOStructureDiscreteAction(Tensor vectorObs, Tensor normalizedVectorObs, List <Tensor> visualObs, TrainerParams trainerParams)
    {
        //all inputs list
        List <Tensor> allObservationInputs = new List <Tensor>();

        if (HasVectorObservation)
        {
            allObservationInputs.Add(vectorObs);
        }
        if (HasVisualObservation)
        {
            allObservationInputs.AddRange(visualObs);
        }

        Tensor[] outputActionsLogits = null; Tensor outputValue = null;
        network.BuildNetworkForDiscreteActionSpace(normalizedVectorObs, visualObs, null, null, ActionSizes, out outputActionsLogits, out outputValue);

        ValueFunction = K.function(allObservationInputs, new List <Tensor> {
            outputValue
        }, null, "ValueFunction");

        //the action masks input placeholders
        List <Tensor> actionMasksInputs = new List <Tensor>();

        for (int i = 0; i < ActionSizes.Length; ++i)
        {
            actionMasksInputs.Add(UnityTFUtils.Input(new int?[] { ActionSizes[i] }, name: "AcionMask" + i)[0]);
        }

        Tensor[] outputActions, outputNormalizedLogits;
        CreateDiscreteActionMaskingLayer(outputActionsLogits, actionMasksInputs.ToArray(), out outputActions, out outputNormalizedLogits);

        //output tensors for discrete actions. Includes all action selected actions and the normalized logits of all actions
        var outputDiscreteActions = new List <Tensor>();

        outputDiscreteActions.Add(K.identity(K.cast(ActionSizes.Length == 1? outputActions[0]: K.concat(outputActions.ToList(), 1), DataType.Float), "OutputAction"));
        outputDiscreteActions.AddRange(outputNormalizedLogits);
        var actionFunctionInputs = new List <Tensor>();

        actionFunctionInputs.AddRange(allObservationInputs); actionFunctionInputs.AddRange(actionMasksInputs);
        ActionFunction = K.function(actionFunctionInputs, outputDiscreteActions, null, "ActionFunction");


        TrainerParamsPPO trainingParams = trainerParams as TrainerParamsPPO;

        if (trainingParams != null)
        {
            // action probability from input action
            Tensor        outputEntropy;
            List <Tensor> inputActionsDiscreteSeperated = null, onehotInputActions = null;    //for discrete action space

            Tensor inputAction = UnityTFUtils.Input(new int?[] { ActionSizes.Length }, name: "InputActions", dtype: DataType.Int32)[0];

            //split the input for each discrete branch
            var splits = new int[ActionSizes.Length];
            for (int i = 0; i < splits.Length; ++i)
            {
                splits[i] = 1;
            }
            inputActionsDiscreteSeperated = K.split(inputAction, K.constant(splits, dtype: DataType.Int32), K.constant(1, dtype: DataType.Int32), ActionSizes.Length);

            Tensor actionLogProb = null;
            using (K.name_scope("ActionProbAndEntropy"))
            {
                onehotInputActions = inputActionsDiscreteSeperated.Select((x, i) => K.reshape(K.one_hot(x, K.constant <int>(ActionSizes[i], dtype: DataType.Int32), K.constant(1.0f), K.constant(0.0f)), new int[] { -1, ActionSizes[i] })).ToList();

                //entropy
                var entropies = outputActionsLogits.Select((t) => { return(K.mean((-1.0f) * K.sum(K.softmax(t) * K.log(K.softmax(t) + 0.00000001f), axis: 1), 0)); });
                outputEntropy = entropies.Aggregate((x, y) => { return(x + y); });

                //probabilities
                var actionProbsArray = ActionSizes.Select((x, i) => { return(K.sum(outputNormalizedLogits[i] * onehotInputActions[i], 1, true)); }).ToList();
                //actionLogProb = K.reshape(K.sum(K.log(outputActionFromNetwork) * onehotInputAction, 1), new int[] { -1, 1 });
                actionLogProb = ActionSizes.Length == 1 ? actionProbsArray[0]:K.concat(actionProbsArray, 1);
            }

            List <Tensor> extraInputs = new List <Tensor>();
            extraInputs.AddRange(actionFunctionInputs);
            extraInputs.Add(inputAction);

            CreatePPOOptimizer(trainingParams, outputEntropy, actionLogProb, outputValue, extraInputs, network.GetWeights());
        }
    }
Example #12
0
    /// <summary>
    /// Initialize the GAN model based on the current value fields, without considering the MLAgent stuff.
    /// </summary>
    public void Initialize(bool enableTraining = true)
    {
        Debug.Assert(Initialized == false, "model already initialized");

        HasNoiseInput      = inputNoiseShape != null && inputNoiseShape.Length > 0;
        HasConditionInput  = inputConditionShape != null && inputConditionShape.Length > 0;
        HasGeneratorL2Loss = hasGeneratorL2Loss;


        //create generator input tensors
        Tensor inputCondition = null;

        if (HasConditionInput)
        {
            inputCondition = UnityTFUtils.Input(inputConditionShape.Select((t) => (int?)t).ToArray(), name: "InputConditoin")[0];
        }
        Tensor inputNoise = null;

        if (HasNoiseInput)
        {
            inputNoise = UnityTFUtils.Input(inputNoiseShape.Select((t) => (int?)t).ToArray(), name: "InputNoise")[0];
        }

        Debug.Assert(HasNoiseInput || HasConditionInput, "GAN needs at least one of noise or condition input");

        Tensor inputTargetToJudge = UnityTFUtils.Input(outputShape.Select((t) => (int?)t).ToArray(), name: "InputTargetToJudge")[0];

        //build the network
        Tensor generatorOutput, disOutForGenerator, dicOutTarget;

        network.BuildNetwork(inputCondition, inputNoise, inputTargetToJudge, outputShape, out generatorOutput, out dicOutTarget, out disOutForGenerator);

        //build the loss
        //generator gan loss
        Tensor genGANLoss = K.constant(0.0f, new int[] { }, DataType.Float) - K.mean(K.binary_crossentropy(disOutForGenerator, K.constant(0.0f, new int[] { }, DataType.Float), false), new int[] { 0, 1 });
        Tensor genLoss    = genGANLoss;
        //generator l2Loss if use it
        Tensor l2Loss = null;
        Tensor inputGeneratorTarget = null;
        Tensor inputL2LossWeight    = null;

        if (hasGeneratorL2Loss)
        {
            inputL2LossWeight    = UnityTFUtils.Input(batch_shape: new int?[] { }, name: "l2LossWeight", dtype: DataType.Float)[0];
            inputGeneratorTarget = UnityTFUtils.Input(outputShape.Select((t) => (int?)t).ToArray(), name: "GeneratorTarget")[0];

            int[] reduceDim = new int[outputShape.Length];
            for (int i = 0; i < reduceDim.Length; ++i)
            {
                reduceDim[i] = i;
            }
            l2Loss  = K.mul(inputL2LossWeight, K.mean(new MeanSquareError().Call(inputGeneratorTarget, generatorOutput), reduceDim));
            genLoss = genGANLoss + l2Loss;
        }

        //discriminator loss
        inputCorrectLabel = UnityTFUtils.Input(new int?[] { 1 }, name: "InputCorrectLabel")[0];
        Tensor discLoss = K.mean(K.binary_crossentropy(dicOutTarget, inputCorrectLabel, false), new int[] { 0, 1 });



        //create the Functions inputs
        List <Tensor> generatorTrainInputs     = new List <Tensor>();
        List <Tensor> generateInputs           = new List <Tensor>();
        List <Tensor> discriminatorTrainInputs = new List <Tensor>();

        discriminatorTrainInputs.Add(inputTargetToJudge);
        discriminatorTrainInputs.Add(inputCorrectLabel);
        if (HasConditionInput)
        {
            generatorTrainInputs.Add(inputCondition);
            generateInputs.Add(inputCondition);
            discriminatorTrainInputs.Add(inputCondition);
        }
        if (HasNoiseInput)
        {
            generatorTrainInputs.Add(inputNoise);
            generateInputs.Add(inputNoise);
        }
        if (hasGeneratorL2Loss)
        {
            generatorTrainInputs.Add(inputGeneratorTarget);
            generatorTrainInputs.Add(inputL2LossWeight);
        }

        //create optimizers
        if (enableTraining)
        {
            var generatorUpdate = AddOptimizer(network.GetGeneratorWeights(), genLoss, generatorOptimizer);
            trainGeneratorFunction = K.function(generatorTrainInputs, new List <Tensor> {
                genLoss
            }, generatorUpdate, "GeneratorUpdateFunction");

            var discriminatorUpdate = AddOptimizer(network.GetDiscriminatorWeights(), discLoss, discriminatorOptimizer);
            trainDiscriminatorFunction = K.function(discriminatorTrainInputs, new List <Tensor> {
                discLoss
            }, discriminatorUpdate, "DiscriminatorUpdateFunction");
        }
        generateFunction = K.function(generateInputs, new List <Tensor> {
            generatorOutput
        }, null, "GenerateFunction");

        //create functoin for training with prediction method
        CreateTrainWithPredictionFunctions();

        Initialized     = true;
        TrainingEnabled = enableTraining;
    }
    /// <summary>
    /// Initialize the model for PPO
    /// </summary>
    /// <param name="trainerParams"></param>
    /// <param name="stateTensor"></param>
    /// <param name="inputVisualTensors"></param>
    /// <param name="outputValueFromNetwork"></param>
    /// <param name="outputActionFromNetwork"></param>
    /// <param name="outputVarianceFromNetwork"></param>
    /// <param name="weightsToUpdate"></param>
    protected void InitializePPOStructures(TrainerParams trainerParams, Tensor stateTensor, List <Tensor> inputVisualTensors, Tensor outputValueFromNetwork, Tensor outputActionFromNetwork, Tensor outputVarianceFromNetwork, List <Tensor> weightsToUpdate)
    {
        List <Tensor> allobservationInputs = new List <Tensor>();

        if (HasVectorObservation)
        {
            allobservationInputs.Add(stateTensor);
        }
        if (HasVisualObservation)
        {
            allobservationInputs.AddRange(inputVisualTensors);
        }

        ValueFunction = K.function(allobservationInputs, new List <Tensor> {
            outputValueFromNetwork
        }, null, "ValueFunction");

        Tensor outputActualAction = null; Tensor actionProb = null;

        if (ActionSpace == SpaceType.continuous)
        {
            using (K.name_scope("SampleAction"))
            {
                outputActualAction = K.standard_normal(K.shape(outputActionFromNetwork), DataType.Float) * K.sqrt(outputVarianceFromNetwork) + outputActionFromNetwork;
            }
            using (K.name_scope("ActionProbs"))
            {
                actionProb = K.normal_probability(K.stop_gradient(outputActualAction), outputActionFromNetwork, outputVarianceFromNetwork);
            }
            ActionFunction = K.function(allobservationInputs, new List <Tensor> {
                outputActualAction, actionProb, outputActionFromNetwork, outputVarianceFromNetwork
            }, null, "ActionFunction");

            var probInputs = new List <Tensor>(); probInputs.AddRange(allobservationInputs); probInputs.Add(outputActualAction);
            ActionProbabilityFunction = K.function(probInputs, new List <Tensor> {
                actionProb
            }, null, "ActionProbabilityFunction");
        }
        else
        {
            ActionFunction = K.function(allobservationInputs, new List <Tensor> {
                outputActionFromNetwork
            }, null, "ActionFunction");
        }

        TrainerParamsPPO trainingParams = trainerParams as TrainerParamsPPO;

        if (trainingParams != null)
        {
            //training needed inputs

            var inputOldProb     = UnityTFUtils.Input(new int?[] { ActionSpace == SpaceType.continuous ? ActionSize : 1 }, name: "InputOldProb")[0];
            var inputAdvantage   = UnityTFUtils.Input(new int?[] { 1 }, name: "InputAdvantage")[0];
            var inputTargetValue = UnityTFUtils.Input(new int?[] { 1 }, name: "InputTargetValue")[0];
            var inputOldValue    = UnityTFUtils.Input(new int?[] { 1 }, name: "InputOldValue")[0];

            ClipEpsilon       = trainingParams.clipEpsilon;
            ValueLossWeight   = trainingParams.valueLossWeight;
            EntropyLossWeight = trainingParams.entropyLossWeight;

            var inputClipEpsilon       = UnityTFUtils.Input(batch_shape: new int?[] { }, name: "ClipEpsilon", dtype: DataType.Float)[0];
            var inputValuelossWeight   = UnityTFUtils.Input(batch_shape: new int?[] { }, name: "ValueLossWeight", dtype: DataType.Float)[0];
            var inputEntropyLossWeight = UnityTFUtils.Input(batch_shape: new int?[] { }, name: "EntropyLossWeight", dtype: DataType.Float)[0];

            // action probability from input action
            Tensor outputEntropy;
            Tensor inputActionDiscrete = null, onehotInputAction = null;    //for discrete action space

            if (ActionSpace == SpaceType.continuous)
            {
                using (K.name_scope("Entropy"))
                {
                    var temp = K.mul(outputVarianceFromNetwork, 2 * Mathf.PI * 2.7182818285);
                    temp = K.mul(K.log(temp), 0.5);
                    if (outputVarianceFromNetwork.shape.Length == 2)
                    {
                        outputEntropy = K.mean(K.mean(temp, 0, false), name: "OutputEntropy");
                    }
                    else
                    {
                        outputEntropy = K.mean(temp, 0, false, name: "OutputEntropy");
                    }
                }
            }
            else
            {
                using (K.name_scope("ActionProbAndEntropy"))
                {
                    inputActionDiscrete = UnityTFUtils.Input(new int?[] { 1 }, name: "InputAction", dtype: DataType.Int32)[0];
                    onehotInputAction   = K.one_hot(inputActionDiscrete, K.constant <int>(ActionSize, dtype: DataType.Int32), K.constant(1.0f), K.constant(0.0f));
                    onehotInputAction   = K.reshape(onehotInputAction, new int[] { -1, ActionSize });
                    outputEntropy       = K.mean((-1.0f) * K.sum(outputActionFromNetwork * K.log(outputActionFromNetwork + 0.00000001f), axis: 1), 0);
                    actionProb          = K.reshape(K.sum(outputActionFromNetwork * onehotInputAction, 1), new int[] { -1, 1 });
                }
            }

            // value loss
            Tensor outputValueLoss = null;
            using (K.name_scope("ValueLoss"))
            {
                var clippedValueEstimate = inputOldValue + K.clip(outputValueFromNetwork - inputOldValue, 0.0f - inputClipEpsilon, inputClipEpsilon);
                var valueLoss1           = new MeanSquareError().Call(outputValueFromNetwork, inputTargetValue);
                var valueLoss2           = new MeanSquareError().Call(clippedValueEstimate, inputTargetValue);
                outputValueLoss = K.mean(K.maximum(valueLoss1, valueLoss2));
            }
            //var outputValueLoss = K.mean(valueLoss1);

            // Clipped Surrogate loss
            Tensor outputPolicyLoss;
            using (K.name_scope("ClippedCurreogateLoss"))
            {
                //Debug.LogWarning("testnew");
                //var probStopGradient = K.stop_gradient(actionProb);
                var probRatio = actionProb / (inputOldProb + 0.0000000001f);
                var p_opt_a   = probRatio * inputAdvantage;
                var p_opt_b   = K.clip(probRatio, 1.0f - inputClipEpsilon, 1.0f + inputClipEpsilon) * inputAdvantage;

                outputPolicyLoss = (-1f) * K.mean(K.mean(K.minimun(p_opt_a, p_opt_b)), name: "ClippedCurreogateLoss");
            }
            //final weighted loss
            var outputLoss = outputPolicyLoss + inputValuelossWeight * outputValueLoss;
            outputLoss = outputLoss - inputEntropyLossWeight * outputEntropy;
            outputLoss = K.identity(outputLoss, "OutputLoss");

            //add inputs, outputs and parameters to the list
            List <Tensor> allInputs = new List <Tensor>();
            if (HasVectorObservation)
            {
                allInputs.Add(stateTensor);
            }
            if (HasVisualObservation)
            {
                allInputs.AddRange(inputVisualTensors);
            }
            if (ActionSpace == SpaceType.continuous)
            {
                allInputs.Add(outputActualAction);
            }
            else
            {
                allInputs.Add(inputActionDiscrete);
            }

            allInputs.Add(inputOldProb);
            allInputs.Add(inputTargetValue);
            allInputs.Add(inputOldValue);
            allInputs.Add(inputAdvantage);
            allInputs.Add(inputClipEpsilon);
            allInputs.Add(inputValuelossWeight);
            allInputs.Add(inputEntropyLossWeight);

            //create optimizer and create necessary functions
            var updates = AddOptimizer(weightsToUpdate, outputLoss, optimizer);
            UpdatePPOFunction = K.function(allInputs, new List <Tensor> {
                outputLoss, outputValueLoss, outputPolicyLoss, outputEntropy, actionProb
            }, updates, "UpdateFunction");
        }
    }
Example #14
0
        /// <summary>
        /// Adds a layer instance on top of the layer stack.
        /// </summary>
        ///
        /// <param name="layer">The layer.</param>
        ///
        public void Add(Layer layer)
        {
            if (outputs.Count == 0)
            {
                // first layer in model: check that it is an input layer
                if (layer.inbound_nodes.Count == 0)
                {
                    // create an input layer
                    if (layer.batch_input_shape == null)
                    {
                        throw new Exception("The first layer in a Sequential model must get an 'input_shape' or 'batch_input_shape' argument.");
                    }

                    // Instantiate the input layer.
                    var x = UnityTFUtils.Input(batch_shape: layer.batch_input_shape, dtype: layer.dtype, name: $"{layer.name}_input");

                    //Debug.Assert(x[0]._keras_history.Value.layer.GetType() == typeof(InputLayer));
                    Debug.Assert(x[0]._keras_history.Value.Item1.GetType() == typeof(InputLayer));

                    // This will build the current layer and create the node connecting
                    // the current layer to the input layer we just created.
                    layer.Call(x);

                    //Debug.Assert(x[0]._keras_history.Value.layer.GetType() == typeof(InputLayer));
                    Debug.Assert(x[0]._keras_history.Value.Item1.GetType() == typeof(InputLayer));
                }


                if (layer.inbound_nodes.Count != 1)
                {
                    throw new Exception($"A layer added to a Sequential model must not already be connected somewhere else. Model received layer '{layer.name}' which has {layer.inbound_nodes.Count} pre-existing inbound connections.");
                }

                if (layer.inbound_nodes[0].output_tensors.Count != 1)
                {
                    throw new Exception("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");
                }

                this.outputs = new List <Tensor> {
                    layer.inbound_nodes[0].output_tensors[0]
                };
                this.inputs = base.get_source_inputs(this.outputs[0]);

                // We create an input node, which we will keep updated
                // as we add more layers
                var node = new Node(outbound_layer: this,
                                    inbound_layers: new List <Layer>(),
                                    node_indices: new List <int?>(),
                                    tensor_indices: new List <int?>(),
                                    input_tensors: this.inputs,
                                    output_tensors: this.outputs,
                                    // no model-level masking for now
                                    input_masks: this.inputs.Select(x => (Tensor)null).ToList(),
                                    output_masks: new List <Tensor>()
                {
                    null
                },
                                    input_shapes: this.inputs.Select(x => x._keras_shape).ToList(),
                                    output_shapes: this.outputs.Select(x => x._keras_shape).ToList()
                                    );
            }
            else
            {
                List <Tensor> output_tensor = layer.Call(this.outputs);
                if (output_tensor.Count > 1)
                {
                    throw new Exception("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");
                }

                this.outputs = output_tensor;

                // update this.inbound_nodes
                this.inbound_nodes[0].output_tensors = this.outputs;
                this.inbound_nodes[0].output_shapes  = new List <int?[]> {
                    this.outputs[0]._keras_shape
                };
            }

            this.layers.Add(layer);
            this.built = false;
        }
Example #15
0
    /// <summary>
    /// Initialize the model for PPO
    /// </summary>
    /// <param name="trainerParams"></param>
    /// <param name="stateTensor"></param>
    /// <param name="inputVisualTensors"></param>
    /// <param name="outputValueFromNetwork"></param>
    /// <param name="outputActionFromNetwork"></param>
    /// <param name="outputVarianceFromNetwork"></param>
    protected void InitializePPOCMAStructures(TrainerParams trainerParams, Tensor stateTensor, List <Tensor> inputVisualTensors, Tensor outputValueFromNetwork, Tensor outputActionMeanFromNetwork, Tensor outActionLogVarianceFromNetwork, List <Tensor> valueWeights, List <Tensor> meanWeights, List <Tensor> varweights)
    {
        List <Tensor> allobservationInputs = new List <Tensor>();

        if (HasVectorObservation)
        {
            allobservationInputs.Add(stateTensor);
        }
        if (HasVisualObservation)
        {
            allobservationInputs.AddRange(inputVisualTensors);
        }

        ValueFunction = K.function(allobservationInputs, new List <Tensor> {
            outputValueFromNetwork
        }, null, "ValueFunction");

        Tensor outputActualAction = null;
        Tensor outputVariance     = K.exp(outActionLogVarianceFromNetwork);

        using (K.name_scope("SampleAction"))
        {
            outputActualAction = K.standard_normal(K.shape(outputActionMeanFromNetwork), DataType.Float) * K.sqrt(outputVariance) + outputActionMeanFromNetwork;
        }

        ActionFunction = K.function(allobservationInputs, new List <Tensor> {
            outputActualAction, outputActionMeanFromNetwork, outputVariance
        }, null, "ActionFunction");

        TrainerParamsPPO trainingParams = trainerParams as TrainerParamsPPO;

        if (trainingParams != null)
        {
            //training needed inputs
            var inputOldAction   = UnityTFUtils.Input(new int?[] { ActionSizes[0] }, name: "InputOldAction")[0];
            var inputAdvantage   = UnityTFUtils.Input(new int?[] { 1 }, name: "InputAdvantage")[0];
            var inputTargetValue = UnityTFUtils.Input(new int?[] { 1 }, name: "InputTargetValue")[0];
            var inputOldValue    = UnityTFUtils.Input(new int?[] { 1 }, name: "InputOldValue")[0];

            //var inputClipEpsilon = UnityTFUtils.Input(batch_shape: new int?[] { }, name: "ClipEpsilon", dtype: DataType.Float)[0];

            var inputClipEpsilonValue = UnityTFUtils.Input(batch_shape: new int?[] { }, name: "ClipEpsilonValue", dtype: DataType.Float)[0];
            // value loss
            Tensor outputValueLoss = null;
            using (K.name_scope("ValueLoss"))
            {
                var clippedValueEstimate = inputOldValue + K.clip(outputValueFromNetwork - inputOldValue, 0.0f - inputClipEpsilonValue, inputClipEpsilonValue);
                var valueLoss1           = new MeanSquareError().Call(outputValueFromNetwork, inputTargetValue);
                var valueLoss2           = new MeanSquareError().Call(clippedValueEstimate, inputTargetValue);
                outputValueLoss = K.mean(K.maximum(valueLoss1, valueLoss2));
                outputValueLoss = K.mean(valueLoss1);
            }

            var           valueUpdates = AddOptimizer(valueWeights, outputValueLoss, optimizerValue);
            List <Tensor> valueInputs  = new List <Tensor>();
            if (HasVectorObservation)
            {
                valueInputs.Add(stateTensor);
            }
            if (HasVisualObservation)
            {
                valueInputs.AddRange(inputVisualTensors);
            }
            valueInputs.Add(inputOldValue);
            valueInputs.Add(inputTargetValue);
            valueInputs.Add(inputClipEpsilonValue);
            TrainValueFunction = K.function(valueInputs, new List <Tensor> {
                outputValueLoss
            }, valueUpdates, "TrainValueFunction");

            // actor losses
            Tensor meanLoss, varLoss;
            using (K.name_scope("ActorLosses"))
            {
                Tensor posAdvantage;
                if (usePositiveAdvOnly)
                {
                    posAdvantage = K.identity(K.relu(K.mean(inputAdvantage)), "ClipedPositiveAdv");
                }
                else
                {
                    posAdvantage = K.identity(K.mean(inputAdvantage), "Adv");
                }
                var meanNoGrad   = K.stop_gradient(outputActionMeanFromNetwork, "MeanNoGrad");
                var varNoGrad    = K.stop_gradient(outputVariance, "VarNoGrad");
                var logVar       = outActionLogVarianceFromNetwork;
                var logVarNoGrad = K.stop_gradient(logVar, "LogVarNoGrad");
                using (K.name_scope("VarLoss"))
                {
                    var logpNoMeanGrad = -1.0f * K.sum(0.5f * K.square(inputOldAction - meanNoGrad) / outputVariance + 0.5f * logVar, 1);
                    varLoss = K.identity(-1.0f * K.mean(posAdvantage * logpNoMeanGrad), "VarLoss");
                }
                using (K.name_scope("MeanLoss"))
                {
                    var logpNoVarGrad = -1.0f * K.sum(0.5f * K.square(inputOldAction - outputActionMeanFromNetwork) / varNoGrad + 0.5f * logVarNoGrad, 1);
                    meanLoss = K.identity(-1.0f * K.mean(posAdvantage * logpNoVarGrad), "MeanLoss");
                }
            }

            //add inputs, outputs and parameters to the list
            List <Tensor> allInputs = new List <Tensor>();
            if (HasVectorObservation)
            {
                allInputs.Add(stateTensor);
            }
            if (HasVisualObservation)
            {
                allInputs.AddRange(inputVisualTensors);
            }
            allInputs.Add(inputOldAction);
            allInputs.Add(inputAdvantage);


            //create optimizer and create necessary functions
            var updatesMean = AddOptimizer(meanWeights, meanLoss, optimizerMean);
            var updatesVar  = AddOptimizer(varweights, varLoss, optimizerVariance);

            TrainMeanFunction = K.function(allInputs, new List <Tensor> {
                meanLoss
            }, updatesMean, "UpdateMeanFunction");
            TrainVarianceFunction = K.function(allInputs, new List <Tensor> {
                varLoss
            }, updatesVar, "UpdateMeanFunction");

            //pretraining for output mean and var
            var inputInitialStd  = UnityTFUtils.Input(new int?[] { ActionSizes[0] }, name: "InputInitialStd")[0];
            var inputInitialMean = UnityTFUtils.Input(new int?[] { ActionSizes[0] }, name: "InputInitialMean")[0];
            var policyInitLoss   = K.mean(K.mean(K.square(inputInitialMean - outputActionMeanFromNetwork)));
            policyInitLoss += K.mean(K.mean(K.square(inputInitialStd - K.sqrt(outputVariance))));

            var updatesPretrain = AddOptimizer(network.GetActorWeights(), policyInitLoss, optimizerPretrain);
            var pretrainInputs  = new List <Tensor>();
            pretrainInputs.Add(stateTensor);
            pretrainInputs.Add(inputInitialMean);
            pretrainInputs.Add(inputInitialStd);
            PretrainFunction = K.function(pretrainInputs, new List <Tensor> {
                policyInitLoss
            }, updatesPretrain, "PretrainFunction");
        }
    }