public void BuildAndTrain() { var featurizerModelLocation = inputModelLocation; ConsoleWriteHeader("Read model"); Console.WriteLine($"Model location: {featurizerModelLocation}"); Console.WriteLine($"Images folder: {imagesFolder}"); Console.WriteLine($"Training file: {dataLocation}"); Console.WriteLine($"Default parameters: image size=({ImageNetSettings.imageWidth},{ImageNetSettings.imageHeight}), image mean: {ImageNetSettings.mean}"); var loader = new TextLoader(env, new TextLoader.Arguments { Column = new[] { new TextLoader.Column("ImagePath", DataKind.Text, 0), new TextLoader.Column("Label", DataKind.Text, 1) } }); var pipeline = new ValueToKeyMappingEstimator(env, "Label", "LabelTokey") .Append(new ImageLoadingEstimator(env, imagesFolder, ("ImagePath", "ImageReal"))) .Append(new ImageResizingEstimator(env, "ImageReal", "ImageReal", ImageNetSettings.imageHeight, ImageNetSettings.imageWidth)) .Append(new ImagePixelExtractingEstimator(env, new[] { new ImagePixelExtractorTransform.ColumnInfo("ImageReal", "input", interleave: ImageNetSettings.channelsLast, offset: ImageNetSettings.mean) })) .Append(new TensorFlowEstimator(env, featurizerModelLocation, new[] { "input" }, new[] { "softmax2_pre_activation" })) .Append(new SdcaMultiClassTrainer(env, "softmax2_pre_activation", "LabelTokey")) .Append(new KeyToValueEstimator(env, ("PredictedLabel", "PredictedLabelValue"))); // Train the pipeline ConsoleWriteHeader("Training classification model"); var data = loader.Read(new MultiFileSource(dataLocation)); var model = pipeline.Fit(data); // Process the training data through the model // This is an optional step, but it's useful for debugging issues var trainData = model.Transform(data); var loadedModelOutputColumnNames = trainData.Schema.GetColumnNames(); var trainData2 = trainData.AsEnumerable <ImageNetPipeline>(env, false, true).ToList(); trainData2.ForEach(pr => ConsoleWriteImagePrediction(pr.ImagePath, pr.PredictedLabelValue, pr.Score.Max())); // Get some performance metric on the model using training data var sdcaContext = new MulticlassClassificationContext(env); ConsoleWriteHeader("Classification metrics"); var metrics = sdcaContext.Evaluate(trainData, label: "LabelTokey", predictedLabel: "PredictedLabel"); Console.WriteLine($"LogLoss is: {metrics.LogLoss}"); Console.WriteLine($"PerClassLogLoss is: {String.Join(" , ", metrics.PerClassLogLoss.Select(c => c.ToString()))}"); // Save the model to assets/outputs ConsoleWriteHeader("Save model to local file"); ModelHelpers.DeleteAssets(outputModelLocation); using (var f = new FileStream(outputModelLocation, FileMode.Create)) model.SaveTo(env, f); Console.WriteLine($"Model saved: {outputModelLocation}"); }
public void KeyToValueWorkout() { string dataPath = GetDataPath("iris.txt"); var reader = new TextLoader(Env, new TextLoader.Arguments { Column = new[] { new TextLoader.Column("ScalarString", DataKind.TX, 1), new TextLoader.Column("VectorString", DataKind.TX, new[] { new TextLoader.Range(1, 4) }), new TextLoader.Column { Name = "BareKey", Source = new[] { new TextLoader.Range(0) }, Type = DataKind.U4, KeyRange = new KeyRange(0, 5), } } }); var data = reader.Read(dataPath); data = new ValueToKeyMappingEstimator(Env, new[] { new TermTransform.ColumnInfo("ScalarString", "A"), new TermTransform.ColumnInfo("VectorString", "B") }).Fit(data).Transform(data); var badData1 = new CopyColumnsTransform(Env, ("BareKey", "A")).Transform(data); var badData2 = new CopyColumnsTransform(Env, ("VectorString", "B")).Transform(data); var est = new KeyToValueEstimator(Env, ("A", "A_back"), ("B", "B_back")); TestEstimatorCore(est, data, invalidInput: badData1); TestEstimatorCore(est, data, invalidInput: badData2); var outputPath = GetOutputPath("KeyToValue", "featurized.tsv"); using (var ch = Env.Start("save")) { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); IDataView savedData = est.Fit(data).Transform(data); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); } CheckEquality("KeyToValue", "featurized.tsv"); Done(); }