static void Main() { var pipeline = new LearningPipeline { new TextLoader(FileName).CreateFrom <AgeRange>(separator: ',', useHeader: true), new Dictionarizer("Label"), new TextFeaturizer("Gender", "Gender"), new ColumnConcatenator("Features", "Age", "Gender"), new StochasticDualCoordinateAscentClassifier(), new PredictedLabelColumnOriginalValueConverter { PredictedLabelColumn = "PredictedLabel" } }; var model = pipeline.Train <AgeRange, AgeRangePrediction>(); var converter = new OnnxConverter { Onnx = OnnxPath, Json = OnnxAsJsonPath, Domain = "com.elbruno" }; converter.Convert(model); // Strip the version. var fileText = File.ReadAllText(OnnxAsJsonPath); fileText = Regex.Replace(fileText, "\"producerVersion\": \"([^\"]+)\"", "\"producerVersion\": \"##VERSION##\""); File.WriteAllText(OnnxAsJsonPath, fileText); }
public void WordEmbeddingsTest() { string dataPath = GetDataPath(@"small-sentiment-test.tsv"); var pipeline = new Legacy.LearningPipeline(0); pipeline.Add(new Legacy.Data.TextLoader(dataPath) { Arguments = new TextLoaderArguments { Separator = new[] { '\t' }, HasHeader = false, Column = new[] { new TextLoaderColumn() { Name = "Cat", Source = new [] { new TextLoaderRange(0, 3) }, Type = Legacy.Data.DataKind.TX }, } } }); var modelPath = GetDataPath(@"shortsentiment.emd"); var embed = new WordEmbeddings() { CustomLookupTable = modelPath }; embed.AddColumn("Cat", "Cat"); pipeline.Add(embed); var model = pipeline.Train <EmbeddingsData, EmbeddingsResult>(); var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "WordEmbeddings"); var onnxPath = GetOutputPath(subDir, "WordEmbeddings.onnx"); DeleteOutputPath(onnxPath); var onnxAsJsonPath = GetOutputPath(subDir, "WordEmbeddings.json"); DeleteOutputPath(onnxAsJsonPath); OnnxConverter converter = new OnnxConverter() { Onnx = onnxPath, Json = onnxAsJsonPath, Domain = "Onnx" }; converter.Convert(model); var fileText = File.ReadAllText(onnxAsJsonPath); fileText = Regex.Replace(fileText, "\"producerVersion\": \"([^\"]+)\"", "\"producerVersion\": \"##VERSION##\""); File.WriteAllText(onnxAsJsonPath, fileText); CheckEquality(subDir, "WordEmbeddings.json"); Done(); }
public void BinaryClassificationSaveModelToOnnxTest() { string dataPath = GetDataPath(@"breast-cancer.txt"); var pipeline = new LearningPipeline(); pipeline.Add(new Data.TextLoader(dataPath) { Arguments = new TextLoaderArguments { Separator = new[] { '\t' }, HasHeader = true, Column = new[] { new TextLoaderColumn() { Name = "Label", Source = new [] { new TextLoaderRange(0) }, Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "Features", Source = new [] { new TextLoaderRange(1, 9) }, Type = Data.DataKind.Num } } } }); pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); var model = pipeline.Train <BreastCancerData, BreastCancerPrediction>(); var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "BinaryClassification", "BreastCancer"); var onnxPath = GetOutputPath(subDir, "SaveModelToOnnxTest.pb"); DeleteOutputPath(onnxPath); var onnxAsJsonPath = GetOutputPath(subDir, "SaveModelToOnnxTest.json"); DeleteOutputPath(onnxAsJsonPath); OnnxConverter converter = new OnnxConverter() { InputsToDrop = new[] { "Label" }, OutputsToDrop = new[] { "Label", "Features" }, Onnx = onnxPath, Json = onnxAsJsonPath, Domain = "Onnx" }; converter.Convert(model); CheckEquality(subDir, "SaveModelToOnnxTest.json"); Done(); }
public void KmeansTest() { string dataPath = GetDataPath(@"breast-cancer.txt"); var pipeline = new Legacy.LearningPipeline(0); pipeline.Add(new Legacy.Data.TextLoader(dataPath) { Arguments = new TextLoaderArguments { Separator = new[] { '\t' }, HasHeader = true, Column = new[] { new TextLoaderColumn() { Name = "Features", Source = new [] { new TextLoaderRange(1, 9) }, Type = Legacy.Data.DataKind.R4 }, } } }); pipeline.Add(new KMeansPlusPlusClusterer() { K = 2, MaxIterations = 1, NumThreads = 1, InitAlgorithm = KMeansPlusPlusTrainerInitAlgorithm.Random }); var model = pipeline.Train <BreastNumericalColumns, BreastCancerClusterPrediction>(); var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Cluster", "BreastCancer"); var onnxPath = GetOutputPath(subDir, "Kmeans.onnx"); DeleteOutputPath(onnxPath); var onnxAsJsonPath = GetOutputPath(subDir, "Kmeans.json"); DeleteOutputPath(onnxAsJsonPath); OnnxConverter converter = new OnnxConverter() { Onnx = onnxPath, Json = onnxAsJsonPath, Domain = "Onnx" }; converter.Convert(model); // Strip the version. var fileText = File.ReadAllText(onnxAsJsonPath); fileText = Regex.Replace(fileText, "\"producerVersion\": \"([^\"]+)\"", "\"producerVersion\": \"##VERSION##\""); File.WriteAllText(onnxAsJsonPath, fileText); CheckEquality(subDir, "Kmeans.json"); Done(); }
public static void SaveToOnnx(PredictionModel model) { //Sauvegarde sous format ONNX (pas encore fonctionnel pour multiclass à la version 0.3) OnnxConverter converter = new OnnxConverter() { InputsToDrop = new[] { "Label" }, OutputsToDrop = new[] { "Label", "Features" }, Onnx = _onnxPath, Domain = "Onnx" }; converter.Convert(model); }
private static void ConvertToOnnx(PredictionModel model) { try { OnnxConverter converter = new OnnxConverter() { InputsToDrop = new[] { "Label" }, OutputsToDrop = new[] { "Label", "Features" }, Onnx = _onnxPath, Json = _onnxAsJsonPath, Domain = "com.mydomain" }; converter.Convert(model); // Strip the version. var fileText = File.ReadAllText(_onnxAsJsonPath); fileText = Regex.Replace(fileText, "\"producerVersion\": \"([^\"]+)\"", "\"producerVersion\": \"##VERSION##\""); File.WriteAllText(_onnxAsJsonPath, fileText); } catch (Exception e) { System.Console.WriteLine(e); } }
private void Export() { var onnxPath = "./SaveModelToOnnxTest.onnx"; var onnxAsJsonPath = "./SaveModelToOnnxTest.json"; OnnxConverter converter = new OnnxConverter() { InputsToDrop = new[] { "Label" }, OutputsToDrop = new[] { "Label", "Features" }, Onnx = onnxPath, Json = onnxAsJsonPath, Domain = "com.mydomain" }; converter.Convert(_model); // Strip the version. var fileText = File.ReadAllText(onnxAsJsonPath); fileText = Regex.Replace(fileText, "\"producerVersion\": \"([^\"]+)\"", "\"producerVersion\": \"##VERSION##\""); File.WriteAllText(onnxAsJsonPath, fileText); }
public void MultiClassificationLRSaveModelToOnnxTest() { string dataPath = GetDataPath(@"breast-cancer.txt"); var pipeline = new Legacy.LearningPipeline(); pipeline.Add(new Legacy.Data.TextLoader(dataPath) { Arguments = new TextLoaderArguments { Separator = new[] { '\t' }, HasHeader = true, Column = new[] { new TextLoaderColumn() { Name = "Label", Source = new [] { new TextLoaderRange(0) }, Type = Legacy.Data.DataKind.Num }, new TextLoaderColumn() { Name = "Features", Source = new [] { new TextLoaderRange(1, 9) }, Type = Legacy.Data.DataKind.Num } } } }); pipeline.Add(new Dictionarizer("Label")); pipeline.Add(new LogisticRegressionClassifier() { UseThreads = false }); var model = pipeline.Train <BreastCancerDataAllColumns, BreastCancerMCPrediction>(); var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "MultiClassClassification", "BreastCancer"); var onnxPath = GetOutputPath(subDir, "MultiClassificationLRSaveModelToOnnxTest.onnx"); DeleteOutputPath(onnxPath); var onnxAsJsonPath = GetOutputPath(subDir, "MultiClassificationLRSaveModelToOnnxTest.json"); DeleteOutputPath(onnxAsJsonPath); OnnxConverter converter = new OnnxConverter() { InputsToDrop = new[] { "Label" }, OutputsToDrop = new[] { "Label", "Features" }, Onnx = onnxPath, Json = onnxAsJsonPath, Domain = "Onnx" }; converter.Convert(model); // Strip the version. var fileText = File.ReadAllText(onnxAsJsonPath); fileText = Regex.Replace(fileText, "\"producerVersion\": \"([^\"]+)\"", "\"producerVersion\": \"##VERSION##\""); File.WriteAllText(onnxAsJsonPath, fileText); CheckEquality(subDir, "MultiClassificationLRSaveModelToOnnxTest.json"); Done(); }
public void KeyToVectorWithBagTest() { string dataPath = GetDataPath(@"breast-cancer.txt"); var pipeline = new Legacy.LearningPipeline(); pipeline.Add(new Legacy.Data.TextLoader(dataPath) { Arguments = new TextLoaderArguments { Separator = new[] { '\t' }, HasHeader = true, Column = new[] { new TextLoaderColumn() { Name = "Label", Source = new [] { new TextLoaderRange(0) }, Type = Legacy.Data.DataKind.Num }, new TextLoaderColumn() { Name = "F1", Source = new [] { new TextLoaderRange(1, 1) }, Type = Legacy.Data.DataKind.Num }, new TextLoaderColumn() { Name = "F2", Source = new [] { new TextLoaderRange(2, 2) }, Type = Legacy.Data.DataKind.TX } } } }); var vectorizer = new CategoricalOneHotVectorizer(); var categoricalColumn = new CategoricalTransformColumn() { OutputKind = CategoricalTransformOutputKind.Bag, Name = "F2", Source = "F2" }; vectorizer.Column = new CategoricalTransformColumn[1] { categoricalColumn }; pipeline.Add(vectorizer); pipeline.Add(new ColumnConcatenator("Features", "F1", "F2")); pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 2, NumTrees = 1, MinDocumentsInLeafs = 2 }); var model = pipeline.Train <BreastCancerData, BreastCancerPrediction>(); var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "BinaryClassification", "BreastCancer"); var onnxPath = GetOutputPath(subDir, "KeyToVectorBag.onnx"); DeleteOutputPath(onnxPath); var onnxAsJsonPath = GetOutputPath(subDir, "KeyToVectorBag.json"); DeleteOutputPath(onnxAsJsonPath); OnnxConverter converter = new OnnxConverter() { InputsToDrop = new[] { "Label" }, OutputsToDrop = new[] { "Label", "F1", "F2", "Features" }, Onnx = onnxPath, Json = onnxAsJsonPath, Domain = "Onnx" }; converter.Convert(model); // Strip the version. var fileText = File.ReadAllText(onnxAsJsonPath); fileText = Regex.Replace(fileText, "\"producerVersion\": \"([^\"]+)\"", "\"producerVersion\": \"##VERSION##\""); File.WriteAllText(onnxAsJsonPath, fileText); CheckEquality(subDir, "KeyToVectorBag.json"); Done(); }
public async Task <string> TrainAsync(string trainpath, bool writeToDisk = true) { // pipeline encapsulates the data loading, data processing/featurization, and learning algorithm var pipeline = new LearningPipeline { // load from CSV --> SubCategory, Description, Bank, Amount, new TextLoader(trainpath).CreateFrom <BankStatementLineItem>(separator: ',', useHeader: true), //Converts input values (words, numbers, etc.) to index in a dictionary. new Dictionarizer(("SubCategory", "Label")), // convert the data columns to the feature. For that TextFeaturizer // ngram analysis over the transaction description new TextFeaturizer("Description", "Description") { TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, WordFeatureExtractor = new NGramNgramExtractor { // Term frequency -- the number of times that term t occurs in document d Weighting = NgramTransformWeightingCriteria.Tf } }, new TextFeaturizer("Bank", "Bank") { TextCase = TextNormalizerTransformCaseNormalizationMode.Lower }, // feature column using bank and description new ColumnConcatenator("Features", "Bank", "Description"), //******************************************************************** // classifiers //******************************************************************** //new NaiveBayesClassifier(), new StochasticDualCoordinateAscentClassifier { Shuffle = false, NumThreads = 1 }, //new LightGbmClassifier(), //******************************************************************** //Transforms a predicted label column to its original values, unless it is of type bool new PredictedLabelColumnOriginalValueConverter { PredictedLabelColumn = "PredictedLabel" } }; //******************************************************************** // training //******************************************************************** Console.WriteLine("=============== Start training ==============="); var watch = Stopwatch.StartNew(); _model = pipeline.Train <BankStatementLineItem, PredictedLabel>(); watch.Stop(); Console.WriteLine($"=============== End training ==============="); Console.WriteLine($"training took {watch.ElapsedMilliseconds} milliseconds"); Console.WriteLine("The model is saved to {0}", PredictionModelWrapper.Model1Path); //******************************************************************** var converter = new OnnxConverter { Onnx = PredictionModelWrapper.Model1Path, Json = PredictionModelWrapper.Model1Path.Replace(".onnx", ".json"), Domain = "onnx" }; converter.Convert(_model); if (writeToDisk) { await _model.WriteAsync(PredictionModelWrapper.Model1Path); // Strip the version. var fileText = File.ReadAllText(converter.Json); fileText = Regex.Replace(fileText, "\"producerVersion\": \"([^\"]+)\"", "\"producerVersion\": \"##VERSION##\""); File.WriteAllText(converter.Json, fileText); } return(PredictionModelWrapper.Model1Path); }
public static void Main(string[] args) { Utils.PrintConsoleMessage("Starting Baseball HOF Prediction", true); var currentDirectory = Directory.GetCurrentDirectory(); var onnxPath = Path.Combine(currentDirectory, "baseballhof-model.onnx"); var onnxAsJsonPath = Path.Combine(currentDirectory, "baseballhof-model.json"); // MCC Evaluation metric double mcc = 0.0; double mccNumerator = 0.0, mccDenominator = 0.0; Console.WriteLine("Run Naive baseball data set? (Type Y or N and then press [Enter])"); string consoleResponse = Console.ReadLine(); // Training & Validation/Dev text CSV files var trainingDataPath = "HOFTrainingNaive.txt"; var validationDataPath = "HOFValidationNaive.txt"; if (consoleResponse.ToUpper() != "Y") { trainingDataPath = "HOFTraining.txt"; validationDataPath = "HOFValidation.txt"; } // 1) Create a new learning pipeline var pipeline = new LearningPipeline(); // 2) Add a Text Loader var trainingLoader = new Microsoft.ML.Legacy.Data.TextLoader(trainingDataPath).CreateFrom <BaseballData>(allowQuotedStrings: false, separator: ','); pipeline.Add(trainingLoader); // 3) Create Features pipeline.Add(new ColumnConcatenator("Features", "YearsPlayed", "AB", "R", "H", "Doubles", "Triples", "HR", "RBI", "SB", "AllStarAppearances", "MVPs", "TripleCrowns", "GoldGloves", "MajorLeaguePlayerOfTheYearAwards", "TB")); // pipeline.Add(new ColumnConcatenator("Features", "YearsPlayed")); // 4) Create new binary classifier (predict yes/no into Baseball HOF) var classifier = new FastTreeBinaryClassifier() { NumLeaves = 10, NumTrees = 60, MinDocumentsInLeafs = 2, BaggingSize = 5, AllowEmptyTrees = true, Caching = CachingOptions.Memory, OptimizationAlgorithm = BoostedTreeArgsOptimizationAlgorithmType.GradientDescent }; // var classifier = new LinearSvmBinaryClassifier(); // var classifier = new FastForestBinaryClassifier(); // var classifier = new GeneralizedAdditiveModelBinaryClassifier(); //var classifier = new LightGbmBinaryClassifier //{ // NumLeaves = 5, // NumBoostRound = 5, // MinDataPerLeaf = 2 //}; // Add the classifier to the pipeline pipeline.Add(classifier); // 5) Train Model var model = pipeline.Train <BaseballData, BaseballDataPrediction>(); // 6) Sample Predictions // Bad Player with poor historical numbers var samplePredictionBadPlayer = new BaseballData { AB = 3000, AllStarAppearances = 0, FullPlayerName = "Bad Player", R = 90, H = 300, Doubles = 30, Triples = 30, HR = 30, RBI = 60, SB = 15, BattingAverage = 0.1f, SluggingPct = 0.25f, PlayerID = 10101, Label = false, MVPs = 0, GoldGloves = 0, MajorLeaguePlayerOfTheYearAwards = 0, TripleCrowns = 0, YearsPlayed = 3, TB = 570 }; var result = model.Predict(samplePredictionBadPlayer); Console.WriteLine("Bad Baseball Player Prediction"); Console.WriteLine("******************************"); Console.WriteLine("HOF Prediction: " + result.PredictedLabel.ToString() + " | " + "Probability: " + Math.Round(result.ProbabilityLabel, 7)); Console.WriteLine(); // Great Player with great historical numbers worthy of HOF var samplePredictionGreatPlayer = new BaseballData { AB = 10000, AllStarAppearances = 12, FullPlayerName = "Great Player", R = 1100, H = 3200, Doubles = 450, Triples = 150, HR = 600, RBI = 1200, SB = 400, BattingAverage = 0.32f, SluggingPct = 0.55f, PlayerID = 20202, Label = true, MVPs = 3, GoldGloves = 8, MajorLeaguePlayerOfTheYearAwards = 4, TripleCrowns = 2, YearsPlayed = 22, TB = 6700 }; var greatPlayerPrediction = model.Predict(samplePredictionGreatPlayer); Console.WriteLine("Great Baseball Player Prediction"); Console.WriteLine("******************************"); Console.WriteLine("HOF Prediction: " + greatPlayerPrediction.PredictedLabel.ToString() + " | " + "Probability: " + greatPlayerPrediction.ProbabilityLabel); Console.WriteLine(); Console.WriteLine(); // 7) Load Evaluation Data var testData = new Microsoft.ML.Legacy.Data.TextLoader(validationDataPath).CreateFrom <BaseballData>(allowQuotedStrings: false, separator: ','); // 8) Evaluate trained model with test data var evaluator = new BinaryClassificationEvaluator() { ProbabilityColumn = "Probability" }; var metrics = evaluator.Evaluate(model, testData); // build a list of False Positives - Players not in the HOF, predicted by classifier to be in HOF var falsePostivePlayers = new List <Tuple <BaseballData, BaseballDataPrediction> >(); // build a list of False Negatives - Players IN THE HOF, predicted by classifier not to be in HOF var falseNegativePlayers = new List <Tuple <BaseballData, BaseballDataPrediction> >(); // build a list of True Positives - Players IN THE HOF, predicted by classifier to be in HOF var truePositivePlayers = new List <Tuple <BaseballData, BaseballDataPrediction> >(); // build a list of True Negataives - Players not in the HOF, predicted by classifier not to be in HOF var trueNegativePlayers = new List <Tuple <BaseballData, BaseballDataPrediction> >(); using (var environment = new LocalEnvironment()) { // note: custom schema not needed anymore // var customSchema = "col=Label:BL:0 col=FullPlayerName:TX:1 col=YearsPlayed:R4:2 col=AB:R4:3 col=R:R4:4 col=H:R4:5 col=Doubles:R4:6 col=Triples:R4:7 col=HR:R4:8 col=RBI:R4:9 col=SB:R4:10 col=BattingAverage:R4:11 col=SluggingPct:R4:12 col=AllStarAppearances:R4:13 col=MVPs:R4:14 col=TripleCrowns:R4:15 col=GoldGloves:R4:16 col=MajorLeaguePlayerOfTheYearAwards:R4:17 col=TB:R4:18 col=LastYearPlayed:R4:19 col=PlayerID:R4:20 Separator=,"; var loader = new Microsoft.ML.Legacy.Data.TextLoader(validationDataPath).CreateFrom <BaseballData>(allowQuotedStrings: false, separator: ','); Experiment experiment = environment.CreateExperiment(); ILearningPipelineDataStep output = loader.ApplyStep(null, experiment) as ILearningPipelineDataStep; experiment.Compile(); loader.SetInput(environment, experiment); experiment.Run(); IDataView data = experiment.GetOutput(output.Data); using (var cursor = data.GetRowCursor(col => true)) { cursor.Schema.TryGetColumnIndex("Label", out int labelCol); cursor.Schema.TryGetColumnIndex("FullPlayerName", out int fullPlayerNameCol); cursor.Schema.TryGetColumnIndex("YearsPlayed", out int yearsPlayedCol); cursor.Schema.TryGetColumnIndex("AB", out int abCol); cursor.Schema.TryGetColumnIndex("R", out int rCol); cursor.Schema.TryGetColumnIndex("H", out int hCol); cursor.Schema.TryGetColumnIndex("Doubles", out int doublesCol); cursor.Schema.TryGetColumnIndex("Triples", out int triplesCol); cursor.Schema.TryGetColumnIndex("HR", out int hrCol); cursor.Schema.TryGetColumnIndex("RBI", out int rbiCol); cursor.Schema.TryGetColumnIndex("SB", out int sbCol); cursor.Schema.TryGetColumnIndex("AllStarAppearances", out int allStarAppearancesCol); cursor.Schema.TryGetColumnIndex("MVPs", out int mvpsCol); cursor.Schema.TryGetColumnIndex("TripleCrowns", out int tripleCrownsCol); cursor.Schema.TryGetColumnIndex("GoldGloves", out int goldGlovesCol); cursor.Schema.TryGetColumnIndex("MajorLeaguePlayerOfTheYearAwards", out int majorLeaguePlayerOfTheYearAwardsCol); cursor.Schema.TryGetColumnIndex("TB", out int tbCol); cursor.Schema.TryGetColumnIndex("PlayerID", out int idCol); while (cursor.MoveNext()) { // Label var labelGetter = cursor.GetGetter <bool>(labelCol); var label = default(bool); labelGetter(ref label); // Full Player Name var fullPlayerNameGetter = cursor.GetGetter <ReadOnlyMemory <char> >(fullPlayerNameCol); var fullPlayerName = default(ReadOnlyMemory <char>); fullPlayerNameGetter(ref fullPlayerName); // Years Played var yearsPlayedGetter = cursor.GetGetter <float>(yearsPlayedCol); var yearsPlayed = 0f; yearsPlayedGetter(ref yearsPlayed); // AB var abGetter = cursor.GetGetter <float>(abCol); var ab = 0f; abGetter(ref ab); // R var rGetter = cursor.GetGetter <float>(rCol); var r = 0f; rGetter(ref r); // H var hGetter = cursor.GetGetter <float>(hCol); var h = 0f; hGetter(ref h); // Doubles var doublesGetter = cursor.GetGetter <float>(doublesCol); float doubles = 0f; doublesGetter(ref doubles); // Triples var triplesGetter = cursor.GetGetter <float>(triplesCol); float triples = 0f; triplesGetter(ref triples); // HR var hrGetter = cursor.GetGetter <float>(hrCol); float hr = 0f; hrGetter(ref hr); // RBI var rbiGetter = cursor.GetGetter <float>(rbiCol); float rbi = 0f; rbiGetter(ref rbi); // SB var sbGetter = cursor.GetGetter <float>(sbCol); float sb = 0f; sbGetter(ref sb); // AllStarAppearances var allStarAppearancesGetter = cursor.GetGetter <float>(allStarAppearancesCol); float allStarAppearances = 0f; allStarAppearancesGetter(ref allStarAppearances); // MVPs var mvpsGetter = cursor.GetGetter <float>(mvpsCol); float mvps = 0f; mvpsGetter(ref mvps); // Tiple Crowns var tripleCrownsGetter = cursor.GetGetter <float>(tripleCrownsCol); float tripleCrowns = 0f; tripleCrownsGetter(ref tripleCrowns); // Gold Gloves var goldGlovesGetter = cursor.GetGetter <float>(goldGlovesCol); float goldGloves = 0f; goldGlovesGetter(ref goldGloves); // MajorLeaguePlayerOfTheYearAwards var majorLeaguePlayerOfTheYearAwardsGetter = cursor.GetGetter <float>(majorLeaguePlayerOfTheYearAwardsCol); float majorLeaguePlayerOfTheYearAwards = 0f; majorLeaguePlayerOfTheYearAwardsGetter(ref majorLeaguePlayerOfTheYearAwards); // TB var tbGetter = cursor.GetGetter <float>(tbCol); float tb = 0f; tbGetter(ref tb); // PlayerID column var idGetter = cursor.GetGetter <float>(idCol); float id = 0f; idGetter(ref id); var baseBallData = new BaseballData() { AB = ab, AllStarAppearances = allStarAppearances, Doubles = doubles, FullPlayerName = fullPlayerName.ToString(), H = h, HR = hr, GoldGloves = goldGloves, PlayerID = id, Label = label, //label.IsTrue ? true : false, MajorLeaguePlayerOfTheYearAwards = majorLeaguePlayerOfTheYearAwards, MVPs = mvps, R = r, RBI = rbi, SB = sb, TB = tb, TripleCrowns = tripleCrowns, Triples = triples, YearsPlayed = yearsPlayed }; var prediction = model.Predict(baseBallData); // True Positives if ((prediction.PredictedLabel == true) && (baseBallData.Label == true)) { truePositivePlayers.Add(new Tuple <BaseballData, BaseballDataPrediction>(baseBallData, prediction)); } // True Negatives if ((prediction.PredictedLabel == false) && (baseBallData.Label == false)) { trueNegativePlayers.Add(new Tuple <BaseballData, BaseballDataPrediction>(baseBallData, prediction)); } // False Positive Prediction if ((prediction.PredictedLabel == true) && (baseBallData.Label == false)) { falsePostivePlayers.Add(new Tuple <BaseballData, BaseballDataPrediction>(baseBallData, prediction)); } else // False Negative Prediction if ((prediction.PredictedLabel == false) && (baseBallData.Label == true)) { falseNegativePlayers.Add(new Tuple <BaseballData, BaseballDataPrediction>(baseBallData, prediction)); } } } } // 9) Print out Metrics (rounded to 4 decimals) mccNumerator = truePositivePlayers.Count * trueNegativePlayers.Count - falsePostivePlayers.Count * falseNegativePlayers.Count; mccDenominator = Math.Sqrt( 1.0 * (truePositivePlayers.Count + falsePostivePlayers.Count) * (truePositivePlayers.Count + falseNegativePlayers.Count) * (trueNegativePlayers.Count + falsePostivePlayers.Count) * (trueNegativePlayers.Count + falseNegativePlayers.Count) ); mcc = mccNumerator / mccDenominator; //Console.WriteLine(mccNumerator); //Console.WriteLine(mccDenominator); Console.WriteLine("******************"); Console.WriteLine("Evaluation Metrics"); Console.WriteLine("******************"); Console.WriteLine("AUC Score: " + Math.Round(metrics.Auc, 4).ToString()); Console.WriteLine("Precision: " + Math.Round(metrics.PositivePrecision, 4).ToString()); Console.WriteLine("Recall: " + Math.Round(metrics.PositiveRecall, 4).ToString()); Console.WriteLine("Accuracy: " + Math.Round(metrics.Accuracy, 4).ToString()); Console.WriteLine("MCC: " + Math.Round(mcc, 4).ToString()); Console.WriteLine("******************"); Console.WriteLine(); Console.WriteLine("******************"); Console.WriteLine("True Positives"); Console.WriteLine("******************"); for (int i = 0; i != truePositivePlayers.Count; i++) { var player = truePositivePlayers[i].Item1; var playerPrediction = truePositivePlayers[i].Item2; Console.WriteLine(player.ToString() + "Prob: " + playerPrediction.ProbabilityLabel.ToString()); } Console.WriteLine(); Console.WriteLine("******************"); Console.WriteLine("False Positives"); Console.WriteLine("******************"); for (int i = 0; i != falsePostivePlayers.Count; i++) { var player = falsePostivePlayers[i].Item1; var playerPrediction = falsePostivePlayers[i].Item2; Console.WriteLine(player.ToString() + "Prob: " + playerPrediction.ProbabilityLabel.ToString()); } Console.WriteLine(); Console.WriteLine("******************"); Console.WriteLine("False Negatives"); Console.WriteLine("******************"); for (int i = 0; i != falseNegativePlayers.Count; i++) { var player = falseNegativePlayers[i].Item1; var playerPrediction = falseNegativePlayers[i].Item2; Console.WriteLine(player.ToString() + "Prob: " + playerPrediction.ProbabilityLabel.ToString()); } //10) Persist trained model var modelPath = Path.Combine(currentDirectory, "baseballhof-model.mlnet"); model.WriteAsync(modelPath).Wait(); //11) Convert to ONNX & persist // will only work for FastTree, LightGBM, Logistic Regression // var onnxPath = Path.Combine(currentDirectory, "baseballhof-model.onnx"); //var onnxAsJsonPath = Path.Combine(currentDirectory, "baseballhof-model.json"); OnnxConverter converter = new OnnxConverter() { InputsToDrop = new[] { "Label" }, OutputsToDrop = new[] { "Label", "Features" }, Onnx = onnxPath, Json = onnxAsJsonPath, Domain = "com.baseballsample" }; // converter.Convert(model); }