Ejemplo n.º 1
0
            public AutoInference.EntryPointGraphDef ToEntryPointGraph(IHostEnvironment env)
            {
                // All transforms must have associated PipelineNode objects
                var unsupportedTransform = Transforms.Where(transform => transform.PipelineNode == null).Cast <TransformInference.SuggestedTransform?>().FirstOrDefault();

                if (unsupportedTransform != null)
                {
                    throw env.ExceptNotSupp($"All transforms in recipe must have entrypoint support. {unsupportedTransform} is not yet supported.");
                }
                var subGraph = env.CreateExperiment();

                Var <IDataView> lastOutput = new Var <IDataView>();

                // Chain transforms
                var transformsModels = new List <Var <ITransformModel> >();

                foreach (var transform in Transforms)
                {
                    transform.PipelineNode.SetInputData(lastOutput);
                    var transformAddResult = transform.PipelineNode.Add(subGraph);
                    transformsModels.Add(transformAddResult.Model);
                    lastOutput = transformAddResult.OutData;
                }

                // Add learner, if present. If not, just return transforms graph object.
                if (Learners.Length > 0 && Learners[0].PipelineNode != null)
                {
                    // Add learner
                    var learner = Learners[0];
                    learner.PipelineNode.SetInputData(lastOutput);
                    var learnerAddResult = learner.PipelineNode.Add(subGraph);

                    // Create single model for featurizing and scoring data,
                    // if transforms present.
                    if (Transforms.Length > 0)
                    {
                        var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner
                        {
                            TransformModels = new ArrayVar <ITransformModel>(transformsModels.ToArray()),
                            PredictorModel  = learnerAddResult.Model
                        };
                        var modelCombineOutput = subGraph.Add(modelCombine);

                        return(new AutoInference.EntryPointGraphDef(subGraph, modelCombineOutput.PredictorModel, lastOutput));
                    }

                    // No transforms present, so just return predictor's model.
                    return(new AutoInference.EntryPointGraphDef(subGraph, learnerAddResult.Model, lastOutput));
                }

                return(new AutoInference.EntryPointGraphDef(subGraph, null, lastOutput));
            }
Ejemplo n.º 2
0
        /// <summary>
        /// Constructs an entrypoint graph from the current pipeline.
        /// </summary>
        public AutoInference.EntryPointGraphDef ToEntryPointGraph(Experiment experiment = null)
        {
            _env.CheckValue(Learner.PipelineNode, nameof(Learner.PipelineNode));
            var subGraph = experiment ?? _env.CreateExperiment();

            // Insert first node
            Var <IDataView> lastOutput = new Var <IDataView>();

            // Chain transforms
            var transformsModels = new List <Var <ITransformModel> >();
            var viableTransforms = Transforms.ToList().Where(transform => transform.PipelineNode != null);

            foreach (var transform in viableTransforms)
            {
                transform.PipelineNode.SetInputData(lastOutput);
                var returnedDataAndModel1 = transform.PipelineNode.Add(subGraph);
                transformsModels.Add(returnedDataAndModel1.Model);
                lastOutput = returnedDataAndModel1.OutData;
            }

            // Add learner
            Learner.PipelineNode?.SetInputData(lastOutput);
            var returnedDataAndModel2 = Learner.PipelineNode?.Add(subGraph);

            // Create single model for featurizing and scoring data,
            // if transforms present.
            if (Transforms.Length > 0)
            {
                var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner
                {
                    TransformModels = new ArrayVar <ITransformModel>(transformsModels.ToArray()),
                    PredictorModel  = returnedDataAndModel2?.Model
                };
                var modelCombineOutput = subGraph.Add(modelCombine);

                return(new AutoInference.EntryPointGraphDef(subGraph, modelCombineOutput.PredictorModel, lastOutput));
            }

            // No transforms present, so just return predictor's model.
            return(new AutoInference.EntryPointGraphDef(subGraph, returnedDataAndModel2?.Model, lastOutput));
        }
Ejemplo n.º 3
0
        public void TestCrossValidationMacroWithNonDefaultNames()
        {
            string dataPath = GetDataPath(@"adult.tiny.with-schema.txt");

            using (var env = new TlcEnvironment(42))
            {
                var subGraph = env.CreateExperiment();

                var textToKey = new ML.Transforms.TextToKeyConverter();
                textToKey.Column = new[] { new ML.Transforms.TermTransformColumn()
                                           {
                                               Name = "Label1", Source = "Label"
                                           } };
                var textToKeyOutput = subGraph.Add(textToKey);

                var hash = new ML.Transforms.HashConverter();
                hash.Column = new[] { new ML.Transforms.HashJoinTransformColumn()
                                      {
                                          Name = "GroupId1", Source = "Workclass"
                                      } };
                hash.Data = textToKeyOutput.OutputData;
                var hashOutput = subGraph.Add(hash);

                var learnerInput = new Trainers.FastTreeRanker
                {
                    TrainingData  = hashOutput.OutputData,
                    NumThreads    = 1,
                    LabelColumn   = "Label1",
                    GroupIdColumn = "GroupId1"
                };
                var learnerOutput = subGraph.Add(learnerInput);

                var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner
                {
                    TransformModels = new ArrayVar <ITransformModel>(textToKeyOutput.Model, hashOutput.Model),
                    PredictorModel  = learnerOutput.PredictorModel
                };
                var modelCombineOutput = subGraph.Add(modelCombine);

                var experiment  = env.CreateExperiment();
                var importInput = new ML.Data.TextLoader(dataPath);
                importInput.Arguments.HasHeader = true;
                importInput.Arguments.Column    = new TextLoaderColumn[]
                {
                    new TextLoaderColumn {
                        Name = "Label", Source = new[] { new TextLoaderRange(0) }
                    },
                    new TextLoaderColumn {
                        Name = "Workclass", Source = new[] { new TextLoaderRange(1) }, Type = ML.Data.DataKind.Text
                    },
                    new TextLoaderColumn {
                        Name = "Features", Source = new[] { new TextLoaderRange(9, 14) }
                    }
                };
                var importOutput = experiment.Add(importInput);

                var crossValidate = new Models.CrossValidator
                {
                    Data           = importOutput.Data,
                    Nodes          = subGraph,
                    TransformModel = null,
                    LabelColumn    = "Label1",
                    GroupColumn    = "GroupId1",
                    Kind           = Models.MacroUtilsTrainerKinds.SignatureRankerTrainer
                };
                crossValidate.Inputs.Data            = textToKey.Data;
                crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
                var crossValidateOutput = experiment.Add(crossValidate);
                experiment.Compile();
                experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
                experiment.Run();
                var data = experiment.GetOutput(crossValidateOutput.OverallMetrics);

                var schema = data.Schema;
                var b      = schema.TryGetColumnIndex("NDCG", out int metricCol);
                Assert.True(b);
                b = schema.TryGetColumnIndex("Fold Index", out int foldCol);
                Assert.True(b);
                using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol))
                {
                    var    getter     = cursor.GetGetter <VBuffer <double> >(metricCol);
                    var    foldGetter = cursor.GetGetter <DvText>(foldCol);
                    DvText fold       = default;

                    // Get the verage.
                    b = cursor.MoveNext();
                    Assert.True(b);
                    var avg = default(VBuffer <double>);
                    getter(ref avg);
                    foldGetter(ref fold);
                    Assert.True(fold.EqualsStr("Average"));

                    // Get the standard deviation.
                    b = cursor.MoveNext();
                    Assert.True(b);
                    var stdev = default(VBuffer <double>);
                    getter(ref stdev);
                    foldGetter(ref fold);
                    Assert.True(fold.EqualsStr("Standard Deviation"));
                    Assert.Equal(5.247, stdev.Values[0], 3);
                    Assert.Equal(4.703, stdev.Values[1], 3);
                    Assert.Equal(3.844, stdev.Values[2], 3);

                    var sumBldr = new BufferBuilder <double>(R8Adder.Instance);
                    sumBldr.Reset(avg.Length, true);
                    var val = default(VBuffer <double>);
                    for (int f = 0; f < 2; f++)
                    {
                        b = cursor.MoveNext();
                        Assert.True(b);
                        getter(ref val);
                        foldGetter(ref fold);
                        sumBldr.AddFeatures(0, ref val);
                        Assert.True(fold.EqualsStr("Fold " + f));
                    }
                    var sum = default(VBuffer <double>);
                    sumBldr.GetResult(ref sum);
                    for (int i = 0; i < avg.Length; i++)
                    {
                        Assert.Equal(avg.Values[i], sum.Values[i] / 2);
                    }
                    b = cursor.MoveNext();
                    Assert.False(b);
                }
            }
        }
Ejemplo n.º 4
0
        public void TestCrossValidationMacroWithStratification()
        {
            var dataPath = GetDataPath(@"breast-cancer.txt");

            using (var env = new TlcEnvironment(42))
            {
                var subGraph = env.CreateExperiment();

                var nop       = new ML.Transforms.NoOperation();
                var nopOutput = subGraph.Add(nop);

                var learnerInput = new ML.Trainers.StochasticDualCoordinateAscentBinaryClassifier
                {
                    TrainingData = nopOutput.OutputData,
                    NumThreads   = 1
                };
                var learnerOutput = subGraph.Add(learnerInput);

                var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner
                {
                    TransformModels = new ArrayVar <ITransformModel>(nopOutput.Model),
                    PredictorModel  = learnerOutput.PredictorModel
                };
                var modelCombineOutput = subGraph.Add(modelCombine);

                var experiment  = env.CreateExperiment();
                var importInput = new ML.Data.TextLoader(dataPath);
                importInput.Arguments.Column = new ML.Data.TextLoaderColumn[]
                {
                    new ML.Data.TextLoaderColumn {
                        Name = "Label", Source = new[] { new ML.Data.TextLoaderRange(0) }
                    },
                    new ML.Data.TextLoaderColumn {
                        Name = "Strat", Source = new[] { new ML.Data.TextLoaderRange(1) }
                    },
                    new ML.Data.TextLoaderColumn {
                        Name = "Features", Source = new[] { new ML.Data.TextLoaderRange(2, 9) }
                    }
                };
                var importOutput = experiment.Add(importInput);

                var crossValidate = new ML.Models.CrossValidator
                {
                    Data                 = importOutput.Data,
                    Nodes                = subGraph,
                    TransformModel       = null,
                    StratificationColumn = "Strat"
                };
                crossValidate.Inputs.Data            = nop.Data;
                crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
                var crossValidateOutput = experiment.Add(crossValidate);
                experiment.Compile();
                experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
                experiment.Run();
                var data = experiment.GetOutput(crossValidateOutput.OverallMetrics);

                var schema = data.Schema;
                var b      = schema.TryGetColumnIndex("AUC", out int metricCol);
                Assert.True(b);
                b = schema.TryGetColumnIndex("Fold Index", out int foldCol);
                Assert.True(b);
                using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol))
                {
                    var    getter     = cursor.GetGetter <double>(metricCol);
                    var    foldGetter = cursor.GetGetter <DvText>(foldCol);
                    DvText fold       = default;

                    // Get the verage.
                    b = cursor.MoveNext();
                    Assert.True(b);
                    double avg = 0;
                    getter(ref avg);
                    foldGetter(ref fold);
                    Assert.True(fold.EqualsStr("Average"));

                    // Get the standard deviation.
                    b = cursor.MoveNext();
                    Assert.True(b);
                    double stdev = 0;
                    getter(ref stdev);
                    foldGetter(ref fold);
                    Assert.True(fold.EqualsStr("Standard Deviation"));
                    Assert.Equal(0.00485, stdev, 5);

                    double sum = 0;
                    double val = 0;
                    for (int f = 0; f < 2; f++)
                    {
                        b = cursor.MoveNext();
                        Assert.True(b);
                        getter(ref val);
                        foldGetter(ref fold);
                        sum += val;
                        Assert.True(fold.EqualsStr("Fold " + f));
                    }
                    Assert.Equal(avg, sum / 2);
                    b = cursor.MoveNext();
                    Assert.False(b);
                }
            }
        }
Ejemplo n.º 5
0
        public void TestCrossValidationMacroWithMultiClass()
        {
            var dataPath = GetDataPath(@"Train-Tiny-28x28.txt");

            using (var env = new TlcEnvironment(42))
            {
                var subGraph = env.CreateExperiment();

                var nop       = new ML.Transforms.NoOperation();
                var nopOutput = subGraph.Add(nop);

                var learnerInput = new ML.Trainers.StochasticDualCoordinateAscentClassifier
                {
                    TrainingData = nopOutput.OutputData,
                    NumThreads   = 1
                };
                var learnerOutput = subGraph.Add(learnerInput);

                var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner
                {
                    TransformModels = new ArrayVar <ITransformModel>(nopOutput.Model),
                    PredictorModel  = learnerOutput.PredictorModel
                };
                var modelCombineOutput = subGraph.Add(modelCombine);

                var experiment   = env.CreateExperiment();
                var importInput  = new ML.Data.TextLoader(dataPath);
                var importOutput = experiment.Add(importInput);

                var crossValidate = new ML.Models.CrossValidator
                {
                    Data           = importOutput.Data,
                    Nodes          = subGraph,
                    Kind           = ML.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer,
                    TransformModel = null
                };
                crossValidate.Inputs.Data            = nop.Data;
                crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
                var crossValidateOutput = experiment.Add(crossValidate);

                experiment.Compile();
                importInput.SetInput(env, experiment);
                experiment.Run();
                var data = experiment.GetOutput(crossValidateOutput.OverallMetrics);

                var schema = data.Schema;
                var b      = schema.TryGetColumnIndex("Accuracy(micro-avg)", out int metricCol);
                Assert.True(b);
                b = schema.TryGetColumnIndex("Fold Index", out int foldCol);
                Assert.True(b);
                using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol))
                {
                    var    getter     = cursor.GetGetter <double>(metricCol);
                    var    foldGetter = cursor.GetGetter <DvText>(foldCol);
                    DvText fold       = default;

                    // Get the verage.
                    b = cursor.MoveNext();
                    Assert.True(b);
                    double avg = 0;
                    getter(ref avg);
                    foldGetter(ref fold);
                    Assert.True(fold.EqualsStr("Average"));

                    // Get the standard deviation.
                    b = cursor.MoveNext();
                    Assert.True(b);
                    double stdev = 0;
                    getter(ref stdev);
                    foldGetter(ref fold);
                    Assert.True(fold.EqualsStr("Standard Deviation"));
                    Assert.Equal(0.025, stdev, 3);

                    double sum = 0;
                    double val = 0;
                    for (int f = 0; f < 2; f++)
                    {
                        b = cursor.MoveNext();
                        Assert.True(b);
                        getter(ref val);
                        foldGetter(ref fold);
                        sum += val;
                        Assert.True(fold.EqualsStr("Fold " + f));
                    }
                    Assert.Equal(avg, sum / 2);
                    b = cursor.MoveNext();
                    Assert.False(b);
                }

                var confusion = experiment.GetOutput(crossValidateOutput.ConfusionMatrix);
                schema = confusion.Schema;
                b      = schema.TryGetColumnIndex("Count", out int countCol);
                Assert.True(b);
                b = schema.TryGetColumnIndex("Fold Index", out foldCol);
                Assert.True(b);
                var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, countCol);
                Assert.True(type != null && type.ItemType.IsText && type.VectorSize == 10);
                var slotNames = default(VBuffer <DvText>);
                schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countCol, ref slotNames);
                Assert.True(slotNames.Values.Select((s, i) => s.EqualsStr(i.ToString())).All(x => x));
                using (var curs = confusion.GetRowCursor(col => true))
                {
                    var countGetter = curs.GetGetter <VBuffer <double> >(countCol);
                    var foldGetter  = curs.GetGetter <DvText>(foldCol);
                    var confCount   = default(VBuffer <double>);
                    var foldIndex   = default(DvText);
                    int rowCount    = 0;
                    var foldCur     = "Fold 0";
                    while (curs.MoveNext())
                    {
                        countGetter(ref confCount);
                        foldGetter(ref foldIndex);
                        rowCount++;
                        Assert.True(foldIndex.EqualsStr(foldCur));
                        if (rowCount == 10)
                        {
                            rowCount = 0;
                            foldCur  = "Fold 1";
                        }
                    }
                    Assert.Equal(0, rowCount);
                }

                var warnings = experiment.GetOutput(crossValidateOutput.Warnings);
                using (var cursor = warnings.GetRowCursor(col => true))
                    Assert.False(cursor.MoveNext());
            }
        }
Ejemplo n.º 6
0
        public void TestCrossValidationMacro()
        {
            var dataPath = GetDataPath(TestDatasets.winequality.trainFilename);

            using (var env = new TlcEnvironment(42))
            {
                var subGraph = env.CreateExperiment();

                var nop       = new ML.Transforms.NoOperation();
                var nopOutput = subGraph.Add(nop);

                var generate = new ML.Transforms.RandomNumberGenerator();
                generate.Column = new[] { new ML.Transforms.GenerateNumberTransformColumn()
                                          {
                                              Name = "Weight1"
                                          } };
                generate.Data = nopOutput.OutputData;
                var generateOutput = subGraph.Add(generate);

                var learnerInput = new ML.Trainers.PoissonRegressor
                {
                    TrainingData = generateOutput.OutputData,
                    NumThreads   = 1,
                    WeightColumn = "Weight1"
                };
                var learnerOutput = subGraph.Add(learnerInput);

                var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner
                {
                    TransformModels = new ArrayVar <ITransformModel>(nopOutput.Model, generateOutput.Model),
                    PredictorModel  = learnerOutput.PredictorModel
                };
                var modelCombineOutput = subGraph.Add(modelCombine);

                var experiment  = env.CreateExperiment();
                var importInput = new ML.Data.TextLoader(dataPath)
                {
                    Arguments = new TextLoaderArguments
                    {
                        Separator = new[] { ';' },
                        HasHeader = true,
                        Column    = new[]
                        {
                            new TextLoaderColumn()
                            {
                                Name   = "Label",
                                Source = new [] { new TextLoaderRange(11) },
                                Type   = ML.Data.DataKind.Num
                            },

                            new TextLoaderColumn()
                            {
                                Name   = "Features",
                                Source = new [] { new TextLoaderRange(0, 10) },
                                Type   = ML.Data.DataKind.Num
                            }
                        }
                    }
                };
                var importOutput = experiment.Add(importInput);

                var crossValidate = new ML.Models.CrossValidator
                {
                    Data           = importOutput.Data,
                    Nodes          = subGraph,
                    Kind           = ML.Models.MacroUtilsTrainerKinds.SignatureRegressorTrainer,
                    TransformModel = null,
                    WeightColumn   = "Weight1"
                };
                crossValidate.Inputs.Data            = nop.Data;
                crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
                var crossValidateOutput = experiment.Add(crossValidate);

                experiment.Compile();
                importInput.SetInput(env, experiment);
                experiment.Run();
                var data = experiment.GetOutput(crossValidateOutput.OverallMetrics);

                var schema = data.Schema;
                var b      = schema.TryGetColumnIndex("L1(avg)", out int metricCol);
                Assert.True(b);
                b = schema.TryGetColumnIndex("Fold Index", out int foldCol);
                Assert.True(b);
                b = schema.TryGetColumnIndex("IsWeighted", out int isWeightedCol);
                using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol || col == isWeightedCol))
                {
                    var    getter           = cursor.GetGetter <double>(metricCol);
                    var    foldGetter       = cursor.GetGetter <DvText>(foldCol);
                    var    isWeightedGetter = cursor.GetGetter <DvBool>(isWeightedCol);
                    DvText fold             = default;
                    DvBool isWeighted       = default;

                    double avg         = 0;
                    double weightedAvg = 0;
                    for (int w = 0; w < 2; w++)
                    {
                        // Get the average.
                        b = cursor.MoveNext();
                        Assert.True(b);
                        if (w == 1)
                        {
                            getter(ref weightedAvg);
                        }
                        else
                        {
                            getter(ref avg);
                        }
                        foldGetter(ref fold);
                        Assert.True(fold.EqualsStr("Average"));
                        isWeightedGetter(ref isWeighted);
                        Assert.True(isWeighted.IsTrue == (w == 1));

                        // Get the standard deviation.
                        b = cursor.MoveNext();
                        Assert.True(b);
                        double stdev = 0;
                        getter(ref stdev);
                        foldGetter(ref fold);
                        Assert.True(fold.EqualsStr("Standard Deviation"));
                        if (w == 1)
                        {
                            Assert.Equal(0.002827, stdev, 6);
                        }
                        else
                        {
                            Assert.Equal(0.002376, stdev, 6);
                        }
                        isWeightedGetter(ref isWeighted);
                        Assert.True(isWeighted.IsTrue == (w == 1));
                    }
                    double sum         = 0;
                    double weightedSum = 0;
                    for (int f = 0; f < 2; f++)
                    {
                        for (int w = 0; w < 2; w++)
                        {
                            b = cursor.MoveNext();
                            Assert.True(b);
                            double val = 0;
                            getter(ref val);
                            foldGetter(ref fold);
                            if (w == 1)
                            {
                                weightedSum += val;
                            }
                            else
                            {
                                sum += val;
                            }
                            Assert.True(fold.EqualsStr("Fold " + f));
                            isWeightedGetter(ref isWeighted);
                            Assert.True(isWeighted.IsTrue == (w == 1));
                        }
                    }
                    Assert.Equal(weightedAvg, weightedSum / 2);
                    Assert.Equal(avg, sum / 2);
                    b = cursor.MoveNext();
                    Assert.False(b);
                }
            }
        }
Ejemplo n.º 7
0
        public void TestCrossValidationBinaryMacro()
        {
            var dataPath = GetDataPath("adult.tiny.with-schema.txt");

            using (var env = new TlcEnvironment())
            {
                var subGraph = env.CreateExperiment();

                var catInput = new ML.Transforms.CategoricalOneHotVectorizer();
                catInput.AddColumn("Categories");
                var catOutput = subGraph.Add(catInput);

                var concatInput = new ML.Transforms.ColumnConcatenator
                {
                    Data = catOutput.OutputData
                };
                concatInput.AddColumn("Features", "Categories", "NumericFeatures");
                var concatOutput = subGraph.Add(concatInput);

                var lrInput = new ML.Trainers.LogisticRegressionBinaryClassifier
                {
                    TrainingData = concatOutput.OutputData,
                    NumThreads   = 1
                };
                var lrOutput = subGraph.Add(lrInput);

                var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner
                {
                    TransformModels = new ArrayVar <ITransformModel>(catOutput.Model, concatOutput.Model),
                    PredictorModel  = lrOutput.PredictorModel
                };
                var modelCombineOutput = subGraph.Add(modelCombine);

                var experiment = env.CreateExperiment();

                var importInput  = new ML.Data.TextLoader(dataPath);
                var importOutput = experiment.Add(importInput);

                var crossValidateBinary = new ML.Models.BinaryCrossValidator
                {
                    Data  = importOutput.Data,
                    Nodes = subGraph
                };
                crossValidateBinary.Inputs.Data   = catInput.Data;
                crossValidateBinary.Outputs.Model = modelCombineOutput.PredictorModel;
                var crossValidateOutput = experiment.Add(crossValidateBinary);

                experiment.Compile();
                importInput.SetInput(env, experiment);
                experiment.Run();
                var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]);

                var schema = data.Schema;
                var b      = schema.TryGetColumnIndex("AUC", out int aucCol);
                Assert.True(b);
                using (var cursor = data.GetRowCursor(col => col == aucCol))
                {
                    var getter = cursor.GetGetter <double>(aucCol);
                    b = cursor.MoveNext();
                    Assert.True(b);
                    double auc = 0;
                    getter(ref auc);
                    Assert.Equal(0.87, auc, 1);
                    b = cursor.MoveNext();
                    Assert.False(b);
                }
            }
        }
Ejemplo n.º 8
0
        public void TestTrainTestMacro()
        {
            var dataPath = GetDataPath("adult.tiny.with-schema.txt");

            using (var env = new TlcEnvironment())
            {
                var subGraph = env.CreateExperiment();

                var catInput = new ML.Transforms.CategoricalOneHotVectorizer();
                catInput.AddColumn("Categories");
                var catOutput = subGraph.Add(catInput);

                var concatInput = new ML.Transforms.ColumnConcatenator
                {
                    Data = catOutput.OutputData
                };
                concatInput.AddColumn("Features", "Categories", "NumericFeatures");
                var concatOutput = subGraph.Add(concatInput);

                var sdcaInput = new ML.Trainers.StochasticDualCoordinateAscentBinaryClassifier
                {
                    TrainingData = concatOutput.OutputData,
                    LossFunction = new HingeLossSDCAClassificationLossFunction()
                    {
                        Margin = 1.1f
                    },
                    NumThreads = 1,
                    Shuffle    = false
                };
                var sdcaOutput = subGraph.Add(sdcaInput);

                var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner
                {
                    TransformModels = new ArrayVar <ITransformModel>(catOutput.Model, concatOutput.Model),
                    PredictorModel  = sdcaOutput.PredictorModel
                };
                var modelCombineOutput = subGraph.Add(modelCombine);

                var experiment = env.CreateExperiment();

                var importInput  = new ML.Data.TextLoader(dataPath);
                var importOutput = experiment.Add(importInput);

                var trainTestInput = new ML.Models.TrainTestBinaryEvaluator
                {
                    TrainingData = importOutput.Data,
                    TestingData  = importOutput.Data,
                    Nodes        = subGraph
                };
                trainTestInput.Inputs.Data   = catInput.Data;
                trainTestInput.Outputs.Model = modelCombineOutput.PredictorModel;
                var trainTestOutput = experiment.Add(trainTestInput);

                experiment.Compile();
                experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
                experiment.Run();
                var data = experiment.GetOutput(trainTestOutput.OverallMetrics);

                var schema = data.Schema;
                var b      = schema.TryGetColumnIndex("AUC", out int aucCol);
                Assert.True(b);
                using (var cursor = data.GetRowCursor(col => col == aucCol))
                {
                    var getter = cursor.GetGetter <double>(aucCol);
                    b = cursor.MoveNext();
                    Assert.True(b);
                    double auc = 0;
                    getter(ref auc);
                    Assert.Equal(0.93, auc, 2);
                    b = cursor.MoveNext();
                    Assert.False(b);
                }
            }
        }