예제 #1
0
        public void Train(ITrainingDataProvider trainingData, int numEpochs, ITrainingContext context)
        {
            IMatrix curr = null;
            var     additionalBackpropagation = trainingData as ICanBackpropagate;

            for (int i = 0; i < numEpochs && context.ShouldContinue; i++)
            {
                context.StartEpoch(trainingData.Count);
                trainingData.StartEpoch();
                var batchErrorList = new List <double>();

                // iterate over each mini batch
                foreach (var miniBatch in _GetMiniBatches(trainingData, _stochastic, context.MiniBatchSize))
                {
                    var garbage = new List <IMatrix>();
                    garbage.Add(curr = miniBatch.Input);
                    _lap.PushLayer();

                    // set up the layer stack
                    var layerStack = new Stack <ICanBackpropagate>();
                    if (additionalBackpropagation != null)
                    {
                        layerStack.Push(additionalBackpropagation);
                    }

                    // feed forward
                    foreach (var layer in _layer)
                    {
                        garbage.Add(curr = layer.FeedForward(curr, true));
                        layerStack.Push(layer);
                    }

                    // calculate the error against the training examples
                    using (var expectedOutput = miniBatch.ExpectedOutput) {
                        garbage.Add(curr = context.ErrorMetric.CalculateDelta(curr, expectedOutput));

                        // calculate the training error for this mini batch
                        if (_calculateTrainingError)
                        {
                            batchErrorList.Add(curr.AsIndexable().Values.Select(v => Math.Pow(v, 2)).Average() / 2);
                        }

                        // backpropagate the error
                        while (layerStack.Any())
                        {
                            var currentLayer = layerStack.Pop();
                            garbage.Add(curr = currentLayer.Backpropagate(curr, context, layerStack.Any()));
                        }
                    }

                    // clear memory
                    context.EndBatch();
                    garbage.ForEach(m => m?.Dispose());
                    _lap.PopLayer();
                }
                context.EndEpoch(_calculateTrainingError ? batchErrorList.Average() : 0f);
            }
        }
예제 #2
0
        public ExecutionResult Execute(float[] input)
        {
            _lap.PushLayer();
            ExecutionResult ret = null;

            _dataSource = new SingleRowDataSource(input, false, MiniBatchSequenceType.Standard, 0);
            var provider = new MiniBatchProvider(_dataSource, false);

            using (var executionContext = new ExecutionContext(_lap)) {
                executionContext.Add(provider.GetMiniBatches(1, mb => _Execute(executionContext, mb)));

                IGraphOperation operation;
                while ((operation = executionContext.GetNextOperation()) != null)
                {
                    _lap.PushLayer();
                    operation.Execute(executionContext);
                    ret = _GetResults().Single();
                    _ClearContextList();
                    _lap.PopLayer();
                }
            }
            _lap.PopLayer();
            _dataSource = null;
            return(ret);
        }