コード例 #1
0
        protected override void AgentReplay(
            int batchSize,
            OptimizerBase optimizer,
            MetricFunction lossMetric,
            bool shuffle)
        {
            var batch = shuffle ? ReplayMemory.ToShuffledBatch(batchSize) : ReplayMemory.ToRandomBatch(batchSize);

            var states       = new DataFrame <float>(StateShape);
            var statesTarget = new DataFrame <float>(StateShape);

            foreach (var sample in batch)
            {
                states.Add(sample.Item1);
                statesTarget.Add(sample.Item4 ?? new float[StateShape.TotalSize]);
            }

            var prediction = Model.Predict(states);
            var predictionOfTargetStates = Model.Predict(statesTarget);
            var predictionTarget         = TargetModel.Predict(statesTarget);

            var data = new DataFrameList <float>(StateShape, ActionShape);

            for (var i = 0; i < batch.Length; i++)
            {
                var sample = batch[i];

                var t = prediction[i];

                if (sample.Item4 == null)
                {
                    t[sample.Item2] = sample.Item3;
                }
                else
                {
                    var lastValue  = float.MinValue;
                    var valueIndex = 0;

                    for (var j = 0; j < predictionOfTargetStates[i].Length; j++)
                    {
                        if (predictionOfTargetStates[i][j] > lastValue)
                        {
                            lastValue  = predictionOfTargetStates[i][j];
                            valueIndex = j;
                        }
                    }

                    t[sample.Item2] = (float)(sample.Item3 + DiscountFactor * predictionTarget[i][valueIndex]);
                }

                data.AddFrame(sample.Item1, t);
            }

            Model.Fit(data, 1, batch.Length, optimizer, lossMetric);
        }
コード例 #2
0
        protected void Train(List <Experience> experiences)
        {
            var e0          = experiences[0];
            var stateShape  = e0.State.Shape;
            var actionShape = e0.Action.Shape;

            Tensor statesBatch     = new Tensor(new Shape(stateShape.Width, stateShape.Height, stateShape.Depth, experiences.Count));
            Tensor nextStatesBatch = new Tensor(statesBatch.Shape);

            for (int i = 0; i < experiences.Count; ++i)
            {
                var e = experiences[i];
                e.State.CopyBatchTo(0, i, statesBatch);
                e.NextState.CopyBatchTo(0, i, nextStatesBatch);
            }

            Tensor rewardsBatch             = Net.Predict(statesBatch)[0]; // this is our original prediction
            Tensor futureRewardsBatch       = EnableDoubleDQN ? Net.Predict(nextStatesBatch)[0] : null;
            Tensor futureTargetRewardsBatch = TargetModel.Predict(nextStatesBatch)[0];

            List <float> absErrors = new List <float>();

            ImportanceSamplingWeights.Zero();

            for (int i = 0; i < experiences.Count; ++i)
            {
                var e = experiences[i];

                float futureReward = 0;

                if (EnableDoubleDQN)
                {
                    var nextBestAction = futureRewardsBatch.ArgMax(i);
                    futureReward = futureTargetRewardsBatch[0, nextBestAction, 0, i];
                }
                else
                {
                    futureReward = futureTargetRewardsBatch.Max(i);
                }

                var estimatedReward = e.Reward;
                if (!e.Done)
                {
                    estimatedReward += DiscountFactor * futureReward;
                }

                float error = estimatedReward - rewardsBatch[0, (int)e.Action[0], 0, i];
                absErrors.Add(Math.Abs(error));

                rewardsBatch[0, (int)e.Action[0], 0, i] = estimatedReward;
                ImportanceSamplingWeights[0, (int)e.Action[0], 0, i] = e.ImportanceSamplingWeight;
            }

            Memory.Update(experiences, absErrors);

            var avgError = absErrors.Sum() / experiences.Count;

            ++TrainingsDone;
            PerEpisodeErrorAvg += (avgError - PerEpisodeErrorAvg) / TrainingsDone;

            Net.Fit(new List <Data> {
                new Data(statesBatch, rewardsBatch)
            }, -1, TrainingEpochs, null, 0, Track.Nothing);
        }