コード例 #1
0
        public void TrainNetworkByBatch(DataSamples samples, DataSamples testSamples, BackPropParams param)
        {
            var rnd = new Random();

            //
            infos.InitNeurons(rnd);
            infos.InitNG(Network, rnd);
            var    stop  = false;
            var    step  = 0;
            double error = 0.0;
            var    count = samples.Count;

            while (!stop)
            {
                for (int sampleIndex = 0; sampleIndex < count; sampleIndex++)
                {
                    var data = samples[sampleIndex];
                    // forward step
                    Network.Inputs = data.Inputs;
                    Network.Calc();
                    // backward
                    infos.ProcessGradients(data.Outputs /*, sampleIndex == 0*/);
                    infos.ProcessDeltas(param.Eta, param.Alpha, sampleIndex == 0);
                }
                infos.UpdateWeights(param.Alpha);
                error = Network.Error(testSamples);
                step++;
                if (param.CallBack != null)
                {
                    param.CallBack(step, error, false);
                }
                stop = (step >= param.MaxSteps) || (error <= param.ErrorStopValue);
            }
            if (param.CallBack != null)
            {
                param.CallBack(step, error, true);
            }
        }
コード例 #2
0
        public void TrainNetworkBySample(DataSamples samples, DataSamples testSamples, BackPropParams param)
        {
            var rnd   = new Random();
            var items = GenerateIndexes(samples);

            //
            infos.InitNeurons(rnd);
            infos.InitNG(Network, rnd);
            var    stop  = false;
            var    step  = 0;
            double error = 0.0;

            while (!stop)
            {
                ShuffleIndexes(items, rnd);
                for (int i = 0; i < items.Count; i++)
                {
                    TrainBySample(samples[items[i]], param.Eta, param.Alpha);
                }
                //
                error = Network.Error(testSamples);
                step++;
                if (param.CallBack != null)
                {
                    param.CallBack(step, error, false);
                }
                stop = (step >= param.MaxSteps) || (error <= param.ErrorStopValue);
            }
            if (param.CallBack != null)
            {
                param.CallBack(step, error, true);
            }
        }