示例#1
0
 private void UpdateTargetNet()
 {
     targetNet = new SequentialNet(policyNet.Layers);
     Layer[] layers = policyNet.Layers.ToArray();
     for (int i = 0; i < layers.Length; i++)
     {
         layers[i] = layers[i].Copy();
         layers[i].IsInputLayer = i == 0;
         layers[i].LastLayer    = i > 0 ? layers[i - 1] : null;
     }
     targetNet.Layers = layers;
 }
示例#2
0
    /// <summary>
    /// Start
    /// </summary>
    private void Start()
    {
        nn = nnBehaviour.GetSequentialNet();

        nn.Init(true, 1f);

        nn.PrintModelInfo(true);

        Trainer trainer = new Trainers.BackPropagation(nn, Errors.MeanSquaredError, 0.05f);

        float[][] inputs          = new float[dataTemplate.Length][];
        float[][] expectedOutputs = new float[dataTemplate.Length][];

        for (int i = 0; i < inputs.Length; i++)
        {
            float[][] sample = dataTemplate[i];
            inputs[i]          = sample[0];
            expectedOutputs[i] = sample[1];
        }

        trainer.Train(inputs, expectedOutputs, 400);
    }
示例#3
0
        private void Awake()
        {
            if (env == null)
            {
                throw new Exception("Environment is null!");
            }

            if (ActionsSize < 2)
            {
                throw new Exception("Agent must have 2 or more actions");
            }

            policyNet = GetComponent <SequentialNetBehaviour>().GetSequentialNet();
            policyNet.Init();
            UpdateTargetNet();

            policyNetTrainer = new Trainers.BackPropagation(policyNet, Errors.MeanSquaredError, LearningRate, false);


            replayMemory = new ReplayMemory(replayMemoryCapacity, memorySampleSize);

            explorationRate = MaxExplorationRate;
        }