/// <summary>
    /// Training for supervised learning
    /// </summary>
    /// <param name="vectorObservations"></param>
    /// <param name="visualObservations"></param>
    /// <param name="actions"></param>
    /// <returns></returns>
    public float TrainBatch(float[,] vectorObservations, List <float[, , , ]> visualObservations, float[,] actions)
    {
        Debug.Assert(mode == Mode.SupervisedLearning, "This method is for SupervisedLearning mode only. Please set the mode of RLModePPO to SupervisedLearning in the editor.");
        Debug.Assert(TrainingEnabled == true, "The model needs to initalized with Training enabled to use TrainBatch()");


        List <Array> inputs = new List <Array>();

        if (vectorObservations != null)
        {
            inputs.Add(vectorObservations);
        }
        if (visualObservations != null)
        {
            inputs.AddRange(visualObservations);
        }
        if (ActionSpace == SpaceType.continuous)
        {
            inputs.Add(actions);
        }
        else if (ActionSpace == SpaceType.discrete)
        {
            int[,] actionsInt = actions.Convert(t => Mathf.RoundToInt(t));
            inputs.Add(actionsInt);
        }

        var loss   = UpdateSLFunction.Call(inputs);
        var result = (float)loss[0].eval();

        return(result);
    }
    /// <summary>
    /// Training for supervised learning
    /// </summary>
    /// <param name="vectorObservations"></param>
    /// <param name="visualObservations"></param>
    /// <param name="actions"></param>
    /// <returns></returns>
    public float TrainBatch(float[,] vectorObservations, List <float[, , , ]> visualObservations, float[,] actions, List <float[, ]> actionsMask = null)
    {
        Debug.Assert(TrainingEnabled == true, "The model needs to initalized with Training enabled to use TrainBatch()");

        List <Array> inputs = new List <Array>();

        if (vectorObservations != null)
        {
            inputs.Add(vectorObservations);
        }
        if (visualObservations != null)
        {
            inputs.AddRange(visualObservations);
        }



        if (ActionSpace == SpaceType.continuous)
        {
            inputs.Add(actions);
        }
        else if (ActionSpace == SpaceType.discrete)
        {
            List <float[, ]> masks = actionsMask;
            int batchSize          = actions.GetLength(0);
            //create all 1 mask if the input mask is null.
            if (masks == null)
            {
                masks = CreateDummyMasks(ActionSizes, batchSize);
            }
            inputs.AddRange(masks);

            int[,] actionsInt = actions.Convert(t => Mathf.RoundToInt(t));
            inputs.Add(actionsInt);
        }

        var loss   = UpdateSLFunction.Call(inputs);
        var result = (float)loss[0].eval();

        return(result);
    }