Exemplo n.º 1
0
        public static async Task <PredictionModel <TaxiTrip, TaxiTripFarePrediction> > TrainModel()
        {
            var pipeline = new LearningPipeline();

            var loadData         = new TextLoader(_dataPath).CreateFrom <TaxiTrip>(useHeader: true, separator: ',');
            var copyLabels       = new ColumnCopier(("FareAmount", "Label"));
            var convertToNumeric = new CategoricalOneHotVectorizer("VendorId", "RateCode", "PaymentType");
            var features         = new ColumnConcatenator("Features",
                                                          "VendorId", "RateCode", "PassengerCount", "TripDistance", "PaymentType");

            pipeline.Add(loadData);
            pipeline.Add(copyLabels);
            pipeline.Add(convertToNumeric);
            pipeline.Add(features);

            pipeline.Add(new FastTreeRegressor());

            var model = pipeline.Train <TaxiTrip, TaxiTripFarePrediction>();
            await model.WriteAsync(_modelPath);

            return(model);
        }
Exemplo n.º 2
0
        public void KeyToVectorWithBagTest()
        {
            string dataPath = GetDataPath(@"breast-cancer.txt");
            var    pipeline = new Legacy.LearningPipeline();

            pipeline.Add(new Legacy.Data.TextLoader(dataPath)
            {
                Arguments = new TextLoaderArguments
                {
                    Separator = new[] { '\t' },
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoaderColumn()
                        {
                            Name   = "Label",
                            Source = new [] { new TextLoaderRange(0) },
                            Type   = Legacy.Data.DataKind.Num
                        },

                        new TextLoaderColumn()
                        {
                            Name   = "F1",
                            Source = new [] { new TextLoaderRange(1, 1) },
                            Type   = Legacy.Data.DataKind.Num
                        },

                        new TextLoaderColumn()
                        {
                            Name   = "F2",
                            Source = new [] { new TextLoaderRange(2, 2) },
                            Type   = Legacy.Data.DataKind.TX
                        }
                    }
                }
            });

            var vectorizer        = new CategoricalOneHotVectorizer();
            var categoricalColumn = new CategoricalTransformColumn()
            {
                OutputKind = CategoricalTransformOutputKind.Bag, Name = "F2", Source = "F2"
            };

            vectorizer.Column = new CategoricalTransformColumn[1] {
                categoricalColumn
            };
            pipeline.Add(vectorizer);
            pipeline.Add(new ColumnConcatenator("Features", "F1", "F2"));
            pipeline.Add(new FastTreeBinaryClassifier()
            {
                NumLeaves = 2, NumTrees = 1, MinDocumentsInLeafs = 2
            });

            var model    = pipeline.Train <BreastCancerData, BreastCancerPrediction>();
            var subDir   = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "BinaryClassification", "BreastCancer");
            var onnxPath = GetOutputPath(subDir, "KeyToVectorBag.onnx");

            DeleteOutputPath(onnxPath);

            var onnxAsJsonPath = GetOutputPath(subDir, "KeyToVectorBag.json");

            DeleteOutputPath(onnxAsJsonPath);

            OnnxConverter converter = new OnnxConverter()
            {
                InputsToDrop  = new[] { "Label" },
                OutputsToDrop = new[] { "Label", "F1", "F2", "Features" },
                Onnx          = onnxPath,
                Json          = onnxAsJsonPath,
                Domain        = "Onnx"
            };

            converter.Convert(model);

            // Strip the version.
            var fileText = File.ReadAllText(onnxAsJsonPath);

            fileText = Regex.Replace(fileText, "\"producerVersion\": \"([^\"]+)\"", "\"producerVersion\": \"##VERSION##\"");
            File.WriteAllText(onnxAsJsonPath, fileText);

            CheckEquality(subDir, "KeyToVectorBag.json");
            Done();
        }