Exemple #1
0
        public void ClassLabelGenerationBasicTest()
        {
            var columns = new TextLoader.Column[]
            {
                new TextLoader.Column()
                {
                    Name = "Label", Source = new TextLoader.Range[] { new TextLoader.Range(0) }, DataKind = DataKind.Boolean
                },
            };

            var result = new ColumnInferenceResults()
            {
                TextLoaderOptions = new TextLoader.Options()
                {
                    Columns        = columns,
                    AllowQuoting   = false,
                    AllowSparse    = false,
                    Separators     = new[] { ',' },
                    HasHeader      = true,
                    TrimWhitespace = true
                },
                ColumnInformation = new ColumnInformation()
            };

            CodeGenerator codeGenerator = new CodeGenerator(null, result, null);
            var           actual        = codeGenerator.GenerateClassLabels();
            var           expected1     = "[ColumnName(\"Label\"), LoadColumn(0)]";
            var           expected2     = "public bool Label{get; set;}";

            Assert.Equal(expected1, actual[0]);
            Assert.Equal(expected2, actual[1]);
        }
        /// <summary>
        /// Convert a type description into another type description.
        /// </summary>

        public static DataViewType Convert(TextLoader.Column col, IChannel ch = null)
        {
            if (col.Source != null && col.Source.Length > 0)
            {
                if (col.Source.Length != 1)
                {
                    throw Contracts.ExceptNotImpl("Convert of TextLoader.Column is not implemented for more than on range.");
                }
                if (col.Source[0].ForceVector)
                {
                    if (!col.Source[0].Max.HasValue)
                    {
                        throw ch != null?ch.Except("A vector column needs a dimension")
                                  : Contracts.Except("A vector column needs a dimension");
                    }
                    int delta   = col.Source[0].Max.Value - col.Source[0].Min + 1;
                    var colType = DataKind2ColumnType(col.Type, ch);
                    return(new VectorDataViewType(colType.AsPrimitive(), delta));
                }
            }
            if (col.KeyCount != null)
            {
                var r = col.KeyCount;
                return(new KeyDataViewType(DataKind2ColumnType(col.Type).RawType, r.Count.HasValue ? r.Count.Value : 0));
            }
            else
            {
                return(DataKind2ColumnType(col.Type, ch));
            }
        }
        public static TextLoader.Column[] GenerateLoaderColumns(Column[] columns)
        {
            var loaderColumns = new List <TextLoader.Column>();

            foreach (var col in columns)
            {
                var loaderColumn = new TextLoader.Column(col.SuggestedName, col.ItemType.GetRawKind().ToDataKind(), col.ColumnIndex);
                loaderColumns.Add(loaderColumn);
            }
            return(loaderColumns.ToArray());
        }
        public MockConnection(MLContext context, DatabaseLoader databaseLoader)
        {
            var outputSchema  = databaseLoader.GetOutputSchema();
            var readerColumns = new TextLoader.Column[outputSchema.Count];

            for (int i = 0; i < outputSchema.Count; i++)
            {
                var column     = outputSchema[i];
                var columnType = column.Type.RawType;

                Assert.True(columnType.TryGetDataKind(out var internalDataKind));
                readerColumns[i] = new TextLoader.Column(column.Name, internalDataKind.ToDataKind(), i);
            }

            _textLoader = context.Data.CreateTextLoader(readerColumns);
        }
        public MockConnection(MLContext context, DatabaseLoader.Column[] columns)
        {
            Columns = columns;

            var readerColumns = new TextLoader.Column[columns.Length];

            for (int i = 0; i < columns.Length; i++)
            {
                var column     = columns[i];
                var columnType = column.Type.ToType();

                Assert.True(columnType.TryGetDataKind(out var internalDataKind));
                readerColumns[i] = new TextLoader.Column(column.Name, internalDataKind.ToDataKind(), i);
            }

            _reader = context.Data.CreateTextLoader(readerColumns);
        }
Exemple #6
0
        public List <T> GetRecords <T>(MemoryStream stream) where T : ICsvReadable, new()
        {
            // this library only allows loading from a file.
            // so write to a local file, use the length of the memory stream
            // to write to a different file based on the input data
            // this will be executed during the first "warmup" run
            var file = "data" + stream.Length + ".csv";

            if (!File.Exists(file))
            {
                using var data = File.Create(file);
                stream.CopyTo(data);
            }

            var activate   = ActivatorFactory.Create <T>(_activationMethod);
            var allRecords = new List <T>();
            var mlc        = new MLContext();

            using (var reader = new StreamReader(stream))
            {
                var schema = new TextLoader.Column[25];
                for (int i = 0; i < schema.Length; i++)
                {
                    schema[i] = new TextLoader.Column("" + i, DataKind.String, i);
                }

                var opts = new TextLoader.Options()
                {
                    HasHeader = false, Separators = new[] { ',' }, Columns = schema
                };
                var l       = mlc.Data.LoadFromTextFile(file, opts);
                var rc      = l.GetRowCursor(l.Schema);
                var cols    = l.Schema.ToArray();
                var getters = cols.Select(c => rc.GetGetter <ReadOnlyMemory <char> >(c)).ToArray();
                while (rc.MoveNext())
                {
                    var record = activate();
                    record.Read(i => { ReadOnlyMemory <char> s = null; getters[i](ref s); return(s.ToString()); });
                    allRecords.Add(record);
                }
            }

            return(allRecords);
        }
        /// <summary>
        /// Reads a text file as a IDataView.
        /// Follows pandas API.
        /// </summary>
        /// <param name="filename">filename</param>
        /// <param name="sep">column separator</param>
        /// <param name="header">has a header or not</param>
        /// <param name="names">column names (can be empty)</param>
        /// <param name="dtypes">column types (can be empty)</param>
        /// <param name="nrows">number of rows to read</param>
        /// <param name="guess_rows">number of rows used to guess types</param>
        /// <param name="encoding">text encoding</param>
        /// <param name="useThreads">specific to TextLoader</param>
        /// <param name="host">host</param>
        /// <param name="index">add a column to hold the index</param>
        /// <returns>TextLoader</returns>
        public static IDataView ReadCsvToTextLoader(string[] filenames,
                                                    char sep          = ',', bool header = true,
                                                    string[] names    = null, DataViewType[] dtypes = null,
                                                    int nrows         = -1, int guess_rows    = 10,
                                                    Encoding encoding = null, bool useThreads = true,
                                                    bool index        = false, IHost host     = null)
        {
            var df = ReadCsv(filenames[0], sep: sep, header: header, names: names, dtypes: dtypes,
                             nrows: guess_rows, guess_rows: guess_rows, encoding: encoding, index: index);
            var sch  = df.Schema;
            var cols = new TextLoader.Column[sch.Count];

            for (int i = 0; i < cols.Length; ++i)
            {
                cols[i] = TextLoader.Column.Parse(df.NameType(i));
                if (cols[i] == null)
                {
                    throw Contracts.Except("Unable to parse '{0}'.", df.NameType(i));
                }
            }

            var args = new TextLoader.Options()
            {
                AllowQuoting   = false,
                Separators     = new[] { sep },
                Columns        = cols,
                TrimWhitespace = true,
                UseThreads     = useThreads,
                HasHeader      = header,
                MaxRows        = nrows > 0 ? (int?)nrows : null
            };

            if (host == null)
            {
                host = new ConsoleEnvironment().Register("TextLoader");
            }
            var multiSource = new MultiFileSource(filenames);

            return(new TextLoader(host, args, multiSource).Load(multiSource));
        }
Exemple #8
0
        /// <summary>
        /// Convert a type description into another type description.
        /// </summary>

        public static ColumnType Convert(TextLoader.Column col, IChannel ch = null)
        {
            if (!col.Type.HasValue)
            {
                throw ch != null?ch.Except("Kind is null") : Contracts.Except("kind is null");
            }
            if (col.Source != null && col.Source.Length > 0)
            {
                if (col.Source.Length != 1)
                {
                    throw Contracts.ExceptNotImpl("Convert of TextLoader.Column is not implemented for more than on range.");
                }
                if (col.Source[0].ForceVector)
                {
                    if (!col.Source[0].Max.HasValue)
                    {
                        throw ch != null?ch.Except("A vector column needs a dimension")
                                  : Contracts.Except("A vector column needs a dimension");
                    }
                    int delta   = col.Source[0].Max.Value - col.Source[0].Min + 1;
                    var colType = DataKind2ColumnType(col.Type.Value, ch);
                    return(new VectorType(colType.AsPrimitive(), delta));
                }
            }
            if (col.KeyRange != null)
            {
                var r = col.KeyRange;
                return(new KeyType(col.Type.HasValue ? col.Type.Value.ToType() : null, r.Min,
                                   r.Max.HasValue ? (int)(r.Max.Value - r.Min + 1) : 0,
                                   r.Contiguous));
            }
            else
            {
                return(DataKind2ColumnType(col.Type.Value, ch));
            }
        }
Exemple #9
0
        public static ReturnResult <dynamic> PredictModel(IDatabase db, PredictionInput input)
        {
            var results  = new ReturnResult <dynamic>();
            var fileName = Guid.NewGuid().ToString();

            try
            {
                var model     = GetModelById(db, input.ModelId);
                var modelInfo = ModelField.GetFieldsByModelId(db, model.Item.ModelId);
                var modelFile = FileStore.GetById(db, model.Item.FileStoreId.Value);

                MLContext      mlContext = new MLContext();
                DataViewSchema predictionPipelineSchema;

                IDataView LoadedData = null;

                var columnData = new List <TextLoader.Column>();
                var curItem    = 0;
                foreach (var c in modelInfo.Item)
                {
                    var newColData = new TextLoader.Column()
                    {
                        DataKind = (DataKind)c.DataTypeId,
                        Name     = c.Name,
                        Source   = new TextLoader.Range[] { new TextLoader.Range(curItem) }
                    };

                    columnData.Add(newColData);
                    curItem++;
                }

                File.WriteAllText(Path.GetTempPath() + fileName, input.CsvData);

                LoadedData = mlContext.Data.LoadFromTextFile(
                    Path.GetTempPath() + fileName,
                    columnData.ToArray(),
                    separatorChar: ',',
                    hasHeader: false,
                    allowQuoting: true
                    );

                var outputColumn = modelInfo.Item.Single(x => x.IsOutput == true);

                using (MemoryStream ms = new MemoryStream(modelFile.Item.Data))
                {
                    var       predictionPipeline = mlContext.Model.Load(ms, out predictionPipelineSchema);
                    IDataView predictions        = predictionPipeline.Transform(LoadedData);

                    if ((DataKind)outputColumn.DataTypeId == DataKind.Single)
                    {
                        var      hasLabel      = predictions.Schema.GetColumnOrNull("PredictedLabel");
                        Single[] predictionOut = null;
                        if (hasLabel == null)
                        {
                            predictionOut = predictions.GetColumn <Single>("Score").ToArray();
                        }
                        else
                        {
                            predictionOut = predictions.GetColumn <Single>("PredictedLabel").ToArray();
                        }
                        results.Item = predictionOut[0];
                    }
                    else if ((DataKind)outputColumn.DataTypeId == DataKind.String)
                    {
                        var predictionOut = predictions.GetColumn <String>("PredictedLabel").ToArray();
                        results.Item = predictionOut[0];
                    }
                    else if ((DataKind)outputColumn.DataTypeId == DataKind.Boolean)
                    {
                        var predictionOut = predictions.GetColumn <Boolean>("PredictedLabel").ToArray();
                        results.Item = predictionOut[0];
                    }
                };

                results.Success = true;
            }
            catch (Exception e)
            {
                results.Success = false;
            }


            //Delete the prediction file
            try
            {
                if (File.Exists(Path.GetTempPath() + fileName))
                {
                    File.Delete(Path.GetTempPath() + fileName);
                }
            }
            catch (Exception e) {}

            return(results);
        }
Exemple #10
0
        /// <summary>
        /// The main program entry point.
        /// </summary>
        /// <param name="args">The command line arguments.</param>
        static void Main(string[] args)
        {
            // create a machine learning context
            var context = new MLContext();

            // load data
            Console.WriteLine("Loading data....");
            var columnDef = new TextLoader.Column[]
            {
                new TextLoader.Column(nameof(Digit.PixelValues), DataKind.Single, 1, 784),
                new TextLoader.Column("Number", DataKind.Single, 0)
            };
            var trainDataView = context.Data.LoadFromTextFile(
                path: trainDataPath,
                columns: columnDef,
                hasHeader: true,
                separatorChar: ',');
            var testDataView = context.Data.LoadFromTextFile(
                path: testDataPath,
                columns: columnDef,
                hasHeader: true,
                separatorChar: ',');


            // build a training pipeline
            // step 1: map the number column to a key value and store in the label column
            var pipeline = context.Transforms.Conversion.MapValueToKey(
                outputColumnName: "Label",
                inputColumnName: "Number",
                keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue)

                           // step 2: concatenate all feature columns
                           .Append(context.Transforms.Concatenate(
                                       "Features",
                                       nameof(Digit.PixelValues)))

                           // step 3: cache data to speed up training
                           .AppendCacheCheckpoint(context)

                           // step 4: train the model with SDCA
                           .Append(context.MulticlassClassification.Trainers.SdcaMaximumEntropy(
                                       labelColumnName: "Label",
                                       featureColumnName: "Features"))

                           // step 5: map the label key value back to a number
                           .Append(context.Transforms.Conversion.MapKeyToValue(
                                       outputColumnName: "Number",
                                       inputColumnName: "Label"));

            // train the model
            Console.WriteLine("Training model....");
            var model = pipeline.Fit(trainDataView);

            // use the model to make predictions on the test data
            Console.WriteLine("Evaluating model....");
            var predictions = model.Transform(testDataView);

            // evaluate the predictions
            var metrics = context.MulticlassClassification.Evaluate(
                data: predictions,
                labelColumnName: "Number",
                scoreColumnName: "Score");

            // show evaluation metrics
            Console.WriteLine($"Evaluation metrics");
            Console.WriteLine($"    MicroAccuracy:    {metrics.MicroAccuracy:0.###}");
            Console.WriteLine($"    MacroAccuracy:    {metrics.MacroAccuracy:0.###}");
            Console.WriteLine($"    LogLoss:          {metrics.LogLoss:#.###}");
            Console.WriteLine($"    LogLossReduction: {metrics.LogLossReduction:#.###}");
            Console.WriteLine();

            // grab three digits from the test data
            var digits     = context.Data.CreateEnumerable <Digit>(testDataView, reuseRowObject: false).ToArray();
            var testDigits = new Digit[] { digits[5], digits[16], digits[28], digits[63], digits[129] };

            // create a prediction engine
            var engine = context.Model.CreatePredictionEngine <Digit, DigitPrediction>(model);

            // set up a table to show the predictions
            var table = new Table(TableConfiguration.Unicode());

            table.AddColumn("Digit");
            for (var i = 0; i < 10; i++)
            {
                table.AddColumn($"P{i}");
            }

            // predict each test digit
            for (var i = 0; i < testDigits.Length; i++)
            {
                var prediction = engine.Predict(testDigits[i]);
                table.AddRow(
                    testDigits[i].Number,
                    prediction.Score[0].ToString("P2"),
                    prediction.Score[1].ToString("P2"),
                    prediction.Score[2].ToString("P2"),
                    prediction.Score[3].ToString("P2"),
                    prediction.Score[4].ToString("P2"),
                    prediction.Score[5].ToString("P2"),
                    prediction.Score[6].ToString("P2"),
                    prediction.Score[7].ToString("P2"),
                    prediction.Score[8].ToString("P2"),
                    prediction.Score[9].ToString("P2"));
            }

            // show results
            Console.WriteLine(table.ToString());
            Console.ReadKey();
        }
        // This method is called if only a datafile is specified, without a loader/term and value columns.
        // It determines the type of the Value column and returns the appropriate TextLoader component factory.
        private static IComponentFactory <IMultiStreamSource, IDataLoader> GetLoaderFactory(string filename, bool keyValues, IHost host)
        {
            Contracts.AssertValue(host);

            // If the user specified non-key values, we define the value column to be numeric.
            if (!keyValues)
            {
                return(ComponentFactoryUtils.CreateFromFunction <IMultiStreamSource, IDataLoader>(
                           (env, files) => TextLoader.Create(
                               env,
                               new TextLoader.Arguments()
                {
                    Column = new[]
                    {
                        new TextLoader.Column("Term", DataKind.TX, 0),
                        new TextLoader.Column("Value", DataKind.Num, 1)
                    }
                },
                               files)));
            }

            // If the user specified key values, we scan the values to determine the range of the key type.
            ulong min = ulong.MaxValue;
            ulong max = ulong.MinValue;

            try
            {
                var  txtArgs = new TextLoader.Arguments();
                bool parsed  = CmdParser.ParseArguments(host, "col=Term:TX:0 col=Value:TX:1", txtArgs);
                host.Assert(parsed);
                var data = TextLoader.ReadFile(host, txtArgs, new MultiFileSource(filename));
                using (var cursor = data.GetRowCursor(c => true))
                {
                    var    getTerm = cursor.GetGetter <DvText>(0);
                    var    getVal  = cursor.GetGetter <DvText>(1);
                    DvText txt     = default(DvText);

                    using (var ch = host.Start("Creating Text Lookup Loader"))
                    {
                        long countNonKeys = 0;
                        while (cursor.MoveNext())
                        {
                            getVal(ref txt);
                            ulong res;
                            // Try to parse the text as a key value between 1 and ulong.MaxValue. If this succeeds and res>0,
                            // we update max and min accordingly. If res==0 it means the value is missing, in which case we ignore it for
                            // computing max and min.
                            if (Conversions.Instance.TryParseKey(ref txt, 1, ulong.MaxValue, out res))
                            {
                                if (res < min && res != 0)
                                {
                                    min = res;
                                }
                                if (res > max)
                                {
                                    max = res;
                                }
                            }
                            // If parsing as key did not succeed, the value can still be 0, so we try parsing it as a ulong. If it succeeds,
                            // then the value is 0, and we update min accordingly.
                            else if (Conversions.Instance.TryParse(ref txt, out res))
                            {
                                ch.Assert(res == 0);
                                min = 0;
                            }
                            //If parsing as a ulong fails, we increment the counter for the non-key values.
                            else
                            {
                                var term = default(DvText);
                                getTerm(ref term);
                                if (countNonKeys < 5)
                                {
                                    ch.Warning("Term '{0}' in mapping file is mapped to non key value '{1}'", term, txt);
                                }
                                countNonKeys++;
                            }
                        }
                        if (countNonKeys > 0)
                        {
                            ch.Warning("Found {0} non key values in the file '{1}'", countNonKeys, filename);
                        }
                        if (min > max)
                        {
                            min = 0;
                            max = uint.MaxValue - 1;
                            ch.Warning("did not find any valid key values in the file '{0}'", filename);
                        }
                        else
                        {
                            ch.Info("Found key values in the range {0} to {1} in the file '{2}'", min, max, filename);
                        }
                        ch.Done();
                    }
                }
            }
            catch (Exception e)
            {
                throw host.Except(e, "Failed to parse the lookup file '{0}' in TermLookupTransform", filename);
            }

            TextLoader.Column valueColumn = new TextLoader.Column("Value", DataKind.U4, 1);
            if (max - min < (ulong)int.MaxValue)
            {
                valueColumn.KeyRange = new KeyRange(min, max);
            }
            else if (max - min < (ulong)uint.MaxValue)
            {
                valueColumn.KeyRange = new KeyRange(min);
            }
            else
            {
                valueColumn.Type     = DataKind.U8;
                valueColumn.KeyRange = new KeyRange(min);
            }

            return(ComponentFactoryUtils.CreateFromFunction <IMultiStreamSource, IDataLoader>(
                       (env, files) => TextLoader.Create(
                           env,
                           new TextLoader.Arguments()
            {
                Column = new[]
                {
                    new TextLoader.Column("Term", DataKind.TX, 0),
                    valueColumn
                }
            },
                           files)));
        }
        static void Main(string[] args)
        {
            var context = new MLContext();

            Console.WriteLine("Loading Data...");

            var colDef = new TextLoader.Column[] {
                new TextLoader.Column(nameof(Digit.PixelValues), DataKind.Single, 1, 784),
                new TextLoader.Column("Number", DataKind.Single, 0)
            };

            var trainDataView = context.Data.LoadFromTextFile(trainDataPath, colDef, hasHeader: true, separatorChar: ',');
            var testDataView  = context.Data.LoadFromTextFile(testDataPath, colDef, hasHeader: true, separatorChar: ',');

            var pipeline = context.Transforms.Conversion.MapValueToKey(outputColumnName: "Label", inputColumnName: "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue)
                           .Append(context.Transforms.Concatenate("Features", nameof(Digit.PixelValues)))
                           .AppendCacheCheckpoint(context)
                           .Append(context.MulticlassClassification.Trainers.OneVersusAll(context.BinaryClassification.Trainers.FastForest(), "Label"))
                           .Append(context.Transforms.Conversion.MapKeyToValue(outputColumnName: "Number", inputColumnName: "Label"));

            Console.WriteLine("Training the model...");
            var model = pipeline.Fit(trainDataView);

            Console.WriteLine("Evaluating model...");
            var predictions = model.Transform(testDataView);

            var metrics = context.MulticlassClassification.Evaluate(predictions, labelColumnName: "Number", scoreColumnName: "Score");

            // show evaluation metrics
            Console.WriteLine($"Evaluation metrics");
            Console.WriteLine($"    MicroAccuracy:    {metrics.MicroAccuracy:0.###}");
            Console.WriteLine($"    MacroAccuracy:    {metrics.MacroAccuracy:0.###}");
            Console.WriteLine($"    LogLoss:          {metrics.LogLoss:#.###}");
            Console.WriteLine($"    LogLossReduction: {metrics.LogLossReduction:#.###}");
            Console.WriteLine();

            var digits = context.Data.CreateEnumerable <Digit>(testDataView, false).ToArray();

            var testDigits = new Digit[] {
                digits[215], // 0
                digits[202], // 1
                digits[199], // 2
                digits[200], // 3
                digits[198], // 4
                digits[207], // 5
                digits[201], // 6
                digits[220], // 7
                digits[226], // 8
                digits[235]  // 9
            };

            var engine = context.Model.CreatePredictionEngine <Digit, DigitPrediction>(model);

            var table = new BetterConsoleTables.Table(TableConfiguration.Unicode());

            table.AddColumn("Digits");
            for (var i = 0; i < 10; i++)
            {
                table.AddColumn($"P{i}");
            }

            for (var i = 0; i < testDigits.Length; i++)
            {
                var prediction = engine.Predict(testDigits[i]);
                table.AddRow(
                    testDigits[i].Number,
                    prediction.Score[0].ToString("P2"),
                    prediction.Score[1].ToString("P2"),
                    prediction.Score[2].ToString("P2"),
                    prediction.Score[3].ToString("P2"),
                    prediction.Score[4].ToString("P2"),
                    prediction.Score[5].ToString("P2"),
                    prediction.Score[6].ToString("P2"),
                    prediction.Score[7].ToString("P2"),
                    prediction.Score[8].ToString("P2"),
                    prediction.Score[9].ToString("P2"));
            }

            // show results
            Console.WriteLine(table.ToString());
        }
Exemple #13
0
        /// <summary>
        /// The main program entry point.
        /// </summary>
        /// <param name="args">The command line arguments.</param>
        static void Main(string[] args)
        {
            // create a machine learning context
            var context = new MLContext();

            // load data
            Console.WriteLine("Loading data....");
            var columnDef = new TextLoader.Column[]
            {
                new TextLoader.Column(nameof(Digit.PixelValues), DataKind.Single, 1, 784),
                new TextLoader.Column(nameof(Digit.Number), DataKind.Single, 0)
            };
            var trainDataView = context.Data.LoadFromTextFile(
                path: trainDataPath,
                columns: columnDef,
                hasHeader: true,
                separatorChar: ',');
            var testDataView = context.Data.LoadFromTextFile(
                path: testDataPath,
                columns: columnDef,
                hasHeader: true,
                separatorChar: ',');

            // load training and testing data
            var training = context.Data.CreateEnumerable <Digit>(trainDataView, reuseRowObject: false);
            var testing  = context.Data.CreateEnumerable <Digit>(testDataView, reuseRowObject: false);

            // set up data arrays
            var training_data   = training.Select(v => v.GetFeatures()).ToArray();
            var training_labels = training.Select(v => v.GetLabel()).ToArray();
            var testing_data    = testing.Select(v => v.GetFeatures()).ToArray();
            var testing_labels  = testing.Select(v => v.GetLabel()).ToArray();

            // build features and labels
            var features = NetUtil.Var(new int[] { 28, 28 }, DataType.Float);
            var labels   = NetUtil.Var(new int[] { 10 }, DataType.Float);

            // build the network
            var network = features
                          .Dense(512, CNTKLib.ReLU)
                          .Dense(10, CNTKLib.Softmax)
                          .ToNetwork();

            Console.WriteLine("Model architecture:");
            Console.WriteLine(network.ToSummary());

            // set up the loss function and the classification error function
            var lossFunc  = CNTKLib.CrossEntropyWithSoftmax(network.Output, labels);
            var errorFunc = CNTKLib.ClassificationError(network.Output, labels);

            // set up a trainer that uses the RMSProp algorithm
            var learner = network.GetRMSPropLearner(
                learningRateSchedule: 0.99,
                gamma: 0.95,
                inc: 2.0,
                dec: 0.5,
                max: 2.0,
                min: 0.5
                );

            // set up a trainer and an evaluator
            var trainer   = network.GetTrainer(learner, lossFunc, errorFunc);
            var evaluator = network.GetEvaluator(errorFunc);

            // train the model
            Console.WriteLine("Epoch\tTrain\tTrain\tTest");
            Console.WriteLine("\tLoss\tError\tError");
            Console.WriteLine("-----------------------------");

            var maxEpochs     = 50;
            var batchSize     = 128;
            var loss          = new double[maxEpochs];
            var trainingError = new double[maxEpochs];
            var testingError  = new double[maxEpochs];
            var batchCount    = 0;

            for (int epoch = 0; epoch < maxEpochs; epoch++)
            {
                // train one epoch on batches
                loss[epoch]          = 0.0;
                trainingError[epoch] = 0.0;
                batchCount           = 0;
                training_data.Index().Shuffle().Batch(batchSize, (indices, begin, end) =>
                {
                    // get the current batch
                    var featureBatch = features.GetBatch(training_data, indices, begin, end);
                    var labelBatch   = labels.GetBatch(training_labels, indices, begin, end);

                    // train the network on the batch
                    var result = trainer.TrainBatch(
                        new[] {
                        (features, featureBatch),
                        (labels, labelBatch)
                    },
Exemple #14
0
        public static ReturnResult <Model> Run([HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = null)] HttpRequest req, ILogger log)
        {
            var dataFilePath = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString());

            try
            {
                db.BeginTransaction();

                MLContext context = new MLContext();

                TrainInput input = null;

                using (StreamReader reader = new StreamReader(req.Body))
                {
                    input = JsonConvert.DeserializeObject <TrainInput>(reader.ReadToEnd());
                }

                File.WriteAllText(dataFilePath, input.Data);

                IDataView LoadedData = null;

                var columnData = new List <TextLoader.Column>();
                foreach (var c in input.Columns)
                {
                    //data type 1 is for ignore
                    if (c.Type != 1)
                    {
                        var newColData = new TextLoader.Column()
                        {
                            DataKind = (DataKind)c.Type,
                            Name     = c.ColumnName,
                            Source   = new TextLoader.Range[] { new TextLoader.Range(c.ColumnIndex) }
                        };

                        columnData.Add(newColData);
                    }
                }

                LoadedData = context.Data.LoadFromTextFile(
                    dataFilePath,
                    columnData.ToArray(),
                    separatorChar: input.Separator,
                    hasHeader: input.HasHeaders,
                    allowQuoting: true
                    );

                LoadedData = context.Data.ShuffleRows(LoadedData);

                /*
                 * Multiclass will be used in the case of binary experiments and multiclass experiments.
                 * This is because multiclass can accept all types as an output column. This will
                 * allow less interaction with the user and a better user experience.
                 */

                double       bestRunMetric = 0;
                ITransformer bestModel     = null;

                if (input.ModelType == TrainInput.ModelTypes.Multiclass)
                {
                    ExperimentResult <MulticlassClassificationMetrics> Results = null;
                    var settings = new MulticlassExperimentSettings()
                    {
                        MaxExperimentTimeInSeconds = 20
                    };
                    var training = context.Auto().CreateMulticlassClassificationExperiment(settings);
                    Results       = training.Execute(LoadedData, labelColumnName: input.LabelColumn);
                    bestRunMetric = Results.BestRun.ValidationMetrics.MacroAccuracy;
                    bestModel     = Results.BestRun.Model;
                }
                else if (input.ModelType == TrainInput.ModelTypes.Binary)
                {
                    ExperimentResult <BinaryClassificationMetrics> Results = null;
                    var settings = new BinaryExperimentSettings()
                    {
                        MaxExperimentTimeInSeconds = 20
                    };
                    var training = context.Auto().CreateBinaryClassificationExperiment(settings);
                    Results       = training.Execute(LoadedData, labelColumnName: input.LabelColumn);
                    bestRunMetric = Results.BestRun.ValidationMetrics.Accuracy;
                    bestModel     = Results.BestRun.Model;
                }
                else if (input.ModelType == TrainInput.ModelTypes.Regression)
                {
                    ExperimentResult <RegressionMetrics> Results = null;
                    var settings = new RegressionExperimentSettings()
                    {
                        MaxExperimentTimeInSeconds = 20
                    };
                    var training = context.Auto().CreateRegressionExperiment(settings);
                    Results       = training.Execute(LoadedData, labelColumnName: input.LabelColumn);
                    bestRunMetric = Results.BestRun.ValidationMetrics.RSquared;
                    bestModel     = Results.BestRun.Model;
                    if (bestRunMetric < 0)
                    {
                        bestRunMetric = 0;
                    }
                }
                else
                {
                    throw new Exception("Invalid model type");
                }


                var modelFileId = 0;

                using (MemoryStream ms = new MemoryStream())
                {
                    context.Model.Save(bestModel, LoadedData.Schema, ms);
                    //Save model to the database
                    FileStore modelSave = new FileStore()
                    {
                        Data = ms.ToArray()
                    };

                    modelFileId = FileStore.InsertUpdate(db, modelSave).Item.FileStoreId;
                }

                var resultModel = new Model()
                {
                    FileStoreId = modelFileId,
                    Accuracy    = bestRunMetric,
                    Rows        = input.Data.Trim().Split('\n').Length
                };

                db.CompleteTransaction();

                return(new ReturnResult <Model>()
                {
                    Success = true,
                    Item = resultModel
                });
            }
            catch (Exception e)
            {
                db.AbortTransaction();
                log.LogError(e.Message);
                return(new ReturnResult <Model>()
                {
                    Success = false,
                    Exception = e
                });
            }
        }