コード例 #1
0
        public void SaveAndLoadModel()
        {
            ///
            /// Train a model
            ///
            var xgbTrainer          = new XGBoost.XGBClassifier();
            int countTrainingPoints = 20;

            entity.XGBArray trainClass1         = Util.GenerateRandom2dPoints(countTrainingPoints / 2, -1.0, 0.0, 0.0, 1.0, 0.0); //Top  left quadrant
            entity.XGBArray trainClass2         = Util.GenerateRandom2dPoints(countTrainingPoints / 2, 0.0, 1.0, -1.0, 0.0, 1.0); //Bot right quadrant
            entity.XGBArray train_Class1_Class2 = Util.UnionOfXGBArray(trainClass1, trainClass2);
            xgbTrainer.Fit(train_Class1_Class2.Vectors, train_Class1_Class2.Labels);
            ///
            /// Save the model
            ///
            string fileModel = "MyLinearModel.dat";

            if (System.IO.File.Exists(fileModel))
            {
                System.IO.File.Delete(fileModel);
            }
            xgbTrainer.SaveModelToFile(fileModel);
            ///
            /// Load the saved model
            ///
            var xgbProduction      = XGBoost.XGBClassifier.LoadClassifierFromFile(fileModel);
            int countTestingPoints = 50;

            entity.XGBArray testClass1         = Util.GenerateRandom2dPoints(countTestingPoints / 2, -0.8, -0.2, 0.2, 0.8, 0.0); //Top  left quadrant
            entity.XGBArray testClass2         = Util.GenerateRandom2dPoints(countTestingPoints / 2, 0.2, 0.8, -0.8, -0.2, 1.0); //Bot right quadrant
            entity.XGBArray test_Class1_Class2 = Util.UnionOfXGBArray(testClass1, testClass2);
            var             results            = xgbProduction.Predict(test_Class1_Class2.Vectors);

            CollectionAssert.AreEqual(results, test_Class1_Class2.Labels);
        }
コード例 #2
0
        public void TrainAndTestIris()
        {
            ///
            /// Load training vectors
            ///
            string filenameTrain = "Iris\\Iris.train.data";

            iris.Iris[] recordsTrain = IrisUtils.LoadIris(filenameTrain);
            entity.XGVector <iris.Iris>[] vectorsTrain = IrisUtils.ConvertFromIrisToFeatureVectors(recordsTrain);
            ///
            /// Load testingvectors
            ///
            string filenameTest = "Iris\\Iris.test.data";

            iris.Iris[] recordsTest = IrisUtils.LoadIris(filenameTest);
            entity.XGVector <iris.Iris>[] vectorsTest = IrisUtils.ConvertFromIrisToFeatureVectors(recordsTest);

            int noOfClasses = 3;
            var xgbc        = new XGBoost.XGBClassifier(objective: "multi:softprob", numClass: 3);

            entity.XGBArray arrTrain = Util.ConvertToXGBArray(vectorsTrain);
            entity.XGBArray arrTest  = Util.ConvertToXGBArray(vectorsTest);
            xgbc.Fit(arrTrain.Vectors, arrTrain.Labels);
            var outcomeTest = xgbc.Predict(arrTest.Vectors);

            for (int index = 0; index < arrTest.Vectors.Length; index++)
            {
                string  sExpected  = IrisUtils.ConvertLabelFromNumericToString(arrTest.Labels[index]);
                float[] arrResults = new float[]
                {
                    outcomeTest[index * noOfClasses + 0],
                    outcomeTest[index * noOfClasses + 1],
                    outcomeTest[index * noOfClasses + 2]
                };
                float  max = arrResults.Max();
                int    indexWithMaxValue = Util.GetIndexWithMaxValue(arrResults);
                string sActualClass      = IrisUtils.ConvertLabelFromNumericToString((float)indexWithMaxValue);
                Trace.WriteLine($"{index}       Expected={sExpected}        Actual={sActualClass}");
                Assert.AreEqual(sActualClass, sExpected);
            }
            string pathFull = System.IO.Path.Combine(Util.GetProjectDir2(), _fileModelIris);

            xgbc.SaveModelToFile(pathFull);
        }