Exemplo n.º 1
0
        public void TrainOnMiniBatch(ISequentialMiniBatch miniBatch, float[] memory, IRecurrentTrainingContext context, Action <IMatrix> beforeBackProp, Action <IMatrix> afterBackProp)
        {
            var trainingContext = context.TrainingContext;

            _lap.PushLayer();
            var sequenceLength = miniBatch.SequenceLength;
            var updateStack    = new Stack <Tuple <Stack <INeuralNetworkRecurrentBackpropagation>, IMatrix, IMatrix, ISequentialMiniBatch, int> >();

            context.ExecuteForward(miniBatch, memory, (k, fc) => {
                var layerStack = new Stack <INeuralNetworkRecurrentBackpropagation>();
                foreach (var action in _layer)
                {
                    layerStack.Push(action.Execute(fc, true));
                }
                updateStack.Push(Tuple.Create(layerStack, miniBatch.GetExpectedOutput(fc, k), fc[0], miniBatch, k));
            });

            // backpropagate, accumulating errors across the sequence
            using (var updateAccumulator = new UpdateAccumulator(trainingContext)) {
                IMatrix curr = null;
                while (updateStack.Any())
                {
                    var update      = updateStack.Pop();
                    var isT0        = !updateStack.Any();
                    var actionStack = update.Item1;

                    // calculate error
                    var expectedOutput = update.Item2;
                    if (expectedOutput != null)
                    {
                        curr = trainingContext.ErrorMetric.CalculateDelta(update.Item3, expectedOutput);
                    }

                    // backpropagate
                    beforeBackProp?.Invoke(curr);
                    while (actionStack.Any())
                    {
                        var backpropagationAction = actionStack.Pop();
                        var shouldCalculateOutput = actionStack.Any() || isT0;
                        curr = backpropagationAction.Execute(curr, trainingContext, true, updateAccumulator);
                    }
                    afterBackProp?.Invoke(curr);

                    // apply any filters
                    foreach (var filter in _filter)
                    {
                        filter.AfterBackPropagation(update.Item4, update.Item5, curr);
                    }
                }

                // adjust the initial memory against the error signal
                if (curr != null)
                {
                    using (var columnSums = curr.ColumnSums()) {
                        var initialDelta = columnSums.AsIndexable();
                        for (var j = 0; j < memory.Length; j++)
                        {
                            memory[j] += initialDelta[j] * trainingContext.TrainingRate;
                        }
                    }
                }
            }

            // cleanup
            trainingContext.EndBatch();
            _lap.PopLayer();
        }
Exemplo n.º 2
0
        public BidirectionalMemory Train(ISequentialTrainingDataProvider trainingData, float[] forwardMemory, float[] backwardMemory, int numEpochs, IRecurrentTrainingContext context)
        {
            var trainingContext = context.TrainingContext;
            var logger          = trainingContext.Logger;

            for (int i = 0; i < numEpochs && context.TrainingContext.ShouldContinue; i++)
            {
                trainingContext.StartEpoch(trainingData.Count);
                var batchErrorList = new List <double>();

                foreach (var miniBatch in _GetMiniBatches(trainingData, _stochastic, trainingContext.MiniBatchSize))
                {
                    _lap.PushLayer();
                    var sequenceLength = miniBatch.SequenceLength;
                    var updateStack    = new Stack <Tuple <Stack <Tuple <INeuralNetworkRecurrentBackpropagation, INeuralNetworkRecurrentBackpropagation> >, IMatrix, IMatrix, ISequentialMiniBatch, int> >();
                    context.ExecuteBidirectional(miniBatch, _layer, forwardMemory, backwardMemory, _padding, updateStack, null);

                    // backpropagate, accumulating errors across the sequence
                    using (var updateAccumulator = new UpdateAccumulator(trainingContext)) {
                        while (updateStack.Any())
                        {
                            var update      = updateStack.Pop();
                            var isT0        = !updateStack.Any();
                            var actionStack = update.Item1;

                            // calculate error
                            var expectedOutput = update.Item2;
                            var curr           = new List <IMatrix>();
                            curr.Add(trainingContext.ErrorMetric.CalculateDelta(update.Item3, expectedOutput));

                            // get a measure of the training error
                            if (_collectTrainingError)
                            {
                                foreach (var item in curr)
                                {
                                    batchErrorList.Add(item.AsIndexable().Values.Select(v => Math.Pow(v, 2)).Average() / 2);
                                }
                            }

                            #region logging
                            if (logger != null)
                            {
                                logger.WriteStartElement("initial-error");
                                foreach (var item in curr)
                                {
                                    item.WriteTo(logger);
                                }
                                logger.WriteEndElement();
                            }
                            #endregion

                            // backpropagate
                            while (actionStack.Any())
                            {
                                var backpropagationAction = actionStack.Pop();
                                if (backpropagationAction.Item1 != null && backpropagationAction.Item2 != null && curr.Count == 1)
                                {
                                    using (var m = curr[0]) {
                                        var split = m.SplitRows(forwardMemory.Length);
                                        curr[0] = split.Left;
                                        curr.Add(split.Right);
                                    }
                                    #region logging
                                    if (logger != null)
                                    {
                                        logger.WriteStartElement("post-split");
                                        foreach (var item in curr)
                                        {
                                            item.WriteTo(logger);
                                        }
                                        logger.WriteEndElement();
                                    }
                                    #endregion
                                }
                                if (backpropagationAction.Item1 != null)
                                {
                                    using (var m = curr[0])
                                        curr[0] = backpropagationAction.Item1.Execute(m, trainingContext, actionStack.Any() || isT0, updateAccumulator);
                                }
                                if (backpropagationAction.Item2 != null)
                                {
                                    using (var m = curr[1])
                                        curr[1] = backpropagationAction.Item2.Execute(m, trainingContext, actionStack.Any() || isT0, updateAccumulator);
                                }
                                #region logging
                                if (logger != null)
                                {
                                    logger.WriteStartElement("error");
                                    foreach (var item in curr)
                                    {
                                        item.WriteTo(logger);
                                    }
                                    logger.WriteEndElement();
                                }
                                #endregion
                            }

                            // apply any filters
                            foreach (var filter in _filter)
                            {
                                foreach (var item in curr)
                                {
                                    filter.AfterBackPropagation(update.Item4, update.Item5, item);
                                }
                            }

                            // adjust the initial memory against the error signal
                            if (isT0)
                            {
                                using (var columnSums0 = curr[0].ColumnSums())
                                    using (var columnSums1 = curr[1].ColumnSums()) {
                                        var initialDelta = columnSums0.AsIndexable();
                                        for (var j = 0; j < forwardMemory.Length; j++)
                                        {
                                            forwardMemory[j] += initialDelta[j] * trainingContext.TrainingRate;
                                        }

                                        initialDelta = columnSums1.AsIndexable();
                                        for (var j = 0; j < backwardMemory.Length; j++)
                                        {
                                            backwardMemory[j] += initialDelta[j] * trainingContext.TrainingRate;
                                        }
                                    }
                            }
                        }
                    }

                    // cleanup
                    trainingContext.EndBatch();
                    _lap.PopLayer();
                    miniBatch.Dispose();
                }
                trainingContext.EndRecurrentEpoch(_collectTrainingError ? batchErrorList.Average() : 0, context);
            }
            return(new BidirectionalMemory(forwardMemory, backwardMemory));
        }