public void TestEstimatorSaveLoad() { using (var env = new TlcEnvironment()) { var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); var pipe = new ImageLoaderEstimator(env, imageFolder, ("ImagePath", "ImageReal")) .Append(new ImageResizerEstimator(env, "ImageReal", "ImageReal", 100, 100)) .Append(new ImagePixelExtractorEstimator(env, "ImageReal", "ImagePixels")) .Append(new ImageGrayscaleEstimator(env, ("ImageReal", "ImageGray"))); pipe.GetOutputSchema(Core.Data.SchemaShape.Create(data.Schema)); var model = pipe.Fit(data); using (var file = env.CreateTempFile()) { using (var fs = file.CreateWriteStream()) model.SaveTo(env, fs); var model2 = TransformerChain.LoadFrom(env, file.OpenReadStream()); var newCols = ((ImageLoaderTransform)model2.First()).Columns; var oldCols = ((ImageLoaderTransform)model.First()).Columns; Assert.True(newCols .Zip(oldCols, (x, y) => x == y) .All(x => x)); } } Done(); }
public void TrainSaveModelAndPredict() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new TlcEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); var trans = TextTransform.Create(env, MakeSentimentTextTransformArgs(), loader); // Train var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }); var cached = new CacheDataView(env, trans, prefetch: null); var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); PredictionEngine <SentimentData, SentimentPrediction> model; using (var file = env.CreateTempFile()) { // Save model. var scoreRoles = new RoleMappedData(trans, label: "Label", feature: "Features"); using (var ch = env.Start("saving")) TrainUtils.SaveModel(env, ch, file, predictor, scoreRoles); // Load model. using (var fs = file.OpenReadStream()) model = env.CreatePredictionEngine <SentimentData, SentimentPrediction>(fs); } // Take a couple examples out of the test data and run predictions on top. var testLoader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(SentimentTestPath))); var testData = testLoader.AsEnumerable <SentimentData>(env, false); foreach (var input in testData.Take(5)) { var prediction = model.Predict(input); // Verify that predictions match and scores are separated from zero. Assert.Equal(input.Sentiment, prediction.Sentiment); Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); } } }
public void New_TrainSaveModelAndPredict() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new TlcEnvironment(seed: 1, conc: 1)) { var reader = new TextLoader(env, MakeSentimentTextLoaderArgs()); var data = reader.Read(new MultiFileSource(dataPath)); // Pipeline. var pipeline = new TextTransform(env, "SentimentText", "Features") .Append(new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label")); // Train. var model = pipeline.Fit(data); ITransformer loadedModel; using (var file = env.CreateTempFile()) { // Save model. using (var fs = file.CreateWriteStream()) model.SaveTo(env, fs); // Load model. loadedModel = TransformerChain.LoadFrom(env, file.OpenReadStream()); } // Create prediction engine and test predictions. var engine = loadedModel.MakePredictionFunction <SentimentData, SentimentPrediction>(env); // Take a couple examples out of the test data and run predictions on top. var testData = reader.Read(new MultiFileSource(GetDataPath(SentimentTestPath))) .AsEnumerable <SentimentData>(env, false); foreach (var input in testData.Take(5)) { var prediction = engine.Predict(input); // Verify that predictions match and scores are separated from zero. Assert.Equal(input.Sentiment, prediction.Sentiment); Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); } } }