示例#1
0
        public void Train()
        {
            var featurizerModelLocation = inputModelLocation;

            ConsoleHelpers.ConsoleWriteHeader("Read model");
            Console.WriteLine($"Model location: {featurizerModelLocation}");
            Console.WriteLine($"Images folder: {imagesFolder}");
            Console.WriteLine($"Default parameters: image size=({ImageSettings.imageWidth},{ImageSettings.imageHeight}), image mean: {ImageSettings.mean}");

            // Get the training data sample images
            ConsoleHelpers.ConsoleWriteHeader("Collecting sample training data");

            var data = mlContext.Data.LoadFromEnumerable(DataHelper.ReadFromFolder(imagesFolder));

            Images.SummarizeTrainingData(imagesFolder);

            // Train the model
            ConsoleHelpers.ConsoleWriteHeader("Training classification model");

            var pipeline = mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: LabelTokey, inputColumnName: "Label")
                           .Append(mlContext.Transforms.LoadImages(outputColumnName: "input", imageFolder: "", inputColumnName: nameof(Images.ImagePath)))
                           .Append(mlContext.Transforms.ResizeImages(outputColumnName: "input", imageWidth: ImageSettings.imageWidth, imageHeight: ImageSettings.imageHeight, inputColumnName: "input"))
                           .Append(mlContext.Transforms.ExtractPixels(outputColumnName: "input", interleavePixelColors: ImageSettings.channelsLast, offsetImage: ImageSettings.mean))
                           .Append(mlContext.Model.LoadTensorFlowModel(inputModelLocation).
                                   ScoreTensorFlowModel(outputColumnNames: new[] { "softmax2_pre_activation" }, inputColumnNames: new[] { "input" }, addBatchDimensionInput: true))
                           .Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: LabelTokey, featureColumnName: "softmax2_pre_activation"))
                           .Append(mlContext.Transforms.Conversion.MapKeyToValue(PredictedLabelValue, "PredictedLabel"))
                           .AppendCacheCheckpoint(mlContext);

            ITransformer model = pipeline.Fit(data);

            // Save the model to assets/outputs
            ConsoleHelpers.ConsoleWriteHeader("Save model to local file");

            mlContext.Model.Save(model, data.Schema, outputModelLocation);

            Console.WriteLine($"Model saved: {outputModelLocation}");

            // Get some performance metrics on the model
            var trainData             = model.Transform(data);
            var classificationContext = mlContext.MulticlassClassification;

            ConsoleHelpers.ConsoleWriteHeader("Evaluating classification metrics");
            var metrics = classificationContext.Evaluate(trainData, labelColumnName: LabelTokey, predictedLabelColumnName: "PredictedLabel");

            Console.WriteLine($"Total Log Loss is: {metrics.LogLoss}");
            Console.WriteLine($"Per Class LogLoss is: {String.Join(" , ", metrics.PerClassLogLoss.Select(c => c.ToString()))}");
        }
示例#2
0
        static void Main()
        {
            string assetsPath = Path.Combine(new FileInfo(typeof(Program).Assembly.Location).Directory.FullName, "assets");

            var    trainingImages       = Path.GetFullPath("../../../assets/inputs/train"); // "C:\\repo\\Image-classification-transfer-learning\\grocery\\train2\\";
            var    testingImages        = Path.GetFullPath("../../../../ImageClassification.Test/assets/test/thor");
            var    featurizerModel      = Path.Combine(assetsPath, "inputs", "inception", "tensorflow_inception_graph.pb");
            string trainedModelLocation = Path.GetFullPath("../../../imageClassifier.zip");

            //Train model
            try
            {
                var modelBuilder = new ModelBuilder(trainingImages, featurizerModel, trainedModelLocation);
                modelBuilder.Train();
            }
            catch (Exception ex)
            {
                ConsoleHelpers.ConsoleWriteException(ex.Message);
            }

            //Test model
            ConsoleHelpers.ConsoleWriteHeader("Test model with few sample images");
            try
            {
                ConsoleHelpers.ConsoleWriteHeader("Load saved model");
                TrainedModel model = new TrainedModel(trainedModelLocation);

                ImagePrediction         prediction = new ImagePrediction();
                List <PredictionResult> results    = new List <PredictionResult>();

                List <ImageData> imageList = DataHelper.ReadFromFolder(testingImages);
                foreach (ImageData image in imageList)
                {
                    prediction = model.predictor.Predict(image);

                    PredictionResult result = new PredictionResult(Path.GetFileName(image.ImagePath).ToString(), prediction.PredictedLabelValue, prediction.Score.Max());
                    results.Add(result);
                }

                var output = JsonConvert.SerializeObject(results, Formatting.Indented);
                Console.WriteLine(output);
            }
            catch (Exception ex)
            {
                ConsoleHelpers.ConsoleWriteException(ex.Message);
            }
        }
示例#3
0
        static void Main(string[] args)
        {
            string assetsRelativePath = @"../../../assets";
            string assetsPath         = GetAbsolutePath(assetsRelativePath);

            var tagsTsv      = Path.Combine(assetsPath, "inputs", "zakladki", "image_list.tsv");
            var imagesFolder = Path.Combine(assetsPath, "inputs", "zakladki", "images");
            var labelsTxt    = Path.Combine(assetsPath, "inputs", "zakladki", "labels.txt");
            var pathToModel  = @"D:\Files\GitHub\BinaryImageClassifier\BinaryImageClassifier\models\mobileNetV2\1564031768";

            try
            {
                var modelScorer = new TFModelScorer(tagsTsv, imagesFolder, pathToModel, labelsTxt);
                modelScorer.Score();
            }
            catch (Exception ex)
            {
                ConsoleHelpers.ConsoleWriteException(ex.ToString());
            }

            ConsoleHelpers.ConsolePressAnyKey();
        }
        static void Main(string[] args)
        {
            string assetsRelativePath = @"../../../assets";
            string assetsPath         = GetAbsolutePath(assetsRelativePath);

            var tagsTsv      = Path.Combine(assetsPath, "inputs", "images", "tags.tsv");
            var imagesFolder = Path.Combine(assetsPath, "inputs", "images");
            var inceptionPb  = Path.Combine(assetsPath, "inputs", "inception", "tensorflow_inception_graph.pb");
            var labelsTxt    = Path.Combine(assetsPath, "inputs", "inception", "imagenet_comp_graph_label_strings.txt");

            try
            {
                var modelScorer = new TFModelScorer(tagsTsv, imagesFolder, inceptionPb, labelsTxt);
                modelScorer.Score();
            }
            catch (Exception ex)
            {
                ConsoleHelpers.ConsoleWriteException(ex.ToString());
            }

            ConsoleHelpers.ConsolePressAnyKey();
        }
示例#5
0
        static void Main(string[] args)
        {
            var assetsPath = ModelHelpers.GetAssetsPath(@"..\..\..\assets");

            var tagsTsv      = Path.Combine(assetsPath, "inputs", "images", "tags.tsv");
            var imagesFolder = Path.Combine(assetsPath, "inputs", "images");
            var inceptionPb  = Path.Combine(assetsPath, "inputs", "inception", "tensorflow_inception_graph.pb");
            var labelsTxt    = Path.Combine(assetsPath, "inputs", "inception", "imagenet_comp_graph_label_strings.txt");

            var customInceptionPb = Path.Combine(assetsPath, "inputs", "inception_custom", "model_tf.pb");
            var customLabelsTxt   = Path.Combine(assetsPath, "inputs", "inception_custom", "labels.txt");

            try
            {
                var modelScorer = new TFModelScorer(tagsTsv, imagesFolder, inceptionPb, labelsTxt);
                modelScorer.Score();
            }
            catch (Exception ex)
            {
                ConsoleHelpers.ConsoleWriteException(ex.Message);
            }

            ConsoleHelpers.ConsolePressAnyKey();
        }