/// <summary> /// Retrain the dnn model on new data. /// </summary> /// <param name="catalog"></param> /// <param name="inputColumnNames"> The names of the model inputs.</param> /// <param name="outputColumnNames">The names of the requested model outputs.</param> /// <param name="labelColumnName">Name of the label column.</param> /// <param name="dnnLabel">Name of the node in DNN graph that is used as label during training in Dnn. /// The value of <paramref name="labelColumnName"/> from <see cref="IDataView"/> is fed to this node.</param> /// <param name="optimizationOperation">The name of the optimization operation in the Dnn graph.</param> /// <param name="modelPath">Path to model file to retrain.</param> /// <param name="epoch">Number of training iterations.</param> /// <param name="batchSize">Number of samples to use for mini-batch training.</param> /// <param name="lossOperation">The name of the operation in the Dnn graph to compute training loss (Optional).</param> /// <param name="metricOperation">The name of the operation in the Dnn graph to compute performance metric during training (Optional).</param> /// <param name="learningRateOperation">The name of the operation in the Dnn graph which sets optimizer learning rate (Optional).</param> /// <param name="learningRate">Learning rate to use during optimization (Optional).</param> /// <param name="addBatchDimensionInput">Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3]. /// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well.</param> /// <param name="dnnFramework"></param> /// <remarks> /// The support for retraining is under preview. /// </remarks> public static DnnEstimator RetrainDnnModel( this ModelOperationsCatalog catalog, string[] outputColumnNames, string[] inputColumnNames, string labelColumnName, string dnnLabel, string optimizationOperation, string modelPath, int epoch = 10, int batchSize = 20, string lossOperation = null, string metricOperation = null, string learningRateOperation = null, float learningRate = 0.01f, bool addBatchDimensionInput = false, DnnFramework dnnFramework = DnnFramework.Tensorflow) { var options = new Options() { ModelLocation = modelPath, InputColumns = inputColumnNames, OutputColumns = outputColumnNames, LabelColumn = labelColumnName, TensorFlowLabel = dnnLabel, OptimizationOperation = optimizationOperation, LossOperation = lossOperation, MetricOperation = metricOperation, Epoch = epoch, LearningRateOperation = learningRateOperation, LearningRate = learningRate, BatchSize = batchSize, AddBatchDimensionInputs = addBatchDimensionInput, ReTrain = true }; var env = CatalogUtils.GetEnvironment(catalog); return(new DnnEstimator(env, options, DnnUtils.LoadDnnModel(env, modelPath, true))); }
/// <summary> /// Performs image classification using transfer learning. /// usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for more information. /// <format type="text/markdown"> /// <![CDATA[ /// [!include[io](~/../docs/samples/docs/api-reference/tensorflow-usage.md)] /// ]]> /// </format> /// </summary> /// <param name="catalog"></param> /// <param name="featuresColumnName">The name of the input features column.</param> /// <param name="labelColumnName">The name of the labels column.</param> /// <param name="scoreColumnName">The name of the output score column.</param> /// <param name="predictedLabelColumnName">The name of the output predicted label columns.</param> /// <param name="arch">The architecture of the image recognition DNN model.</param> /// <param name="epoch">Number of training iterations. Each iteration/epoch refers to one pass over the dataset.</param> /// <param name="batchSize">The batch size for training.</param> /// <param name="learningRate">The learning rate for training.</param> /// <param name="disableEarlyStopping">Whether to disable use of early stopping technique. Training will go on for the full epoch count.</param> /// <param name="earlyStopping">Early stopping technique parameters to be used to terminate training when training metric stops improving.</param> /// <param name="metricsCallback">Callback for reporting model statistics during training phase.</param> /// <param name="statisticFrequency">Indicates the frequency of epochs at which to report model statistics during training phase.</param> /// <param name="framework">Indicates the choice of DNN training framework. Currently only tensorflow is supported.</param> /// <param name="modelSavePath">Optional name of the path where a copy new graph should be saved. The graph will be saved as part of model.</param> /// <param name="finalModelPrefix">The name of the prefix for the final mode and checkpoint files.</param> /// <param name="validationSet">Validation set.</param> /// <param name="testOnTrainSet">Indicates to evaluate the model on train set after every epoch.</param> /// <param name="reuseTrainSetBottleneckCachedValues">Indicates to not re-compute cached trainset bottleneck values if already available in the bin folder.</param> /// <param name="reuseValidationSetBottleneckCachedValues">Indicates to not re-compute validataionset cached bottleneck validationset values if already available in the bin folder.</param> /// <param name="trainSetBottleneckCachedValuesFilePath">Indicates the file path to store trainset bottleneck values for caching.</param> /// <param name="validationSetBottleneckCachedValuesFilePath">Indicates the file path to store validationset bottleneck values for caching.</param> /// <remarks> /// The support for image classification is under preview. /// </remarks> public static ImageClassificationEstimator ImageClassification( this ModelOperationsCatalog catalog, string featuresColumnName, string labelColumnName, string scoreColumnName = "Score", string predictedLabelColumnName = "PredictedLabel", Architecture arch = Architecture.InceptionV3, int epoch = 100, int batchSize = 10, float learningRate = 0.01f, bool disableEarlyStopping = false, EarlyStopping earlyStopping = null, ImageClassificationMetricsCallback metricsCallback = null, int statisticFrequency = 1, DnnFramework framework = DnnFramework.Tensorflow, string modelSavePath = null, string finalModelPrefix = "custom_retrained_model_based_on_", IDataView validationSet = null, bool testOnTrainSet = true, bool reuseTrainSetBottleneckCachedValues = false, bool reuseValidationSetBottleneckCachedValues = false, string trainSetBottleneckCachedValuesFilePath = "trainSetBottleneckFile.csv", string validationSetBottleneckCachedValuesFilePath = "validationSetBottleneckFile.csv" ) { var options = new ImageClassificationEstimator.Options() { ModelLocation = ModelLocation[arch], InputColumns = new[] { featuresColumnName }, OutputColumns = new[] { scoreColumnName, predictedLabelColumnName }, LabelColumn = labelColumnName, TensorFlowLabel = labelColumnName, Epoch = epoch, LearningRate = learningRate, BatchSize = batchSize, EarlyStoppingCriteria = disableEarlyStopping ? null : earlyStopping == null ? new EarlyStopping() : earlyStopping, ScoreColumnName = scoreColumnName, PredictedLabelColumnName = predictedLabelColumnName, FinalModelPrefix = finalModelPrefix, Arch = arch, MetricsCallback = metricsCallback, StatisticsFrequency = statisticFrequency, Framework = framework, ModelSavePath = modelSavePath, ValidationSet = validationSet, TestOnTrainSet = testOnTrainSet, TrainSetBottleneckCachedValuesFilePath = trainSetBottleneckCachedValuesFilePath, ValidationSetBottleneckCachedValuesFilePath = validationSetBottleneckCachedValuesFilePath, ReuseTrainSetBottleneckCachedValues = reuseTrainSetBottleneckCachedValues, ReuseValidationSetBottleneckCachedValues = reuseValidationSetBottleneckCachedValues }; if (!File.Exists(options.ModelLocation)) { if (options.Arch == Architecture.InceptionV3) { var baseGitPath = @"https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta"; using (WebClient client = new WebClient()) { client.DownloadFile(new Uri($"{baseGitPath}"), @"InceptionV3.meta"); } baseGitPath = @"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/tfhub_modules.zip"; using (WebClient client = new WebClient()) { client.DownloadFile(new Uri($"{baseGitPath}"), @"tfhub_modules.zip"); ZipFile.ExtractToDirectory(Path.Combine(Directory.GetCurrentDirectory(), @"tfhub_modules.zip"), @"tfhub_modules"); } } else if (options.Arch == Architecture.ResnetV2101) { var baseGitPath = @"https://aka.ms/mlnet-resources/image/ResNet101Tensorflow/resnet_v2_101_299.meta"; using (WebClient client = new WebClient()) { client.DownloadFile(new Uri($"{baseGitPath}"), @"resnet_v2_101_299.meta"); } } else if (options.Arch == Architecture.MobilenetV2) { var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/MobileNetV2TensorFlow/mobilenet_v2.meta"; using (WebClient client = new WebClient()) { client.DownloadFile(new Uri($"{baseGitPath}"), @"mobilenet_v2.meta"); } } } var env = CatalogUtils.GetEnvironment(catalog); return(new ImageClassificationEstimator(env, options, DnnUtils.LoadDnnModel(env, options.ModelLocation, true))); }
/// <summary> /// Performs image classification using transfer learning. /// </summary> /// <param name="catalog"></param> /// <param name="featuresColumnName">The name of the input features column.</param> /// <param name="labelColumnName">The name of the labels column.</param> /// <param name="outputGraphPath">Optional name of the path where a copy new graph should be saved. The graph will be saved as part of model.</param> /// <param name="scoreColumnName">The name of the output score column.</param> /// <param name="predictedLabelColumnName">The name of the output predicted label columns.</param> /// <param name="checkpointName">The name of the prefix for checkpoint files.</param> /// <param name="arch">The architecture of the image recognition DNN model.</param> /// <param name="dnnFramework">The backend DNN framework to use, currently only Tensorflow is supported.</param> /// <param name="epoch">Number of training epochs.</param> /// <param name="batchSize">The batch size for training.</param> /// <param name="learningRate">The learning rate for training.</param> /// <remarks> /// The support for image classification is under preview. /// </remarks> public static DnnEstimator ImageClassification( this ModelOperationsCatalog catalog, string featuresColumnName, string labelColumnName, string outputGraphPath = null, string scoreColumnName = "Score", string predictedLabelColumnName = "PredictedLabel", string checkpointName = "_retrain_checkpoint", Architecture arch = Architecture.ResnetV2101, DnnFramework dnnFramework = DnnFramework.Tensorflow, int epoch = 10, int batchSize = 20, float learningRate = 0.01f) { var options = new Options() { ModelLocation = arch == Architecture.ResnetV2101 ? @"resnet_v2_101_299.meta" : @"InceptionV3.meta", InputColumns = new[] { featuresColumnName }, OutputColumns = new[] { scoreColumnName, predictedLabelColumnName }, LabelColumn = labelColumnName, TensorFlowLabel = labelColumnName, Epoch = epoch, LearningRate = learningRate, BatchSize = batchSize, AddBatchDimensionInputs = arch == Architecture.InceptionV3 ? false : true, TransferLearning = true, ScoreColumnName = scoreColumnName, PredictedLabelColumnName = predictedLabelColumnName, CheckpointName = checkpointName, Arch = arch, MeasureTrainAccuracy = false }; if (!File.Exists(options.ModelLocation)) { if (options.Arch == Architecture.InceptionV3) { var baseGitPath = @"https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta"; using (WebClient client = new WebClient()) { client.DownloadFile(new Uri($"{baseGitPath}"), @"InceptionV3.meta"); } baseGitPath = @"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/tfhub_modules.zip"; using (WebClient client = new WebClient()) { client.DownloadFile(new Uri($"{baseGitPath}"), @"tfhub_modules.zip"); ZipFile.ExtractToDirectory(Path.Combine(Directory.GetCurrentDirectory(), @"tfhub_modules.zip"), @"tfhub_modules"); } } else if (options.Arch == Architecture.ResnetV2101) { var baseGitPath = @"https://aka.ms/mlnet-resources/image/ResNet101Tensorflow/resnet_v2_101_299.meta"; using (WebClient client = new WebClient()) { client.DownloadFile(new Uri($"{baseGitPath}"), @"resnet_v2_101_299.meta"); } } } var env = CatalogUtils.GetEnvironment(catalog); return(new DnnEstimator(env, options, DnnUtils.LoadDnnModel(env, options.ModelLocation, true))); }