Ejemplo n.º 1
0
        protected static double GetPatternError(INet net, ILearningPattern pattern)
        {
            var actual = net.Propagate(pattern.Input);
            var expected = pattern.Output;

            var errors = new List<IFuzzyNumber>();
            var i = 0;
            foreach (var actualNumber in actual)
            {
                errors.Add(actualNumber.Sub(expected.ElementAt(i)));
                i++;
            }

            var patternError = 0.0;
            foreach (var errorNumber in errors)
            {
                var leftError = 0.0;
                var rightError = 0.0;
                errorNumber.ForeachLevel((alpha, level) =>
                {
                    leftError += alpha * (level.X * level.X);
                    rightError += alpha * (level.Y * level.Y);
                });

                var currentOutputError = leftError + rightError;
                patternError += currentOutputError;
            }

            return patternError / 2.0;
        }
Ejemplo n.º 2
0
        protected override void LearnPattern(INet net, ILearningPattern learningPattern, double currentPatternError)
        {
            //call only after net.propagation()
            PropagateErrorOnLayers(net.Layers, learningPattern.Output);
            CalculateWeightDelta(net);

            //ChangeAndSetWeights(_deltas, net);
            //_deltas = null;
        }
Ejemplo n.º 3
0
 //here we summarize gradient of each pattern
 protected override void LearnPattern(INet net, ILearningPattern learningPattern, double currentPatternError)
 {
     PropagateErrorOnLayers(net.Layers, learningPattern.Output); //nablaF(xk)
     var currentGradient = CreateWeightsGradient(net.Layers);
     _gradient = _gradient == null ? currentGradient : _gradient.Sum(currentGradient);
 }
Ejemplo n.º 4
0
 protected abstract void LearnPattern(INet net, ILearningPattern learningPattern, double currentPatternError);
Ejemplo n.º 5
0
 protected virtual double CalculatePatternError(INet net, ILearningPattern pattern)
 {
     return GetPatternError(net, pattern);
 }