コード例 #1
0
        /// <summary>
        /// Removes an existing <see cref="ILearningPipelineItem"/>ILearningPipelineItem</see> item.
        /// </summary>
        /// <param name="item"></param>
        /// <returns></returns>
        public IFluentLearningPipeline <TInput, TOutput> Remove(ILearningPipelineItem item)
        {
            if (item is null)
            {
                throw new ArgumentNullException(nameof(item));
            }

            if (HasPipeline)
            {
                _pipeline.Remove(item);
            }

            return(this);
        }
コード例 #2
0
        public void AppendPipeline()
        {
            var pipeline = new LearningPipeline();

            pipeline.Append(new CategoricalOneHotVectorizer("String1", "String2"))
            .Append(new ColumnConcatenator(outputColumn: "Features", "String1", "String2", "Number1", "Number2"))
            .Append(new StochasticDualCoordinateAscentRegressor());
            Assert.NotNull(pipeline);
            Assert.Equal(3, pipeline.Count);

            pipeline.Remove(pipeline.ElementAt(2));
            Assert.Equal(2, pipeline.Count);

            pipeline.Append(new StochasticDualCoordinateAscentRegressor());
            Assert.Equal(3, pipeline.Count);
        }
コード例 #3
0
        public void CanAddAndRemoveFromPipeline()
        {
            var pipeline = new LearningPipeline()
            {
                new Transforms.CategoricalOneHotVectorizer("String1", "String2"),
                new Transforms.ColumnConcatenator(outputColumn: "Features", "String1", "String2", "Number1", "Number2"),
                new Trainers.StochasticDualCoordinateAscentRegressor()
            };

            Assert.NotNull(pipeline);
            Assert.Equal(3, pipeline.Count);

            pipeline.Remove(pipeline.ElementAt(2));
            Assert.Equal(2, pipeline.Count);

            pipeline.Add(new Trainers.StochasticDualCoordinateAscentRegressor());
            Assert.Equal(3, pipeline.Count);
        }
コード例 #4
0
        static async Task Train()
        {
            //パイプラインの作成
            var pipeline = new LearningPipeline();

            //訓練データの読み込み
            var trainingSets = new TextLoader <TitanicData>(trainSetPath, useHeader: true, separator: ",");

            pipeline.Add(trainingSets);

            //年齢が欠損値の行を捨てる
            pipeline.Add(new MissingValuesRowDropper()
            {
                Column = new string[] { "Age" }
            });

            //数値でない変数をOneHotVectorにする
            pipeline.Add(new CategoricalOneHotVectorizer("Sex", "Embarked"));

            //モデルに使う変数を結合する
            pipeline.Add(new ColumnConcatenator("Features",
                                                "Pclass", "Sex", "Age", "SibSp", "Parch", "Fare", "Embarked"));

            //交差検証データ
            var cvSets = new TextLoader <TitanicData>(crossValidationSetPath, useHeader: true, separator: ",");
            //グリッドサーチ用
            var n_trees  = new int[] { 2, 4, 8, 16, 32, 64, 128 };
            var n_leaves = new int[] { 2, 4, 8, 16, 32, 64, 128 };
            //最も良い精度
            var bestf1 = 0.0;
            //最も良い分類器
            FastForestBinaryClassifier bestClassifier = null;
            //二値分類の評価
            var evaluator = new BinaryClassificationEvaluator();

            foreach (var nt in n_trees)
            {
                foreach (var nl in n_leaves)
                {
                    //ランダムフォレストで二値分類
                    var classifier = new FastForestBinaryClassifier()
                    {
                        NumTrees  = nt,
                        NumLeaves = nl
                    };
                    pipeline.Add(classifier);
                    //訓練
                    var model = pipeline.Train <TitanicData, TitanicPrediction>();

                    //F1スコア
                    var metrics = evaluator.Evaluate(model, cvSets);
                    Console.WriteLine($"#tree = {nt}, #leaf = {nl}, cv_f1={metrics.F1Score}");
                    if (!double.IsNaN(metrics.F1Score) && metrics.F1Score > bestf1)
                    {
                        Console.WriteLine($"[!]Classifier Updated {bestf1} -> {metrics.F1Score} / nt : {nt}, nl : {nl}");
                        bestf1         = metrics.F1Score;
                        bestClassifier = classifier;
                    }

                    //パイプラインから一旦分類器削除
                    pipeline.Remove(classifier);
                }
            }

            //グリッドサーチの結果から最も良いモデルを選択してパイプラインに追加
            pipeline.Add(bestClassifier);

            //訓練
            var bestModel = pipeline.Train <TitanicData, TitanicPrediction>();
            //訓練誤差
            var trainMetrics = evaluator.Evaluate(bestModel, trainingSets);
            //交差検証誤差
            var cvMetrics = evaluator.Evaluate(bestModel, cvSets);
            //テストデータ
            var testSets = new TextLoader <TitanicData>(testSetPath, useHeader: true, separator: ",");
            //テスト誤差
            var testMetrics = evaluator.Evaluate(bestModel, testSets);

            //モデルの保存
            await bestModel.WriteAsync("model.zip");

            //結果表示
            Console.WriteLine("### Result ###");
            Console.WriteLine("- Selected Classifier");
            Console.WriteLine($"NumTree={bestClassifier.NumTrees}, NumLeaves={bestClassifier.NumLeaves}");
            Console.WriteLine("- Trian Sets");
            Console.WriteLine($"Accuracy={trainMetrics.Accuracy:P2}, " +
                              $"Precision={trainMetrics.PositivePrecision:P2}, " +
                              $"Recall={trainMetrics.PositiveRecall:P2}, " +
                              $"F1score={trainMetrics.F1Score:P2}");
            Console.WriteLine("- Cross Validation Sets");
            Console.WriteLine($"Accuracy={cvMetrics.Accuracy:P2}, " +
                              $"Precision={cvMetrics.PositivePrecision:P2}, " +
                              $"Recall={cvMetrics.PositiveRecall:P2}, " +
                              $"F1score={cvMetrics.F1Score:P2}");
            Console.WriteLine("- Test Sets");
            Console.WriteLine($"Accuracy={testMetrics.Accuracy:P2}, " +
                              $"Precision={testMetrics.PositivePrecision:P2}, " +
                              $"Recall={testMetrics.PositiveRecall:P2}, " +
                              $"F1score={testMetrics.F1Score:P2}");

            /*
             ### Result ###
             ###- Selected Classifier
             ###NumTree=4, NumLeaves=32
             ###- Trian Sets
             ###Accuracy=83.45%, Precision=83.78%, Recall=72.94%, F1score=77.99%
             ###- Cross Validation Sets
             ###Accuracy=80.28%, Precision=86.67%, Recall=63.93%, F1score=73.58%
             ###- Test Sets
             ###Accuracy=86.58%, Precision=86.79%, Recall=77.97%, F1score=82.14%
             */
        }