Exemplo n.º 1
0
        //Oh Boy
        public JobHandle GradientDescentBackpropigate(NativeSlice2D <float> inputData, NativeSlice2D <float> expectedOutput, int testCaseCount, out float errorSum)
        {
            Profiler.BeginSample("NetworkEvaluator::GradientDescentBackpropigate");
            var resultArray = new NativeArray2D <float>(testCaseCount, OutputLayer.Size);

            var feedforwardHandle = Evaluate(inputData, resultArray.Slice(0, testCaseCount), testCaseCount);
            //Now we need to use all the jobs created before to evaluate this.

            //Compute output error

/*            var computeOutputErrorJob = new ErrorEvaluators.QuadraticSigmoidOutputErrorEvaluator()
 *          {
 *              Expected = expectedOutput,
 *              Actuall = resultArray,
 *              WeightedActivation = OutputLayer.WeightedInput,
 *              ErrorOut = OutputLayer.Error
 *          };*/

            var computeOutputErrorJob = new ErrorEvaluators.CrossEntropySigmoidOutputErrorEvaluator()
            {
                Expected           = expectedOutput,
                Actuall            = resultArray.Slice(0, testCaseCount),
                WeightedActivation = OutputLayer.SliceWeightedInputs(0, testCaseCount),
                ErrorOut           = OutputLayer.SliceError(0, testCaseCount)
            };


            var outputErrorHandle = computeOutputErrorJob.Schedule(OutputLayer.Error.Dimensions.y, 4, feedforwardHandle);

            //Convert that output error to the output node gradient

            //outputErrorHandle.Complete();
            //Perform backpropigation
            JobHandle backpropigationHandle = outputErrorHandle;

            for (int layerIndex = _layers.Count - 2; layerIndex >= 0; layerIndex--)
            {
                var targetLayer = _layers[layerIndex];
                var nextLayer   = _layers[layerIndex + 1];

                if (layerIndex != 0)
                {
                    backpropigationHandle = targetLayer.ModelLayer.BackpropigateLayer(targetLayer, nextLayer, testCaseCount, backpropigationHandle);
                }

                var accumulateGradientOverWeightJob = new ErrorEvaluators.AccumulateGradientOverWeight()
                {
                    PreviousActivation = targetLayer.SliceActivations(0, testCaseCount),
                    NextError          = nextLayer.SliceError(0, testCaseCount),

                    WeightGradients = nextLayer.WeightGradients
                };

                backpropigationHandle =
                    accumulateGradientOverWeightJob.Schedule(nextLayer.WeightGradients.Length, 4,
                                                             backpropigationHandle);
            }

            JobHandle.ScheduleBatchedJobs();


            //Update weights (all layers but first)
            JobHandle updateNetworkJobHandle = backpropigationHandle;

            for (int layerIndex = 1; layerIndex < _layers.Count; layerIndex++)
            {
                var layer = _layers[layerIndex];

                var applyGradientsToWeightsJob = new ErrorEvaluators.ApplyGradientToLayerWeights()
                {
                    TestCount       = 1,
                    LearningRate    = LearningRate,
                    WeightGradients = layer.WeightGradients,
                    LayerWeights    = layer.ModelLayer.Weights
                };

                updateNetworkJobHandle =
                    applyGradientsToWeightsJob.Schedule(layer.ModelLayer.Weights.Length, 4, updateNetworkJobHandle);

                //updateNetworkJobHandle.Complete();

                var applyGradientsToBiasesJob = new ErrorEvaluators.ApplyGradientToLayerBiases()
                {
                    TestCount    = 1,
                    LearningRate = LearningRate,
                    LayerBiases  = layer.ModelLayer.Biases,
                    LayerErrors  = layer.Error.Slice(0, testCaseCount)
                };

                updateNetworkJobHandle =
                    applyGradientsToBiasesJob.Schedule(layer.ModelLayer.Biases.Length, 4, updateNetworkJobHandle);

                //updateNetworkJobHandle.Complete();
                int x = 1;
            }

            //return error i guess ?

            updateNetworkJobHandle.Complete();

            //Compute average error sum for test cases
            ComputeErrorSum(expectedOutput, testCaseCount, out errorSum);

            resultArray.Dispose();

            Profiler.EndSample();
            return(updateNetworkJobHandle);
        }
Exemplo n.º 2
0
 public virtual NativeSlice2D <float> SliceWeightedInputs(int start, int count)
 {
     return(WeightedInput.Slice(start, count));
 }
Exemplo n.º 3
0
 public virtual NativeSlice2D <float> SliceError(int start, int count)
 {
     return(Error.Slice(start, count));
 }
Exemplo n.º 4
0
        public NativeArray <float> WeightGradients; //This does not need to scale by slice size

        public virtual NativeSlice2D <float> SliceActivations(int start, int count)
        {
            return(OutputActivation.Slice(start, count));
        }