Esempio n. 1
0
        public void TestCrossValidationMacroWithNonDefaultNames()
        {
            string dataPath = GetDataPath(@"adult.tiny.with-schema.txt");
            var    env      = new MLContext(42);
            var    subGraph = env.CreateExperiment();

            var textToKey = new Legacy.Transforms.TextToKeyConverter();

            textToKey.Column = new[] { new Legacy.Transforms.ValueToKeyMappingTransformerColumn()
                                       {
                                           Name = "Label1", Source = "Label"
                                       } };
            var textToKeyOutput = subGraph.Add(textToKey);

            var hash = new Legacy.Transforms.HashConverter();

            hash.Column = new[] { new Legacy.Transforms.HashJoiningTransformColumn()
                                  {
                                      Name = "GroupId1", Source = "Workclass"
                                  } };
            hash.Data = textToKeyOutput.OutputData;
            var hashOutput = subGraph.Add(hash);

            var learnerInput = new Legacy.Trainers.FastTreeRanker
            {
                TrainingData  = hashOutput.OutputData,
                NumThreads    = 1,
                LabelColumn   = "Label1",
                GroupIdColumn = "GroupId1"
            };
            var learnerOutput = subGraph.Add(learnerInput);

            var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner
            {
                TransformModels = new ArrayVar <TransformModel>(textToKeyOutput.Model, hashOutput.Model),
                PredictorModel  = learnerOutput.PredictorModel
            };
            var modelCombineOutput = subGraph.Add(modelCombine);

            var experiment  = env.CreateExperiment();
            var importInput = new Legacy.Data.TextLoader(dataPath);

            importInput.Arguments.HasHeader = true;
            importInput.Arguments.Column    = new TextLoaderColumn[]
            {
                new TextLoaderColumn {
                    Name = "Label", Source = new[] { new TextLoaderRange(0) }
                },
                new TextLoaderColumn {
                    Name = "Workclass", Source = new[] { new TextLoaderRange(1) }, Type = Legacy.Data.DataKind.Text
                },
                new TextLoaderColumn {
                    Name = "Features", Source = new[] { new TextLoaderRange(9, 14) }
                }
            };
            var importOutput = experiment.Add(importInput);

            var crossValidate = new Legacy.Models.CrossValidator
            {
                Data           = importOutput.Data,
                Nodes          = subGraph,
                TransformModel = null,
                LabelColumn    = "Label1",
                GroupColumn    = "GroupId1",
                NameColumn     = "Workclass",
                Kind           = Legacy.Models.MacroUtilsTrainerKinds.SignatureRankerTrainer
            };

            crossValidate.Inputs.Data            = textToKey.Data;
            crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel;
            var crossValidateOutput = experiment.Add(crossValidate);

            experiment.Compile();
            experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
            experiment.Run();
            var data = experiment.GetOutput(crossValidateOutput.OverallMetrics);

            var schema = data.Schema;
            var b      = schema.TryGetColumnIndex("NDCG", out int metricCol);

            Assert.True(b);
            b = schema.TryGetColumnIndex("Fold Index", out int foldCol);
            Assert.True(b);
            using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol))
            {
                var getter                 = cursor.GetGetter <VBuffer <double> >(metricCol);
                var foldGetter             = cursor.GetGetter <ReadOnlyMemory <char> >(foldCol);
                ReadOnlyMemory <char> fold = default;

                // Get the verage.
                b = cursor.MoveNext();
                Assert.True(b);
                var avg = default(VBuffer <double>);
                getter(ref avg);
                foldGetter(ref fold);
                Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold));

                // Get the standard deviation.
                b = cursor.MoveNext();
                Assert.True(b);
                var stdev = default(VBuffer <double>);
                getter(ref stdev);
                foldGetter(ref fold);
                Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold));
                var stdevValues = stdev.GetValues();
                Assert.Equal(2.462, stdevValues[0], 3);
                Assert.Equal(2.763, stdevValues[1], 3);
                Assert.Equal(3.273, stdevValues[2], 3);

                var sumBldr = new BufferBuilder <double>(R8Adder.Instance);
                sumBldr.Reset(avg.Length, true);
                var val = default(VBuffer <double>);
                for (int f = 0; f < 2; f++)
                {
                    b = cursor.MoveNext();
                    Assert.True(b);
                    getter(ref val);
                    foldGetter(ref fold);
                    sumBldr.AddFeatures(0, in val);
                    Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold));
                }
                var sum = default(VBuffer <double>);
                sumBldr.GetResult(ref sum);

                var avgValues = avg.GetValues();
                var sumValues = sum.GetValues();
                for (int i = 0; i < avgValues.Length; i++)
                {
                    Assert.Equal(avgValues[i], sumValues[i] / 2);
                }
                b = cursor.MoveNext();
                Assert.False(b);
            }

            data = experiment.GetOutput(crossValidateOutput.PerInstanceMetrics);
            Assert.True(data.Schema.TryGetColumnIndex("Instance", out int nameCol));
            using (var cursor = data.GetRowCursor(col => col == nameCol))
            {
                var getter = cursor.GetGetter <ReadOnlyMemory <char> >(nameCol);
                while (cursor.MoveNext())
                {
                    ReadOnlyMemory <char> name = default;
                    getter(ref name);
                    Assert.Subset(new HashSet <string>()
                    {
                        "Private", "?", "Federal-gov"
                    }, new HashSet <string>()
                    {
                        name.ToString()
                    });
                    if (cursor.Position > 4)
                    {
                        break;
                    }
                }
            }
        }
Esempio n. 2
0
        public void TestCrossValidationMacroMultiClassWithWarnings()
        {
            var dataPath = GetDataPath(@"Train-Tiny-28x28.txt");
            var env      = new MLContext(42);
            var subGraph = env.CreateExperiment();

            var nop       = new Legacy.Transforms.NoOperation();
            var nopOutput = subGraph.Add(nop);

            var learnerInput = new Legacy.Trainers.LogisticRegressionClassifier
            {
                TrainingData = nopOutput.OutputData,
                NumThreads   = 1
            };
            var learnerOutput = subGraph.Add(learnerInput);

            var experiment   = env.CreateExperiment();
            var importInput  = new Legacy.Data.TextLoader(dataPath);
            var importOutput = experiment.Add(importInput);

            var filter = new Legacy.Transforms.RowRangeFilter();

            filter.Data   = importOutput.Data;
            filter.Column = "Label";
            filter.Min    = 0;
            filter.Max    = 5;
            var filterOutput = experiment.Add(filter);

            var term = new Legacy.Transforms.TextToKeyConverter();

            term.Column = new[]
            {
                new Legacy.Transforms.ValueToKeyMappingTransformerColumn()
                {
                    Source = "Label", Name = "Strat", Sort = Legacy.Transforms.ValueToKeyMappingTransformerSortOrder.Value
                }
            };
            term.Data = filterOutput.OutputData;
            var termOutput = experiment.Add(term);

            var crossValidate = new Legacy.Models.CrossValidator
            {
                Data                 = termOutput.OutputData,
                Nodes                = subGraph,
                Kind                 = Legacy.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer,
                TransformModel       = null,
                StratificationColumn = "Strat"
            };

            crossValidate.Inputs.Data            = nop.Data;
            crossValidate.Outputs.PredictorModel = learnerOutput.PredictorModel;
            var crossValidateOutput = experiment.Add(crossValidate);

            experiment.Compile();
            importInput.SetInput(env, experiment);
            experiment.Run();
            var warnings = experiment.GetOutput(crossValidateOutput.Warnings);

            var schema = warnings.Schema;
            var b      = schema.TryGetColumnIndex("WarningText", out int warningCol);

            Assert.True(b);
            using (var cursor = warnings.GetRowCursor(col => col == warningCol))
            {
                var getter = cursor.GetGetter <ReadOnlyMemory <char> >(warningCol);

                b = cursor.MoveNext();
                Assert.True(b);
                var warning = default(ReadOnlyMemory <char>);
                getter(ref warning);
                Assert.Contains("test instances with class values not seen in the training set.", warning.ToString());
                b = cursor.MoveNext();
                Assert.True(b);
                getter(ref warning);
                Assert.Contains("Detected columns of variable length: SortedScores, SortedClasses", warning.ToString());
                b = cursor.MoveNext();
                Assert.False(b);
            }
        }