コード例 #1
0
        public void Train(int cpi, int cs = 5, int seed = -1)
        {
            if (seed == -1)
            {
                seed = Environment.TickCount;
            }
            Random      r  = new Random(seed);
            int         ai = 0;
            TrainingSet ts = new TrainingSet(Inputs, W * H * 3);

            foreach (var i in InImgs)
            {
                double[] iv = new double[Inputs];
                double[] ov = new double[W * H * 3];

                int ic = 0;
                for (int y = 0; y < i.H; y++)
                {
                    for (int x = 0; x < i.W; x++)
                    {
                        iv[ic] = GV(i.Dat[ic++]);
                        iv[ic] = GV(i.Dat[ic++]);
                        iv[ic] = GV(i.Dat[ic++]);
                    }
                }
                Image oi = OutImgs[ai];
                int   vv = 0;
                for (int y = 0; y < i.H; y++)
                {
                    for (int x = 0; x < i.W; x++)
                    {
                        //int l = (i.H * y * 3) + (x * 3);
                        ov[vv] = GV(i.Dat[vv++]);
                        ov[vv] = GV(i.Dat[vv++]);
                        ov[vv] = GV(i.Dat[vv++]);
                    }
                }

                ai++;

                TrainingSample s = new TrainingSample(iv, ov);
                for (int xc = 0; xc < cpi; xc++)
                {
                    ts.Add(s);
                }
            }
            Ready = false;
            //for(int t = 0; t < cs; t++)
            //{
            //            net.BeginEpochEvent += TrainE;
            net.EndEpochEvent += EndE;

            net.Learn(ts, cs);
            net.StopLearning();
            Console.WriteLine("Done training mind.");
        }
コード例 #2
0
        /// <summary>
        /// 点击计算按钮
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void tsmiCalculate_Click(object sender, EventArgs e)
        {
            // 创建输入层、隐层和输出层
            ActivationLayer inputLayer  = GetLayer(cboInputLayerType.SelectedItem.ToString(), 2);
            ActivationLayer hiddenLayer = GetLayer(cboHiddenLayerType.SelectedItem.ToString(), int.Parse(txtHiddenLayerCount.Text));
            ActivationLayer outputLayer = GetLayer(cboOutputLayerType.SelectedItem.ToString(), 1);

            // 创建层之间的关联
            new BackpropagationConnector(inputLayer, hiddenLayer, ConnectionMode.Complete).Initializer  = new RandomFunction(0, 0.3);
            new BackpropagationConnector(hiddenLayer, outputLayer, ConnectionMode.Complete).Initializer = new RandomFunction(0, 0.3);

            // 创建神经网络
            var network = new BackpropagationNetwork(inputLayer, outputLayer);

            network.SetLearningRate(double.Parse(txtInitialLearningRate.Text), double.Parse(txtFinalLearningRate.Text));

            // 进行训练
            var trainingSet = new TrainingSet(2, 1);

            for (var i = 0; i < 17; i++)
            {
                var x1 = data[i, 0];
                var x2 = data[i, 1];
                var y  = data[i, 2];

                var inputVector    = new double[] { x1, x2 };
                var outputVector   = new double[] { y };
                var trainingSample = new TrainingSample(inputVector, outputVector);
                trainingSet.Add(trainingSample);
            }
            network.SetLearningRate(0.3, 0.1);
            network.Learn(trainingSet, int.Parse(txtTrainingEpochs.Text));
            network.StopLearning();

            // 进行预测
            for (var i = 0; i < 17; i++)
            {
                var x1 = data[i, 0];
                var x2 = data[i, 1];
                var y  = data[i, 2];

                var testInput  = new double[] { x1, x2 };
                var testOutput = network.Run(testInput)[0];

                var absolute = testOutput - y;
                var relative = Math.Abs((testOutput - y) / testOutput);

                dgvData.Rows[i].Cells[3].Value = testOutput.ToString("f3");
                dgvData.Rows[i].Cells[4].Value = absolute.ToString("f3");
                dgvData.Rows[i].Cells[5].Value = (relative * 100).ToString("f1") + "%";
            }
        }
コード例 #3
0
ファイル: MainForm.cs プロジェクト: lanicon/waveletstudio
 private void StopLearning(object sender, EventArgs e)
 {
     if (network != null)
     {
         network.StopLearning();
         LineItem lineItem = new LineItem("Approximated Function");
         for (double xVal = 0; xVal < 10; xVal += 0.05d)
         {
             lineItem.AddPoint(xVal, network.Run(new double[] { xVal })[0]);
         }
         lineItem.Symbol.Type = SymbolType.None;
         lineItem.Color       = Color.DarkOrchid;
         functionGraph.GraphPane.CurveList.Add(lineItem);
         functionGraph.Refresh();
         functionGraph.GraphPane.CurveList.Remove(lineItem);
     }
     network = null;
     EnableControls(true);
 }