示例#1
0
文件: CARTTest.cs 项目: Dasmic/MLLib
        public void CART_maketree_check_all_validation_data_set()
        {
            initData_Jason();
            BuildCART cart =
                new BuildCART();

            cart.SetMissingValue(999);
            cart.SetParameters(1);

            ModelCART mb =
                (ModelCART)cart.BuildModel(_trainingData,
                                           _attributeHeaders,
                                           _indexTargetAttribute);

            int count = 0;

            for (int row = 0; row < _validationData[0].Length; row++)
            {
                double[] data  = new double[] { _validationData[0][row], _validationData[1][row] };
                double   value = mb.RunModelForSingleData(data);
                if (value == _validationData[2][row])
                {
                    count++;
                }
            }
            Assert.AreEqual(count, 9);
        }
示例#2
0
        public override Common.MLCore.ModelBase BuildModel(
            double[][] trainingData,
            string[] attributeHeaders,
            int indexTargetAttribute)
        {
            //Verify data and set variables
            VerifyData(trainingData, attributeHeaders, indexTargetAttribute);

            ModelBaggedDecisionTree model =
                new ModelBaggedDecisionTree(_missingValue,
                                            _indexTargetAttribute,
                                            _trainingData.Length - 1,
                                            _numberOfTrees);

            //By default samples/tree is same as original samples
            if (_numberOfSamplesPerTree == long.MaxValue)
            {
                _numberOfSamplesPerTree = _trainingData[0].Length;
            }

            //Split the data for each tree
            //Parallelize this
            Parallel.For(0, _numberOfTrees, new ParallelOptions {
                MaxDegreeOfParallelism = _maxParallelThreads
            }, ii =>
                         //for (int ii=0;ii<_numberOfTrees;ii++)
            {
                Random rnd = new Random();
                ConcurrentBag <long> trainingDataRowIndices =
                    new ConcurrentBag <long>();

                //Initialize the rows
                for (long idx = 0;
                     idx < _numberOfSamplesPerTree; idx++)
                {
                    //Random Sampling with Replacement
                    long rowIdx = rnd.Next(0, _trainingData[0].Length - 1);

                    //rowIdx = idx + startIdx;
                    //rowIdx = rowIdx > _noOfDataSamples - 1 ?
                    //                    rowIdx - _noOfDataSamples : rowIdx;
                    trainingDataRowIndices.Add(rowIdx);
                }

                //For test only
                BuildCART buildCart = new BuildCART();
                ModelCART modelCart = (ModelCART)buildCart.BuildModel(trainingData,
                                                                      attributeHeaders,
                                                                      indexTargetAttribute,
                                                                      trainingDataRowIndices);
                model.AddTree(ii, modelCart);
            }); //Number of trees

            return(model);
        }
示例#3
0
文件: CARTTest.cs 项目: Dasmic/MLLib
        public void CART_maketree_check_root_node_value()
        {
            initData_Jason();
            BuildCART cart =
                new BuildCART();

            cart.SetMissingValue(999);
            cart.SetParameters(1);

            ModelCART mb =
                (ModelCART)cart.BuildModel(_trainingData,
                                           _attributeHeaders,
                                           _indexTargetAttribute);

            Assert.AreEqual(mb.Root.AttributeIndex, 0);
            Assert.AreEqual(mb.Root.Value, 6.642287351);
        }
示例#4
0
文件: CARTTest.cs 项目: Dasmic/MLLib
        public void CART_maketree_special_no_splitting_possible_true()
        {
            initData_special_no_splitting_possible();
            BuildCART cart =
                new BuildCART();
            ModelCART model =
                (ModelCART)cart.BuildModel(_trainingData,
                                           _attributeHeaders,
                                           _indexTargetAttribute);

            int row = 1;

            double[] data  = GetSingleTrainingRowDataForTest(row);
            double   value = model.RunModelForSingleData(data);

            Assert.AreEqual(value,
                            _trainingData[_indexTargetAttribute][row]);
        }
示例#5
0
文件: CARTTest.cs 项目: Dasmic/MLLib
        public void CART_maketree_check_root_node_structure()
        {
            initData_Jason();
            BuildCART cart =
                new BuildCART();

            cart.SetMissingValue(999);
            cart.SetParameters(1);

            ModelBase mb =
                (ModelBase)cart.BuildModel(_trainingData,
                                           _attributeHeaders,
                                           _indexTargetAttribute);

            Assert.AreEqual(mb.Root.Children.Count, 2);
            Assert.AreEqual(mb.Root.Children[0].Children, null);
            Assert.AreEqual(mb.Root.Children[1].Children, null);
        }
示例#6
0
文件: CARTTest.cs 项目: Dasmic/MLLib
        public void CART_gini_index_test()
        {
            initData_Jason();
            BuildCART cart =
                new BuildCART();

            setPrivateVariablesInBuildObject(cart);

            PrivateObject obj  = new PrivateObject(cart);
            double        gini = (double)obj.Invoke("GetGiniImpurity",
                                                    new object[] {
                0,
                2.771244718
            });

            Assert.IsTrue(
                SupportFunctions.DoubleCompare(gini, 0.49382716));
        }
示例#7
0
文件: CARTTest.cs 项目: Dasmic/MLLib
        public void CART_maketree_validate_single_training_data()
        {
            initData_Jason();
            BuildCART cart =
                new BuildCART();

            cart.SetMissingValue(999);
            cart.SetParameters(1);

            ModelCART mb =
                (ModelCART)cart.BuildModel(_trainingData,
                                           _attributeHeaders,
                                           _indexTargetAttribute);

            double[] data  = new double[] { _trainingData[0][0], _trainingData[1][0] };
            double   value = mb.RunModelForSingleData(data);

            Assert.AreEqual(value, _trainingData[2][0]);
        }
示例#8
0
        public override Common.MLCore.ModelBase BuildModel(
            double[][] trainingData,
            string[] attributeHeaders,
            int indexTargetAttribute)
        {
            //Verify data and set variables
            VerifyData(trainingData, attributeHeaders, indexTargetAttribute);

            ModelRandomForest model =
                new ModelRandomForest(_missingValue,
                                      _indexTargetAttribute,
                                      _trainingData.Length - 1,
                                      _numberOfTrees);

            //By default samples/tree is same as original samples
            if (_numberOfFeaturesPerTree == int.MaxValue)
            {
                _numberOfFeaturesPerTree = (int)System.Math.Ceiling(
                    System.Math.Sqrt((double)_trainingData.Length));
            }

            //Create for each
            //Parallelize this
            Parallel.For(0, _numberOfTrees,
                         new ParallelOptions {
                MaxDegreeOfParallelism = _maxParallelThreads
            }, ii =>
                         //for (int ii = 0; ii < _numberOfTrees; ii++)
            {
                //For test only
                BuildCART buildCart = new BuildCART();
                buildCart.SetParametersForRandomForest(_numberOfFeaturesPerTree);
                ModelCART modelCart = (ModelCART)buildCart.BuildModel(trainingData,
                                                                      attributeHeaders,
                                                                      indexTargetAttribute);
                model.AddTree(ii, modelCart);
            });//Number of trees

            return(model);
        }