Exemplo n.º 1
0
        public override void train(int trainCount)
        {
            int    epochs    = 0;
            bool   moreError = true;
            double minError  = 1E-2;

            double[][] error         = new double[this.classMask.Length][];
            double[]   mse           = new double[MAX_EPOCHS]; //Mean square error.
            bool       weightChanged = true;
            int        classIndex;

            double[] lineData;

            /*REAL WORK*/
            while (weightChanged && epochs < NeuralNetwork.MAX_EPOCHS && moreError)
            {
                weightChanged = false;
                for (int i = 0; i < this.classMask.Length; i++)            //Class index.
                {
                    error[i]   = new double[trainCount];
                    classIndex = (int)classMask[i] - 1;
                    for (int j = 0; j < trainCount; j++)                //Data index.
                    {
                        lineData = VectorTools.trim(this.data[classIndex][j], this.featureMask);
                        lineData = VectorTools.prepend(lineData, this.bias);                         //Prepend the bias.

                        double net = VectorTools.multiply(this.weight, lineData);
                        net = getvalue(net);

                        if (net != this.target[i])
                        {
                            weightChanged = true;
                            lineData      = VectorTools.trim(this.data[classIndex][j], this.featureMask);
                            lineData      = VectorTools.prepend(lineData, this.bias);

                            error[i][j] = this.target[i] - net;
                            double[] mulOut = VectorTools.multiply(lineData, error[i][j] * this.eta);
                            error[i][j] = 0.5 * error[i][j] * error[i][j];
                            this.weight = VectorTools.sum(this.weight, mulOut);
                        }
                    }
                }                 //End of inner for.

                double[] temp = VectorTools.get1D(error);
                mse[epochs] = VectorTools.mean(temp);
                if (mse[epochs] < minError)
                {
                    moreError = false;
                }
                epochs++;
            }             //End of outer while.
        }
        /// <summary>Train the machine using the perceptron algorithm.</summary>
        /// <param name="trainCount">Number of data set smaples to use in training.</param>
        public override void train(int trainCount)
        {
            int  epochs        = 0;
            bool weightChanged = true;
            int  classIndex;

            double[] lineData;

            /*REAL WORK*/
            while (weightChanged && epochs < NeuralNetwork.MAX_EPOCHS)
            {
                weightChanged = false;
                for (int i = 0; i < this.classMask.Length; i++)            //Class index.
                {
                    classIndex = (int)classMask[i] - 1;
                    for (int j = 0; j < trainCount; j++)                //Data index.
                    {
                        lineData = VectorTools.trim(this.data[classIndex][j], this.featureMask);
                        lineData = VectorTools.prepend(lineData, this.bias);                         //Prepend the bias.

                        double net    = VectorTools.multiply(this.weight, lineData);
                        int    sgnOut = ActivationFunctions.signum(net);

                        if (sgnOut != this.target[i])
                        {
                            weightChanged = true;
                            lineData      = VectorTools.trim(this.data[classIndex][j], this.featureMask);
                            lineData      = VectorTools.prepend(lineData, this.bias);

                            double   error  = this.target[i] - sgnOut;
                            double[] mulOut = VectorTools.multiply(lineData, error * this.eta);
                            this.weight = VectorTools.sum(this.weight, mulOut);
                        }
                    }
                }                 //End of inner for.

                epochs++;
            }             //End of outer while.
        }