public void ClassLabelGenerationBasicTest() { var columns = new TextLoader.Column[] { new TextLoader.Column() { Name = "Label", Source = new TextLoader.Range[] { new TextLoader.Range(0) }, DataKind = DataKind.Boolean }, }; var result = new ColumnInferenceResults() { TextLoaderOptions = new TextLoader.Options() { Columns = columns, AllowQuoting = false, AllowSparse = false, Separators = new[] { ',' }, HasHeader = true, TrimWhitespace = true }, ColumnInformation = new ColumnInformation() }; CodeGenerator codeGenerator = new CodeGenerator(null, result, null); var actual = codeGenerator.GenerateClassLabels(); var expected1 = "[ColumnName(\"Label\"), LoadColumn(0)]"; var expected2 = "public bool Label{get; set;}"; Assert.Equal(expected1, actual[0]); Assert.Equal(expected2, actual[1]); }
/// <summary> /// Convert a type description into another type description. /// </summary> public static DataViewType Convert(TextLoader.Column col, IChannel ch = null) { if (col.Source != null && col.Source.Length > 0) { if (col.Source.Length != 1) { throw Contracts.ExceptNotImpl("Convert of TextLoader.Column is not implemented for more than on range."); } if (col.Source[0].ForceVector) { if (!col.Source[0].Max.HasValue) { throw ch != null?ch.Except("A vector column needs a dimension") : Contracts.Except("A vector column needs a dimension"); } int delta = col.Source[0].Max.Value - col.Source[0].Min + 1; var colType = DataKind2ColumnType(col.Type, ch); return(new VectorDataViewType(colType.AsPrimitive(), delta)); } } if (col.KeyCount != null) { var r = col.KeyCount; return(new KeyDataViewType(DataKind2ColumnType(col.Type).RawType, r.Count.HasValue ? r.Count.Value : 0)); } else { return(DataKind2ColumnType(col.Type, ch)); } }
public static TextLoader.Column[] GenerateLoaderColumns(Column[] columns) { var loaderColumns = new List <TextLoader.Column>(); foreach (var col in columns) { var loaderColumn = new TextLoader.Column(col.SuggestedName, col.ItemType.GetRawKind().ToDataKind(), col.ColumnIndex); loaderColumns.Add(loaderColumn); } return(loaderColumns.ToArray()); }
public MockConnection(MLContext context, DatabaseLoader databaseLoader) { var outputSchema = databaseLoader.GetOutputSchema(); var readerColumns = new TextLoader.Column[outputSchema.Count]; for (int i = 0; i < outputSchema.Count; i++) { var column = outputSchema[i]; var columnType = column.Type.RawType; Assert.True(columnType.TryGetDataKind(out var internalDataKind)); readerColumns[i] = new TextLoader.Column(column.Name, internalDataKind.ToDataKind(), i); } _textLoader = context.Data.CreateTextLoader(readerColumns); }
public MockConnection(MLContext context, DatabaseLoader.Column[] columns) { Columns = columns; var readerColumns = new TextLoader.Column[columns.Length]; for (int i = 0; i < columns.Length; i++) { var column = columns[i]; var columnType = column.Type.ToType(); Assert.True(columnType.TryGetDataKind(out var internalDataKind)); readerColumns[i] = new TextLoader.Column(column.Name, internalDataKind.ToDataKind(), i); } _reader = context.Data.CreateTextLoader(readerColumns); }
public List <T> GetRecords <T>(MemoryStream stream) where T : ICsvReadable, new() { // this library only allows loading from a file. // so write to a local file, use the length of the memory stream // to write to a different file based on the input data // this will be executed during the first "warmup" run var file = "data" + stream.Length + ".csv"; if (!File.Exists(file)) { using var data = File.Create(file); stream.CopyTo(data); } var activate = ActivatorFactory.Create <T>(_activationMethod); var allRecords = new List <T>(); var mlc = new MLContext(); using (var reader = new StreamReader(stream)) { var schema = new TextLoader.Column[25]; for (int i = 0; i < schema.Length; i++) { schema[i] = new TextLoader.Column("" + i, DataKind.String, i); } var opts = new TextLoader.Options() { HasHeader = false, Separators = new[] { ',' }, Columns = schema }; var l = mlc.Data.LoadFromTextFile(file, opts); var rc = l.GetRowCursor(l.Schema); var cols = l.Schema.ToArray(); var getters = cols.Select(c => rc.GetGetter <ReadOnlyMemory <char> >(c)).ToArray(); while (rc.MoveNext()) { var record = activate(); record.Read(i => { ReadOnlyMemory <char> s = null; getters[i](ref s); return(s.ToString()); }); allRecords.Add(record); } } return(allRecords); }
/// <summary> /// Reads a text file as a IDataView. /// Follows pandas API. /// </summary> /// <param name="filename">filename</param> /// <param name="sep">column separator</param> /// <param name="header">has a header or not</param> /// <param name="names">column names (can be empty)</param> /// <param name="dtypes">column types (can be empty)</param> /// <param name="nrows">number of rows to read</param> /// <param name="guess_rows">number of rows used to guess types</param> /// <param name="encoding">text encoding</param> /// <param name="useThreads">specific to TextLoader</param> /// <param name="host">host</param> /// <param name="index">add a column to hold the index</param> /// <returns>TextLoader</returns> public static IDataView ReadCsvToTextLoader(string[] filenames, char sep = ',', bool header = true, string[] names = null, DataViewType[] dtypes = null, int nrows = -1, int guess_rows = 10, Encoding encoding = null, bool useThreads = true, bool index = false, IHost host = null) { var df = ReadCsv(filenames[0], sep: sep, header: header, names: names, dtypes: dtypes, nrows: guess_rows, guess_rows: guess_rows, encoding: encoding, index: index); var sch = df.Schema; var cols = new TextLoader.Column[sch.Count]; for (int i = 0; i < cols.Length; ++i) { cols[i] = TextLoader.Column.Parse(df.NameType(i)); if (cols[i] == null) { throw Contracts.Except("Unable to parse '{0}'.", df.NameType(i)); } } var args = new TextLoader.Options() { AllowQuoting = false, Separators = new[] { sep }, Columns = cols, TrimWhitespace = true, UseThreads = useThreads, HasHeader = header, MaxRows = nrows > 0 ? (int?)nrows : null }; if (host == null) { host = new ConsoleEnvironment().Register("TextLoader"); } var multiSource = new MultiFileSource(filenames); return(new TextLoader(host, args, multiSource).Load(multiSource)); }
/// <summary> /// Convert a type description into another type description. /// </summary> public static ColumnType Convert(TextLoader.Column col, IChannel ch = null) { if (!col.Type.HasValue) { throw ch != null?ch.Except("Kind is null") : Contracts.Except("kind is null"); } if (col.Source != null && col.Source.Length > 0) { if (col.Source.Length != 1) { throw Contracts.ExceptNotImpl("Convert of TextLoader.Column is not implemented for more than on range."); } if (col.Source[0].ForceVector) { if (!col.Source[0].Max.HasValue) { throw ch != null?ch.Except("A vector column needs a dimension") : Contracts.Except("A vector column needs a dimension"); } int delta = col.Source[0].Max.Value - col.Source[0].Min + 1; var colType = DataKind2ColumnType(col.Type.Value, ch); return(new VectorType(colType.AsPrimitive(), delta)); } } if (col.KeyRange != null) { var r = col.KeyRange; return(new KeyType(col.Type.HasValue ? col.Type.Value.ToType() : null, r.Min, r.Max.HasValue ? (int)(r.Max.Value - r.Min + 1) : 0, r.Contiguous)); } else { return(DataKind2ColumnType(col.Type.Value, ch)); } }
public static ReturnResult <dynamic> PredictModel(IDatabase db, PredictionInput input) { var results = new ReturnResult <dynamic>(); var fileName = Guid.NewGuid().ToString(); try { var model = GetModelById(db, input.ModelId); var modelInfo = ModelField.GetFieldsByModelId(db, model.Item.ModelId); var modelFile = FileStore.GetById(db, model.Item.FileStoreId.Value); MLContext mlContext = new MLContext(); DataViewSchema predictionPipelineSchema; IDataView LoadedData = null; var columnData = new List <TextLoader.Column>(); var curItem = 0; foreach (var c in modelInfo.Item) { var newColData = new TextLoader.Column() { DataKind = (DataKind)c.DataTypeId, Name = c.Name, Source = new TextLoader.Range[] { new TextLoader.Range(curItem) } }; columnData.Add(newColData); curItem++; } File.WriteAllText(Path.GetTempPath() + fileName, input.CsvData); LoadedData = mlContext.Data.LoadFromTextFile( Path.GetTempPath() + fileName, columnData.ToArray(), separatorChar: ',', hasHeader: false, allowQuoting: true ); var outputColumn = modelInfo.Item.Single(x => x.IsOutput == true); using (MemoryStream ms = new MemoryStream(modelFile.Item.Data)) { var predictionPipeline = mlContext.Model.Load(ms, out predictionPipelineSchema); IDataView predictions = predictionPipeline.Transform(LoadedData); if ((DataKind)outputColumn.DataTypeId == DataKind.Single) { var hasLabel = predictions.Schema.GetColumnOrNull("PredictedLabel"); Single[] predictionOut = null; if (hasLabel == null) { predictionOut = predictions.GetColumn <Single>("Score").ToArray(); } else { predictionOut = predictions.GetColumn <Single>("PredictedLabel").ToArray(); } results.Item = predictionOut[0]; } else if ((DataKind)outputColumn.DataTypeId == DataKind.String) { var predictionOut = predictions.GetColumn <String>("PredictedLabel").ToArray(); results.Item = predictionOut[0]; } else if ((DataKind)outputColumn.DataTypeId == DataKind.Boolean) { var predictionOut = predictions.GetColumn <Boolean>("PredictedLabel").ToArray(); results.Item = predictionOut[0]; } }; results.Success = true; } catch (Exception e) { results.Success = false; } //Delete the prediction file try { if (File.Exists(Path.GetTempPath() + fileName)) { File.Delete(Path.GetTempPath() + fileName); } } catch (Exception e) {} return(results); }
/// <summary> /// The main program entry point. /// </summary> /// <param name="args">The command line arguments.</param> static void Main(string[] args) { // create a machine learning context var context = new MLContext(); // load data Console.WriteLine("Loading data...."); var columnDef = new TextLoader.Column[] { new TextLoader.Column(nameof(Digit.PixelValues), DataKind.Single, 1, 784), new TextLoader.Column("Number", DataKind.Single, 0) }; var trainDataView = context.Data.LoadFromTextFile( path: trainDataPath, columns: columnDef, hasHeader: true, separatorChar: ','); var testDataView = context.Data.LoadFromTextFile( path: testDataPath, columns: columnDef, hasHeader: true, separatorChar: ','); // build a training pipeline // step 1: map the number column to a key value and store in the label column var pipeline = context.Transforms.Conversion.MapValueToKey( outputColumnName: "Label", inputColumnName: "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue) // step 2: concatenate all feature columns .Append(context.Transforms.Concatenate( "Features", nameof(Digit.PixelValues))) // step 3: cache data to speed up training .AppendCacheCheckpoint(context) // step 4: train the model with SDCA .Append(context.MulticlassClassification.Trainers.SdcaMaximumEntropy( labelColumnName: "Label", featureColumnName: "Features")) // step 5: map the label key value back to a number .Append(context.Transforms.Conversion.MapKeyToValue( outputColumnName: "Number", inputColumnName: "Label")); // train the model Console.WriteLine("Training model...."); var model = pipeline.Fit(trainDataView); // use the model to make predictions on the test data Console.WriteLine("Evaluating model...."); var predictions = model.Transform(testDataView); // evaluate the predictions var metrics = context.MulticlassClassification.Evaluate( data: predictions, labelColumnName: "Number", scoreColumnName: "Score"); // show evaluation metrics Console.WriteLine($"Evaluation metrics"); Console.WriteLine($" MicroAccuracy: {metrics.MicroAccuracy:0.###}"); Console.WriteLine($" MacroAccuracy: {metrics.MacroAccuracy:0.###}"); Console.WriteLine($" LogLoss: {metrics.LogLoss:#.###}"); Console.WriteLine($" LogLossReduction: {metrics.LogLossReduction:#.###}"); Console.WriteLine(); // grab three digits from the test data var digits = context.Data.CreateEnumerable <Digit>(testDataView, reuseRowObject: false).ToArray(); var testDigits = new Digit[] { digits[5], digits[16], digits[28], digits[63], digits[129] }; // create a prediction engine var engine = context.Model.CreatePredictionEngine <Digit, DigitPrediction>(model); // set up a table to show the predictions var table = new Table(TableConfiguration.Unicode()); table.AddColumn("Digit"); for (var i = 0; i < 10; i++) { table.AddColumn($"P{i}"); } // predict each test digit for (var i = 0; i < testDigits.Length; i++) { var prediction = engine.Predict(testDigits[i]); table.AddRow( testDigits[i].Number, prediction.Score[0].ToString("P2"), prediction.Score[1].ToString("P2"), prediction.Score[2].ToString("P2"), prediction.Score[3].ToString("P2"), prediction.Score[4].ToString("P2"), prediction.Score[5].ToString("P2"), prediction.Score[6].ToString("P2"), prediction.Score[7].ToString("P2"), prediction.Score[8].ToString("P2"), prediction.Score[9].ToString("P2")); } // show results Console.WriteLine(table.ToString()); Console.ReadKey(); }
// This method is called if only a datafile is specified, without a loader/term and value columns. // It determines the type of the Value column and returns the appropriate TextLoader component factory. private static IComponentFactory <IMultiStreamSource, IDataLoader> GetLoaderFactory(string filename, bool keyValues, IHost host) { Contracts.AssertValue(host); // If the user specified non-key values, we define the value column to be numeric. if (!keyValues) { return(ComponentFactoryUtils.CreateFromFunction <IMultiStreamSource, IDataLoader>( (env, files) => TextLoader.Create( env, new TextLoader.Arguments() { Column = new[] { new TextLoader.Column("Term", DataKind.TX, 0), new TextLoader.Column("Value", DataKind.Num, 1) } }, files))); } // If the user specified key values, we scan the values to determine the range of the key type. ulong min = ulong.MaxValue; ulong max = ulong.MinValue; try { var txtArgs = new TextLoader.Arguments(); bool parsed = CmdParser.ParseArguments(host, "col=Term:TX:0 col=Value:TX:1", txtArgs); host.Assert(parsed); var data = TextLoader.ReadFile(host, txtArgs, new MultiFileSource(filename)); using (var cursor = data.GetRowCursor(c => true)) { var getTerm = cursor.GetGetter <DvText>(0); var getVal = cursor.GetGetter <DvText>(1); DvText txt = default(DvText); using (var ch = host.Start("Creating Text Lookup Loader")) { long countNonKeys = 0; while (cursor.MoveNext()) { getVal(ref txt); ulong res; // Try to parse the text as a key value between 1 and ulong.MaxValue. If this succeeds and res>0, // we update max and min accordingly. If res==0 it means the value is missing, in which case we ignore it for // computing max and min. if (Conversions.Instance.TryParseKey(ref txt, 1, ulong.MaxValue, out res)) { if (res < min && res != 0) { min = res; } if (res > max) { max = res; } } // If parsing as key did not succeed, the value can still be 0, so we try parsing it as a ulong. If it succeeds, // then the value is 0, and we update min accordingly. else if (Conversions.Instance.TryParse(ref txt, out res)) { ch.Assert(res == 0); min = 0; } //If parsing as a ulong fails, we increment the counter for the non-key values. else { var term = default(DvText); getTerm(ref term); if (countNonKeys < 5) { ch.Warning("Term '{0}' in mapping file is mapped to non key value '{1}'", term, txt); } countNonKeys++; } } if (countNonKeys > 0) { ch.Warning("Found {0} non key values in the file '{1}'", countNonKeys, filename); } if (min > max) { min = 0; max = uint.MaxValue - 1; ch.Warning("did not find any valid key values in the file '{0}'", filename); } else { ch.Info("Found key values in the range {0} to {1} in the file '{2}'", min, max, filename); } ch.Done(); } } } catch (Exception e) { throw host.Except(e, "Failed to parse the lookup file '{0}' in TermLookupTransform", filename); } TextLoader.Column valueColumn = new TextLoader.Column("Value", DataKind.U4, 1); if (max - min < (ulong)int.MaxValue) { valueColumn.KeyRange = new KeyRange(min, max); } else if (max - min < (ulong)uint.MaxValue) { valueColumn.KeyRange = new KeyRange(min); } else { valueColumn.Type = DataKind.U8; valueColumn.KeyRange = new KeyRange(min); } return(ComponentFactoryUtils.CreateFromFunction <IMultiStreamSource, IDataLoader>( (env, files) => TextLoader.Create( env, new TextLoader.Arguments() { Column = new[] { new TextLoader.Column("Term", DataKind.TX, 0), valueColumn } }, files))); }
static void Main(string[] args) { var context = new MLContext(); Console.WriteLine("Loading Data..."); var colDef = new TextLoader.Column[] { new TextLoader.Column(nameof(Digit.PixelValues), DataKind.Single, 1, 784), new TextLoader.Column("Number", DataKind.Single, 0) }; var trainDataView = context.Data.LoadFromTextFile(trainDataPath, colDef, hasHeader: true, separatorChar: ','); var testDataView = context.Data.LoadFromTextFile(testDataPath, colDef, hasHeader: true, separatorChar: ','); var pipeline = context.Transforms.Conversion.MapValueToKey(outputColumnName: "Label", inputColumnName: "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue) .Append(context.Transforms.Concatenate("Features", nameof(Digit.PixelValues))) .AppendCacheCheckpoint(context) .Append(context.MulticlassClassification.Trainers.OneVersusAll(context.BinaryClassification.Trainers.FastForest(), "Label")) .Append(context.Transforms.Conversion.MapKeyToValue(outputColumnName: "Number", inputColumnName: "Label")); Console.WriteLine("Training the model..."); var model = pipeline.Fit(trainDataView); Console.WriteLine("Evaluating model..."); var predictions = model.Transform(testDataView); var metrics = context.MulticlassClassification.Evaluate(predictions, labelColumnName: "Number", scoreColumnName: "Score"); // show evaluation metrics Console.WriteLine($"Evaluation metrics"); Console.WriteLine($" MicroAccuracy: {metrics.MicroAccuracy:0.###}"); Console.WriteLine($" MacroAccuracy: {metrics.MacroAccuracy:0.###}"); Console.WriteLine($" LogLoss: {metrics.LogLoss:#.###}"); Console.WriteLine($" LogLossReduction: {metrics.LogLossReduction:#.###}"); Console.WriteLine(); var digits = context.Data.CreateEnumerable <Digit>(testDataView, false).ToArray(); var testDigits = new Digit[] { digits[215], // 0 digits[202], // 1 digits[199], // 2 digits[200], // 3 digits[198], // 4 digits[207], // 5 digits[201], // 6 digits[220], // 7 digits[226], // 8 digits[235] // 9 }; var engine = context.Model.CreatePredictionEngine <Digit, DigitPrediction>(model); var table = new BetterConsoleTables.Table(TableConfiguration.Unicode()); table.AddColumn("Digits"); for (var i = 0; i < 10; i++) { table.AddColumn($"P{i}"); } for (var i = 0; i < testDigits.Length; i++) { var prediction = engine.Predict(testDigits[i]); table.AddRow( testDigits[i].Number, prediction.Score[0].ToString("P2"), prediction.Score[1].ToString("P2"), prediction.Score[2].ToString("P2"), prediction.Score[3].ToString("P2"), prediction.Score[4].ToString("P2"), prediction.Score[5].ToString("P2"), prediction.Score[6].ToString("P2"), prediction.Score[7].ToString("P2"), prediction.Score[8].ToString("P2"), prediction.Score[9].ToString("P2")); } // show results Console.WriteLine(table.ToString()); }
/// <summary> /// The main program entry point. /// </summary> /// <param name="args">The command line arguments.</param> static void Main(string[] args) { // create a machine learning context var context = new MLContext(); // load data Console.WriteLine("Loading data...."); var columnDef = new TextLoader.Column[] { new TextLoader.Column(nameof(Digit.PixelValues), DataKind.Single, 1, 784), new TextLoader.Column(nameof(Digit.Number), DataKind.Single, 0) }; var trainDataView = context.Data.LoadFromTextFile( path: trainDataPath, columns: columnDef, hasHeader: true, separatorChar: ','); var testDataView = context.Data.LoadFromTextFile( path: testDataPath, columns: columnDef, hasHeader: true, separatorChar: ','); // load training and testing data var training = context.Data.CreateEnumerable <Digit>(trainDataView, reuseRowObject: false); var testing = context.Data.CreateEnumerable <Digit>(testDataView, reuseRowObject: false); // set up data arrays var training_data = training.Select(v => v.GetFeatures()).ToArray(); var training_labels = training.Select(v => v.GetLabel()).ToArray(); var testing_data = testing.Select(v => v.GetFeatures()).ToArray(); var testing_labels = testing.Select(v => v.GetLabel()).ToArray(); // build features and labels var features = NetUtil.Var(new int[] { 28, 28 }, DataType.Float); var labels = NetUtil.Var(new int[] { 10 }, DataType.Float); // build the network var network = features .Dense(512, CNTKLib.ReLU) .Dense(10, CNTKLib.Softmax) .ToNetwork(); Console.WriteLine("Model architecture:"); Console.WriteLine(network.ToSummary()); // set up the loss function and the classification error function var lossFunc = CNTKLib.CrossEntropyWithSoftmax(network.Output, labels); var errorFunc = CNTKLib.ClassificationError(network.Output, labels); // set up a trainer that uses the RMSProp algorithm var learner = network.GetRMSPropLearner( learningRateSchedule: 0.99, gamma: 0.95, inc: 2.0, dec: 0.5, max: 2.0, min: 0.5 ); // set up a trainer and an evaluator var trainer = network.GetTrainer(learner, lossFunc, errorFunc); var evaluator = network.GetEvaluator(errorFunc); // train the model Console.WriteLine("Epoch\tTrain\tTrain\tTest"); Console.WriteLine("\tLoss\tError\tError"); Console.WriteLine("-----------------------------"); var maxEpochs = 50; var batchSize = 128; var loss = new double[maxEpochs]; var trainingError = new double[maxEpochs]; var testingError = new double[maxEpochs]; var batchCount = 0; for (int epoch = 0; epoch < maxEpochs; epoch++) { // train one epoch on batches loss[epoch] = 0.0; trainingError[epoch] = 0.0; batchCount = 0; training_data.Index().Shuffle().Batch(batchSize, (indices, begin, end) => { // get the current batch var featureBatch = features.GetBatch(training_data, indices, begin, end); var labelBatch = labels.GetBatch(training_labels, indices, begin, end); // train the network on the batch var result = trainer.TrainBatch( new[] { (features, featureBatch), (labels, labelBatch) },
public static ReturnResult <Model> Run([HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = null)] HttpRequest req, ILogger log) { var dataFilePath = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString()); try { db.BeginTransaction(); MLContext context = new MLContext(); TrainInput input = null; using (StreamReader reader = new StreamReader(req.Body)) { input = JsonConvert.DeserializeObject <TrainInput>(reader.ReadToEnd()); } File.WriteAllText(dataFilePath, input.Data); IDataView LoadedData = null; var columnData = new List <TextLoader.Column>(); foreach (var c in input.Columns) { //data type 1 is for ignore if (c.Type != 1) { var newColData = new TextLoader.Column() { DataKind = (DataKind)c.Type, Name = c.ColumnName, Source = new TextLoader.Range[] { new TextLoader.Range(c.ColumnIndex) } }; columnData.Add(newColData); } } LoadedData = context.Data.LoadFromTextFile( dataFilePath, columnData.ToArray(), separatorChar: input.Separator, hasHeader: input.HasHeaders, allowQuoting: true ); LoadedData = context.Data.ShuffleRows(LoadedData); /* * Multiclass will be used in the case of binary experiments and multiclass experiments. * This is because multiclass can accept all types as an output column. This will * allow less interaction with the user and a better user experience. */ double bestRunMetric = 0; ITransformer bestModel = null; if (input.ModelType == TrainInput.ModelTypes.Multiclass) { ExperimentResult <MulticlassClassificationMetrics> Results = null; var settings = new MulticlassExperimentSettings() { MaxExperimentTimeInSeconds = 20 }; var training = context.Auto().CreateMulticlassClassificationExperiment(settings); Results = training.Execute(LoadedData, labelColumnName: input.LabelColumn); bestRunMetric = Results.BestRun.ValidationMetrics.MacroAccuracy; bestModel = Results.BestRun.Model; } else if (input.ModelType == TrainInput.ModelTypes.Binary) { ExperimentResult <BinaryClassificationMetrics> Results = null; var settings = new BinaryExperimentSettings() { MaxExperimentTimeInSeconds = 20 }; var training = context.Auto().CreateBinaryClassificationExperiment(settings); Results = training.Execute(LoadedData, labelColumnName: input.LabelColumn); bestRunMetric = Results.BestRun.ValidationMetrics.Accuracy; bestModel = Results.BestRun.Model; } else if (input.ModelType == TrainInput.ModelTypes.Regression) { ExperimentResult <RegressionMetrics> Results = null; var settings = new RegressionExperimentSettings() { MaxExperimentTimeInSeconds = 20 }; var training = context.Auto().CreateRegressionExperiment(settings); Results = training.Execute(LoadedData, labelColumnName: input.LabelColumn); bestRunMetric = Results.BestRun.ValidationMetrics.RSquared; bestModel = Results.BestRun.Model; if (bestRunMetric < 0) { bestRunMetric = 0; } } else { throw new Exception("Invalid model type"); } var modelFileId = 0; using (MemoryStream ms = new MemoryStream()) { context.Model.Save(bestModel, LoadedData.Schema, ms); //Save model to the database FileStore modelSave = new FileStore() { Data = ms.ToArray() }; modelFileId = FileStore.InsertUpdate(db, modelSave).Item.FileStoreId; } var resultModel = new Model() { FileStoreId = modelFileId, Accuracy = bestRunMetric, Rows = input.Data.Trim().Split('\n').Length }; db.CompleteTransaction(); return(new ReturnResult <Model>() { Success = true, Item = resultModel }); } catch (Exception e) { db.AbortTransaction(); log.LogError(e.Message); return(new ReturnResult <Model>() { Success = false, Exception = e }); } }