public override void Test() { var wizard = new ModelWizard(); var task = wizard.AddImageClassificationTask <TransferLearning>(new TaskOptions { DataDir = @"image_classification_v1\flower_photos", ModelPath = @"image_classification_v1\saved_model.pb" }); var result = task.Test(); accuracy = result.Accuracy; }
/// <summary> /// Prediction /// labels mapping, it's from output_lables.txt /// 0 - daisy /// 1 - dandelion /// 2 - roses /// 3 - sunflowers /// 4 - tulips /// </summary> public override void Predict() { // predict image var wizard = new ModelWizard(); var task = wizard.AddImageClassificationTask <TransferLearning>(new TaskOptions { ModelPath = @"image_classification_v1\saved_model.pb" }); var imgPath = Path.Join("image_classification_v1", "flower_photos", "daisy", "5547758_eea9edfd54_n.jpg"); var input = ImageUtil.ReadImageFromFile(imgPath); var result = task.Predict(input); Debug.Assert(result.Label == "daisy"); }
public override void Train() { // get a set of images to teach the network about the new classes string fileName = "flower_photos.tgz"; string dataDir = "image_classification_v1"; string url = $"http://download.tensorflow.org/example_images/{fileName}"; Web.Download(url, dataDir, fileName); Compress.ExtractTGZ(Path.Join(dataDir, fileName), dataDir); // using wizard to train model var wizard = new ModelWizard(); var task = wizard.AddImageClassificationTask <TransferLearning>(new TaskOptions { DataDir = @"image_classification_v1\flower_photos", }); task.Train(new TrainingOptions()); }