示例#1
0
        //Oh Boy
        public JobHandle GradientDescentBackpropigate(NativeSlice2D <float> inputData, NativeSlice2D <float> expectedOutput, int testCaseCount, out float errorSum)
        {
            Profiler.BeginSample("NetworkEvaluator::GradientDescentBackpropigate");
            var resultArray = new NativeArray2D <float>(testCaseCount, OutputLayer.Size);

            var feedforwardHandle = Evaluate(inputData, resultArray.Slice(0, testCaseCount), testCaseCount);
            //Now we need to use all the jobs created before to evaluate this.

            //Compute output error

/*            var computeOutputErrorJob = new ErrorEvaluators.QuadraticSigmoidOutputErrorEvaluator()
 *          {
 *              Expected = expectedOutput,
 *              Actuall = resultArray,
 *              WeightedActivation = OutputLayer.WeightedInput,
 *              ErrorOut = OutputLayer.Error
 *          };*/

            var computeOutputErrorJob = new ErrorEvaluators.CrossEntropySigmoidOutputErrorEvaluator()
            {
                Expected           = expectedOutput,
                Actuall            = resultArray.Slice(0, testCaseCount),
                WeightedActivation = OutputLayer.SliceWeightedInputs(0, testCaseCount),
                ErrorOut           = OutputLayer.SliceError(0, testCaseCount)
            };


            var outputErrorHandle = computeOutputErrorJob.Schedule(OutputLayer.Error.Dimensions.y, 4, feedforwardHandle);

            //Convert that output error to the output node gradient

            //outputErrorHandle.Complete();
            //Perform backpropigation
            JobHandle backpropigationHandle = outputErrorHandle;

            for (int layerIndex = _layers.Count - 2; layerIndex >= 0; layerIndex--)
            {
                var targetLayer = _layers[layerIndex];
                var nextLayer   = _layers[layerIndex + 1];

                if (layerIndex != 0)
                {
                    backpropigationHandle = targetLayer.ModelLayer.BackpropigateLayer(targetLayer, nextLayer, testCaseCount, backpropigationHandle);
                }

                var accumulateGradientOverWeightJob = new ErrorEvaluators.AccumulateGradientOverWeight()
                {
                    PreviousActivation = targetLayer.SliceActivations(0, testCaseCount),
                    NextError          = nextLayer.SliceError(0, testCaseCount),

                    WeightGradients = nextLayer.WeightGradients
                };

                backpropigationHandle =
                    accumulateGradientOverWeightJob.Schedule(nextLayer.WeightGradients.Length, 4,
                                                             backpropigationHandle);
            }

            JobHandle.ScheduleBatchedJobs();


            //Update weights (all layers but first)
            JobHandle updateNetworkJobHandle = backpropigationHandle;

            for (int layerIndex = 1; layerIndex < _layers.Count; layerIndex++)
            {
                var layer = _layers[layerIndex];

                var applyGradientsToWeightsJob = new ErrorEvaluators.ApplyGradientToLayerWeights()
                {
                    TestCount       = 1,
                    LearningRate    = LearningRate,
                    WeightGradients = layer.WeightGradients,
                    LayerWeights    = layer.ModelLayer.Weights
                };

                updateNetworkJobHandle =
                    applyGradientsToWeightsJob.Schedule(layer.ModelLayer.Weights.Length, 4, updateNetworkJobHandle);

                //updateNetworkJobHandle.Complete();

                var applyGradientsToBiasesJob = new ErrorEvaluators.ApplyGradientToLayerBiases()
                {
                    TestCount    = 1,
                    LearningRate = LearningRate,
                    LayerBiases  = layer.ModelLayer.Biases,
                    LayerErrors  = layer.Error.Slice(0, testCaseCount)
                };

                updateNetworkJobHandle =
                    applyGradientsToBiasesJob.Schedule(layer.ModelLayer.Biases.Length, 4, updateNetworkJobHandle);

                //updateNetworkJobHandle.Complete();
                int x = 1;
            }

            //return error i guess ?

            updateNetworkJobHandle.Complete();

            //Compute average error sum for test cases
            ComputeErrorSum(expectedOutput, testCaseCount, out errorSum);

            resultArray.Dispose();

            Profiler.EndSample();
            return(updateNetworkJobHandle);
        }
示例#2
0
        protected void InitFromData(int inputSize, int outputSize, int elements, NativeArray2D <float> inputData, NativeArray2D <float> results, float trainingSetSize = 0.7f)
        {
            if (inputData.Dimensions.x != elements)
            {
                throw new ArgumentException();       //TODO:
            }

            if (inputData.Dimensions.y != inputSize)
            {
                throw new ArgumentException();       //TODO:
            }

            if (results.Dimensions.x != elements)
            {
                throw new ArgumentException();
            }

            if (results.Dimensions.y != outputSize)
            {
                throw new ArgumentException();
            }

            TrainingSetSize = (int)(elements * trainingSetSize);
            TestingSetSize  = elements - TrainingSetSize;

            CaseCount = elements;

            InputSize  = inputSize;
            ResultSize = outputSize;

            InputData      = inputData;
            ExpectedResult = results;
        }