Exemple #1
0
 private void OnEpochComplete(ITrainingContext context, IRecurrentTrainingContext recurrentContext)
 {
     if (_CalculateTestScore(context, _memory, _testData, _trainer, recurrentContext, ref _bestScore, ref _bestOutput))
     {
         using (var stream = new FileStream(_dataFile, FileMode.Create, FileAccess.Write))
             Serializer.Serialize(stream, _bestOutput);
     }
 }
        protected bool _CalculateTestScore(ITrainingContext context, float[] forwardMemory, float[] backwardMemory, ISequentialTrainingDataProvider data, INeuralNetworkBidirectionalBatchTrainer network, IRecurrentTrainingContext recurrentContext, ref double bestScore, ref BidirectionalNetwork output)
        {
            bool flag        = false;
            var  score       = _GetScore(data, network, forwardMemory, backwardMemory, recurrentContext);
            var  errorMetric = recurrentContext.TrainingContext.ErrorMetric;

            if ((errorMetric.HigherIsBetter && score > bestScore) || (!errorMetric.HigherIsBetter && score < bestScore))
            {
                bestScore            = score;
                output               = network.NetworkInfo;
                output.ForwardMemory = new FloatArray {
                    Data = forwardMemory
                };
                output.BackwardMemory = new FloatArray {
                    Data = backwardMemory
                };
                flag = true;
            }
            context.WriteScore(score, errorMetric.DisplayAsPercentage, flag);
            return(flag);
        }
        protected bool _CalculateTestScore(ITrainingContext context, float[] memory, ISequentialTrainingDataProvider data, INeuralNetworkRecurrentBatchTrainer network, IRecurrentTrainingContext recurrentContext, ref double bestScore, ref RecurrentNetwork output)
        {
            bool flag        = false;
            var  score       = _GetScore(data, network, memory, recurrentContext);
            var  errorMetric = recurrentContext.TrainingContext.ErrorMetric;

            if ((errorMetric.HigherIsBetter && score > bestScore) || (!errorMetric.HigherIsBetter && score < bestScore))
            {
                bestScore     = score;
                output        = network.NetworkInfo;
                output.Memory = new FloatArray {
                    Data = memory
                };
                flag = true;
            }
            context.WriteScore(score, errorMetric.DisplayAsPercentage, flag);

            if (flag)
            {
                _noChange = 0;
            }
            else
            {
                ++_noChange;
            }

            if (_autoAdjustOnNoChangeCount.HasValue && _noChange >= _autoAdjustOnNoChangeCount.Value)
            {
                context.ReduceTrainingRate();
                Console.WriteLine("Reducing training rate to " + context.TrainingRate);
                ApplyBestParams();
                _noChange = 0;
            }
            return(flag);
        }
 protected double _GetScore(ISequentialTrainingDataProvider data, INeuralNetworkBidirectionalBatchTrainer network, float[] forwardMemory, float[] backwardMemory, IRecurrentTrainingContext context)
 {
     return(Math.Abs(network.Execute(data, forwardMemory, backwardMemory, context).SelectMany(d => d).Select(d => context.TrainingContext.ErrorMetric.Compute(d.Output, d.ExpectedOutput)).Average()));
 }
Exemple #5
0
        public void Train(ISequentialTrainingDataProvider trainingData, int numEpochs, ITrainingContext context, IRecurrentTrainingContext recurrentContext = null)
        {
            if (recurrentContext == null)
            {
                recurrentContext = new RecurrentContext(_trainer.LinearAlgebraProvider, context);
            }

            _bestScore = _GetScore(_testData, _trainer, _memory, recurrentContext);
            Console.WriteLine(context.ErrorMetric.DisplayAsPercentage ? "Initial score: {0:P}" : "Initial score: {0}", _bestScore);

            _bestOutput = null;
            recurrentContext.TrainingContext.RecurrentEpochComplete += OnEpochComplete;
            _memory = _trainer.Train(trainingData, _memory, numEpochs, recurrentContext);
            recurrentContext.TrainingContext.RecurrentEpochComplete -= OnEpochComplete;

            // ensure best values are current
            ApplyBestParams();
        }
 public float CalculateCost(ISequentialTrainingDataProvider data, float[] memory, IRecurrentTrainingContext context)
 {
     return(Execute(data, memory, context).SelectMany(r => r).Select(r => context.TrainingContext.ErrorMetric.Compute(r.Output, r.ExpectedOutput)).Average());
 }
        public IReadOnlyList <IRecurrentExecutionResults[]> Execute(ISequentialTrainingDataProvider trainingData, float[] memory, IRecurrentTrainingContext context)
        {
            List <IRecurrentExecutionResults> temp;
            var sequenceOutput = new Dictionary <int, List <IRecurrentExecutionResults> >();
            var batchSize      = context.TrainingContext.MiniBatchSize;

            foreach (var miniBatch in _GetMiniBatches(trainingData, false, batchSize))
            {
                _lap.PushLayer();
                context.ExecuteForward(miniBatch, memory, (k, fc) => {
                    foreach (var action in _layer)
                    {
                        action.Execute(fc, false);
                    }
                    var memoryOutput = fc[1].AsIndexable().Rows.ToList();

                    // store the output
                    if (!sequenceOutput.TryGetValue(k, out temp))
                    {
                        sequenceOutput.Add(k, temp = new List <IRecurrentExecutionResults>());
                    }
                    var ret = fc[0].AsIndexable().Rows.Zip(miniBatch.GetExpectedOutput(fc, k).AsIndexable().Rows, (a, e) => Tuple.Create(a, e));
                    temp.AddRange(ret.Zip(memoryOutput, (t, d) => new RecurrentExecutionResults(t.Item1, t.Item2, d)));
                });

                // cleanup
                context.TrainingContext.EndBatch();
                _lap.PopLayer();
                miniBatch.Dispose();
            }
            return(sequenceOutput.OrderBy(kv => kv.Key).Select(kv => kv.Value.ToArray()).ToList());
        }
        public float[] Train(ISequentialTrainingDataProvider trainingData, float[] memory, int numEpochs, IRecurrentTrainingContext context)
        {
            var trainingContext = context.TrainingContext;

            for (int i = 0; i < numEpochs && context.TrainingContext.ShouldContinue; i++)
            {
                trainingContext.StartEpoch(trainingData.Count);
                var batchErrorList = new List <double>();

                foreach (var miniBatch in _GetMiniBatches(trainingData, _stochastic, trainingContext.MiniBatchSize))
                {
                    TrainOnMiniBatch(miniBatch, memory, context, curr => {
                        if (_collectTrainingError) // get a measure of the training error
                        {
                            batchErrorList.Add(curr.AsIndexable().Values.Select(v => Math.Pow(v, 2)).Average() / 2);
                        }
                    }, null);
                    miniBatch.Dispose();
                }
                trainingContext.EndRecurrentEpoch(_collectTrainingError ? batchErrorList.Average() : 0, context);
            }
            return(memory);
        }
        public void TrainOnMiniBatch(ISequentialMiniBatch miniBatch, float[] memory, IRecurrentTrainingContext context, Action <IMatrix> beforeBackProp, Action <IMatrix> afterBackProp)
        {
            var trainingContext = context.TrainingContext;

            _lap.PushLayer();
            var sequenceLength = miniBatch.SequenceLength;
            var updateStack    = new Stack <Tuple <Stack <INeuralNetworkRecurrentBackpropagation>, IMatrix, IMatrix, ISequentialMiniBatch, int> >();

            context.ExecuteForward(miniBatch, memory, (k, fc) => {
                var layerStack = new Stack <INeuralNetworkRecurrentBackpropagation>();
                foreach (var action in _layer)
                {
                    layerStack.Push(action.Execute(fc, true));
                }
                updateStack.Push(Tuple.Create(layerStack, miniBatch.GetExpectedOutput(fc, k), fc[0], miniBatch, k));
            });

            // backpropagate, accumulating errors across the sequence
            using (var updateAccumulator = new UpdateAccumulator(trainingContext)) {
                IMatrix curr = null;
                while (updateStack.Any())
                {
                    var update      = updateStack.Pop();
                    var isT0        = !updateStack.Any();
                    var actionStack = update.Item1;

                    // calculate error
                    var expectedOutput = update.Item2;
                    if (expectedOutput != null)
                    {
                        curr = trainingContext.ErrorMetric.CalculateDelta(update.Item3, expectedOutput);
                    }

                    // backpropagate
                    beforeBackProp?.Invoke(curr);
                    while (actionStack.Any())
                    {
                        var backpropagationAction = actionStack.Pop();
                        var shouldCalculateOutput = actionStack.Any() || isT0;
                        curr = backpropagationAction.Execute(curr, trainingContext, true, updateAccumulator);
                    }
                    afterBackProp?.Invoke(curr);

                    // apply any filters
                    foreach (var filter in _filter)
                    {
                        filter.AfterBackPropagation(update.Item4, update.Item5, curr);
                    }
                }

                // adjust the initial memory against the error signal
                if (curr != null)
                {
                    using (var columnSums = curr.ColumnSums()) {
                        var initialDelta = columnSums.AsIndexable();
                        for (var j = 0; j < memory.Length; j++)
                        {
                            memory[j] += initialDelta[j] * trainingContext.TrainingRate;
                        }
                    }
                }
            }

            // cleanup
            trainingContext.EndBatch();
            _lap.PopLayer();
        }
        public IReadOnlyList <IRecurrentExecutionResults[]> Execute(ISequentialTrainingDataProvider trainingData, float[] forwardMemory, float[] backwardMemory, IRecurrentTrainingContext context)
        {
            List <IRecurrentExecutionResults> temp;
            var sequenceOutput = new Dictionary <int, List <IRecurrentExecutionResults> >();
            var batchSize      = context.TrainingContext.MiniBatchSize;

            foreach (var miniBatch in _GetMiniBatches(trainingData, false, batchSize))
            {
                _lap.PushLayer();
                var sequenceLength = miniBatch.SequenceLength;
                context.ExecuteBidirectional(miniBatch, _layer, forwardMemory, backwardMemory, _padding, null, (memoryOutput, output) => {
                    // store the output
                    for (var k = 0; k < sequenceLength; k++)
                    {
                        if (!sequenceOutput.TryGetValue(k, out temp))
                        {
                            sequenceOutput.Add(k, temp = new List <IRecurrentExecutionResults>());
                        }
                        var ret = output[k].AsIndexable().Rows.Zip(miniBatch.GetExpectedOutput(output, k).AsIndexable().Rows, (a, e) => Tuple.Create(a, e));
                        temp.AddRange(ret.Zip(memoryOutput[k], (t, d) => new RecurrentExecutionResults(t.Item1, t.Item2, d)));
                    }
                });

                // cleanup
                context.TrainingContext.EndBatch();
                _lap.PopLayer();
                miniBatch.Dispose();
            }
            return(sequenceOutput.OrderBy(kv => kv.Key).Select(kv => kv.Value.ToArray()).ToList());
        }
        public BidirectionalMemory Train(ISequentialTrainingDataProvider trainingData, float[] forwardMemory, float[] backwardMemory, int numEpochs, IRecurrentTrainingContext context)
        {
            var trainingContext = context.TrainingContext;
            var logger          = trainingContext.Logger;

            for (int i = 0; i < numEpochs && context.TrainingContext.ShouldContinue; i++)
            {
                trainingContext.StartEpoch(trainingData.Count);
                var batchErrorList = new List <double>();

                foreach (var miniBatch in _GetMiniBatches(trainingData, _stochastic, trainingContext.MiniBatchSize))
                {
                    _lap.PushLayer();
                    var sequenceLength = miniBatch.SequenceLength;
                    var updateStack    = new Stack <Tuple <Stack <Tuple <INeuralNetworkRecurrentBackpropagation, INeuralNetworkRecurrentBackpropagation> >, IMatrix, IMatrix, ISequentialMiniBatch, int> >();
                    context.ExecuteBidirectional(miniBatch, _layer, forwardMemory, backwardMemory, _padding, updateStack, null);

                    // backpropagate, accumulating errors across the sequence
                    using (var updateAccumulator = new UpdateAccumulator(trainingContext)) {
                        while (updateStack.Any())
                        {
                            var update      = updateStack.Pop();
                            var isT0        = !updateStack.Any();
                            var actionStack = update.Item1;

                            // calculate error
                            var expectedOutput = update.Item2;
                            var curr           = new List <IMatrix>();
                            curr.Add(trainingContext.ErrorMetric.CalculateDelta(update.Item3, expectedOutput));

                            // get a measure of the training error
                            if (_collectTrainingError)
                            {
                                foreach (var item in curr)
                                {
                                    batchErrorList.Add(item.AsIndexable().Values.Select(v => Math.Pow(v, 2)).Average() / 2);
                                }
                            }

                            #region logging
                            if (logger != null)
                            {
                                logger.WriteStartElement("initial-error");
                                foreach (var item in curr)
                                {
                                    item.WriteTo(logger);
                                }
                                logger.WriteEndElement();
                            }
                            #endregion

                            // backpropagate
                            while (actionStack.Any())
                            {
                                var backpropagationAction = actionStack.Pop();
                                if (backpropagationAction.Item1 != null && backpropagationAction.Item2 != null && curr.Count == 1)
                                {
                                    using (var m = curr[0]) {
                                        var split = m.SplitRows(forwardMemory.Length);
                                        curr[0] = split.Left;
                                        curr.Add(split.Right);
                                    }
                                    #region logging
                                    if (logger != null)
                                    {
                                        logger.WriteStartElement("post-split");
                                        foreach (var item in curr)
                                        {
                                            item.WriteTo(logger);
                                        }
                                        logger.WriteEndElement();
                                    }
                                    #endregion
                                }
                                if (backpropagationAction.Item1 != null)
                                {
                                    using (var m = curr[0])
                                        curr[0] = backpropagationAction.Item1.Execute(m, trainingContext, actionStack.Any() || isT0, updateAccumulator);
                                }
                                if (backpropagationAction.Item2 != null)
                                {
                                    using (var m = curr[1])
                                        curr[1] = backpropagationAction.Item2.Execute(m, trainingContext, actionStack.Any() || isT0, updateAccumulator);
                                }
                                #region logging
                                if (logger != null)
                                {
                                    logger.WriteStartElement("error");
                                    foreach (var item in curr)
                                    {
                                        item.WriteTo(logger);
                                    }
                                    logger.WriteEndElement();
                                }
                                #endregion
                            }

                            // apply any filters
                            foreach (var filter in _filter)
                            {
                                foreach (var item in curr)
                                {
                                    filter.AfterBackPropagation(update.Item4, update.Item5, item);
                                }
                            }

                            // adjust the initial memory against the error signal
                            if (isT0)
                            {
                                using (var columnSums0 = curr[0].ColumnSums())
                                    using (var columnSums1 = curr[1].ColumnSums()) {
                                        var initialDelta = columnSums0.AsIndexable();
                                        for (var j = 0; j < forwardMemory.Length; j++)
                                        {
                                            forwardMemory[j] += initialDelta[j] * trainingContext.TrainingRate;
                                        }

                                        initialDelta = columnSums1.AsIndexable();
                                        for (var j = 0; j < backwardMemory.Length; j++)
                                        {
                                            backwardMemory[j] += initialDelta[j] * trainingContext.TrainingRate;
                                        }
                                    }
                            }
                        }
                    }

                    // cleanup
                    trainingContext.EndBatch();
                    _lap.PopLayer();
                    miniBatch.Dispose();
                }
                trainingContext.EndRecurrentEpoch(_collectTrainingError ? batchErrorList.Average() : 0, context);
            }
            return(new BidirectionalMemory(forwardMemory, backwardMemory));
        }
Exemple #12
0
 public void EndRecurrentEpoch(double trainingError, IRecurrentTrainingContext context)
 {
     EndEpoch(trainingError);
     RecurrentEpochComplete?.Invoke(this, context);
 }