public void TrainAndPredictIrisModelUsingDirectInstantiationTest() { string dataPath = GetDataPath("iris.txt"); string testDataPath = dataPath; using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = TextLoader.ReadFile(env, new TextLoader.Arguments() { HasHeader = false, Column = new[] { new TextLoader.Column("Label", DataKind.R4, 0), new TextLoader.Column("SepalLength", DataKind.R4, 1), new TextLoader.Column("SepalWidth", DataKind.R4, 2), new TextLoader.Column("PetalLength", DataKind.R4, 3), new TextLoader.Column("PetalWidth", DataKind.R4, 4) } }, new MultiFileSource(dataPath)); IDataView pipeline = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(loader); // Normalizer is not automatically added though the trainer has 'NormalizeFeatures' On/Auto pipeline = NormalizeTransform.CreateMinMaxNormalizer(env, pipeline, "Features"); // Train var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { NumThreads = 1 }); // Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto var cached = new CacheDataView(env, pipeline, prefetch: null); var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); var pred = trainer.Train(trainRoles); // Get scorer and evaluate the predictions from test data IDataScorerTransform testDataScorer = GetScorer(env, pipeline, pred, testDataPath); var metrics = Evaluate(env, testDataScorer); CompareMatrics(metrics); // Create prediction engine and test predictions var model = env.CreatePredictionEngine <IrisData, IrisPrediction>(testDataScorer); ComparePredictions(model); // Get feature importance i.e. weight vector var summary = ((MulticlassLogisticRegressionPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema); Assert.Equal(7.757864, Convert.ToDouble(summary[0].Value), 5); } }
public void TensorFlowTransformMNISTConvTest() { var model_location = "mnist_model/frozen_saved_model.pb"; using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) { var dataPath = GetDataPath("Train-Tiny-28x28.txt"); var testDataPath = GetDataPath("MNIST.Test.tiny.txt"); // Pipeline var loader = TextLoader.ReadFile(env, new TextLoader.Arguments() { Separator = "tab", HasHeader = true, Column = new[] { new TextLoader.Column("Label", DataKind.Num, 0), new TextLoader.Column("Placeholder", DataKind.Num, new [] { new TextLoader.Range(1, 784) }) } }, new MultiFileSource(dataPath)); IDataView trans = CopyColumnsTransform.Create(env, new CopyColumnsTransform.Arguments() { Column = new[] { new CopyColumnsTransform.Column() { Name = "reshape_input", Source = "Placeholder" } } }, loader); trans = TensorFlowTransform.Create(env, trans, model_location, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" }); trans = new ConcatTransform(env, "Features", "Softmax", "dense/Relu").Transform(trans); var trainer = new LightGbmMulticlassTrainer(env, new LightGbmArguments()); var cached = new CacheDataView(env, trans, prefetch: null); var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); var pred = trainer.Train(trainRoles); // Get scorer and evaluate the predictions from test data IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); var metrics = Evaluate(env, testDataScorer); Assert.Equal(0.99, metrics.AccuracyMicro, 2); Assert.Equal(1.0, metrics.AccuracyMacro, 2); // Create prediction engine and test predictions var model = env.CreatePredictionEngine <MNISTData, MNISTPrediction>(testDataScorer); var sample1 = new MNISTData() { Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 18, 18, 18, 126, 136, 175, 26, 166, 255, 247, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253, 253, 253, 253, 253, 225, 172, 253, 242, 195, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49, 238, 253, 253, 253, 253, 253, 253, 253, 253, 251, 93, 82, 82, 56, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 219, 253, 253, 253, 253, 253, 198, 182, 247, 241, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 80, 156, 107, 253, 253, 205, 11, 0, 43, 154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 1, 154, 253, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241, 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81, 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148, 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253, 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253, 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195, 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }; var prediction = model.Predict(sample1); float max = -1; int maxIndex = -1; for (int i = 0; i < prediction.PredictedLabels.Length; i++) { if (prediction.PredictedLabels[i] > max) { max = prediction.PredictedLabels[i]; maxIndex = i; } } Assert.Equal(5, maxIndex); } }