private static void repeat(double[][] inputs, int[] outputs,
                                   DecisionTree tree, int training, double threshold,
                                   out int nodeCount2)
        {
            int nodeCount = 0;

            foreach (var node in tree)
            {
                nodeCount++;
            }

            var pruningInputs       = inputs.Submatrix(training, inputs.Length - 1);
            var pruningOutputs      = outputs.Submatrix(training, inputs.Length - 1);
            ErrorBasedPruning prune = new ErrorBasedPruning(tree, pruningInputs, pruningOutputs);

            prune.Threshold = threshold;

            double lastError;
            double error = Double.PositiveInfinity;

            do
            {
                lastError = error;
                error     = prune.Run();
            } while (error < lastError);

            nodeCount2 = 0;
            foreach (var node in tree)
            {
                nodeCount2++;
            }
        }
        public void RunTest()
        {
            Accord.Math.Random.Generator.Seed = 0;

            double[][] inputs;
            int[]      outputs;

            int          trainingSamplesCount = 6000;
            DecisionTree tree = ReducedErrorPruningTest.createNurseryExample(out inputs, out outputs, trainingSamplesCount);

            int nodeCount = 0;

            foreach (var node in tree)
            {
                nodeCount++;
            }

            var pruningInputs       = inputs.Submatrix(trainingSamplesCount, inputs.Length - 1);
            var pruningOutputs      = outputs.Submatrix(trainingSamplesCount, inputs.Length - 1);
            ErrorBasedPruning prune = new ErrorBasedPruning(tree, pruningInputs, pruningOutputs);

            prune.Threshold = 0.1;

            double lastError, error = Double.PositiveInfinity;

            do
            {
                lastError = error;
                error     = prune.Run();
            } while (error < lastError);

            int nodeCount2 = 0;

            foreach (var node in tree)
            {
                nodeCount2++;
            }

            Assert.AreEqual(0.28922413793103446, error, 5e-4);
            Assert.AreEqual(447, nodeCount);
            Assert.AreEqual(424, nodeCount2);
        }
Ejemplo n.º 3
0
        public void RunTest()
        {
            double[][] inputs;
            int[]      outputs;

            int          training = 6000;
            DecisionTree tree     = ReducedErrorPruningTest.createNurseryExample(out inputs, out outputs, training);

            int nodeCount = 0;

            foreach (var node in tree)
            {
                nodeCount++;
            }

            var pruningInputs       = inputs.Submatrix(training, inputs.Length - 1);
            var pruningOutputs      = outputs.Submatrix(training, inputs.Length - 1);
            ErrorBasedPruning prune = new ErrorBasedPruning(tree, pruningInputs, pruningOutputs);

            prune.Threshold = 0.1;

            double lastError, error = Double.PositiveInfinity;

            do
            {
                lastError = error;
                error     = prune.Run();
            } while (error < lastError);

            int nodeCount2 = 0;

            foreach (var node in tree)
            {
                nodeCount2++;
            }

            Assert.AreEqual(0.25459770114942532, error);
            Assert.AreEqual(447, nodeCount);
            Assert.AreEqual(193, nodeCount2);
        }
Ejemplo n.º 4
0
        /// <summary>
        /// 构建决策树
        /// </summary>
        public void BuildTree()
        {
            // 采样
            List <Cell> samplePoints = getSample();

            updateConsoleEvent("-----------");
            updateConsoleEvent("起始城市栅格数目:" + this.BeginCityCnt);
            updateConsoleEvent("目标城市栅格数目:" + this.EndCityCnt);
            updateConsoleEvent("-----------");
            updateConsoleEvent("-----------开始训练----------");

            // 样本数目
            int COUNT = samplePoints.Count;

            // 构造输入和输出数据集
            double[][] inputs  = new double[COUNT][];
            int[]      outputs = new int[COUNT];
            for (int i = 0; i < COUNT; i++)
            {
                Cell          cell  = samplePoints[i];
                int           pos   = cell.row * width + cell.col;
                List <double> input = (from buffer in driveBufferList
                                       select buffer[pos]).ToList <double>();
                input.Add(GetNeighbourAffect(beginBuffer, width, height, cell.row, cell.col, 3));
                inputs[i] = input.ToArray <double>();
                if (this.landInfo.UrbanInfos[0].LandUseTypeValue == (int)beginBuffer[pos])
                {
                    outputs[i] = 1;
                }
                else
                {
                    outputs[i] = 0;
                }
            }


            // 训练数据集
            var trainingInputs = inputs.Submatrix(0, COUNT / 2 - 1);
            var trainingOutput = outputs.Submatrix(0, COUNT / 2 - 1);

            // 检验数据集
            var pruningInputs = inputs.Submatrix(COUNT / 2, COUNT - 1);
            var pruningOutput = outputs.Submatrix(COUNT / 2, COUNT - 1);

            // 设置驱动因子的名字
            List <DecisionVariable> featuresList = (from column in this.driveLayerNames
                                                    select new DecisionVariable(column, DecisionVariableKind.Continuous)).ToList <DecisionVariable>();

            featuresList.Add(new DecisionVariable("affectofneighbour", DecisionVariableKind.Continuous));

            // 训练树
            var tree    = new DecisionTree(inputs: featuresList, classes: 2);
            var teacher = new C45Learning(tree);

            teacher.Learn(trainingInputs, trainingOutput);

            // 剪枝
            ErrorBasedPruning prune = new ErrorBasedPruning(tree, pruningInputs, pruningOutput);

            prune.Threshold = 0.1;// Gain threshold ?
            double lastError;
            double error = Double.PositiveInfinity;

            do
            {
                lastError = error;
                error     = prune.Run();
            } while (error < lastError);

            updateConsoleEvent("错误率:" + error);

            this.func = tree.ToExpression().Compile();
            //UpdateUi("错误率" + error);

            DecisionSet rules    = tree.ToRules();
            string      ruleText = rules.ToString();

            //consolePad.addLineToInfo(ruleText);
            updateConsoleEvent("规则:");
            updateConsoleEvent(ruleText);
            updateConsoleEvent("-----------训练结束----------");
        }