public void TestCrossValidationMacroWithNonDefaultNames() { string dataPath = GetDataPath(@"adult.tiny.with-schema.txt"); var env = new MLContext(42); var subGraph = env.CreateExperiment(); var textToKey = new Legacy.Transforms.TextToKeyConverter(); textToKey.Column = new[] { new Legacy.Transforms.ValueToKeyMappingTransformerColumn() { Name = "Label1", Source = "Label" } }; var textToKeyOutput = subGraph.Add(textToKey); var hash = new Legacy.Transforms.HashConverter(); hash.Column = new[] { new Legacy.Transforms.HashJoiningTransformColumn() { Name = "GroupId1", Source = "Workclass" } }; hash.Data = textToKeyOutput.OutputData; var hashOutput = subGraph.Add(hash); var learnerInput = new Legacy.Trainers.FastTreeRanker { TrainingData = hashOutput.OutputData, NumThreads = 1, LabelColumn = "Label1", GroupIdColumn = "GroupId1" }; var learnerOutput = subGraph.Add(learnerInput); var modelCombine = new Legacy.Transforms.ManyHeterogeneousModelCombiner { TransformModels = new ArrayVar <ITransformModel>(textToKeyOutput.Model, hashOutput.Model), PredictorModel = learnerOutput.PredictorModel }; var modelCombineOutput = subGraph.Add(modelCombine); var experiment = env.CreateExperiment(); var importInput = new Legacy.Data.TextLoader(dataPath); importInput.Arguments.HasHeader = true; importInput.Arguments.Column = new TextLoaderColumn[] { new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } }, new TextLoaderColumn { Name = "Workclass", Source = new[] { new TextLoaderRange(1) }, Type = Legacy.Data.DataKind.Text }, new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(9, 14) } } }; var importOutput = experiment.Add(importInput); var crossValidate = new Legacy.Models.CrossValidator { Data = importOutput.Data, Nodes = subGraph, TransformModel = null, LabelColumn = "Label1", GroupColumn = "GroupId1", NameColumn = "Workclass", Kind = Legacy.Models.MacroUtilsTrainerKinds.SignatureRankerTrainer }; crossValidate.Inputs.Data = textToKey.Data; crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; var crossValidateOutput = experiment.Add(crossValidate); experiment.Compile(); experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); experiment.Run(); var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); var schema = data.Schema; var b = schema.TryGetColumnIndex("NDCG", out int metricCol); Assert.True(b); b = schema.TryGetColumnIndex("Fold Index", out int foldCol); Assert.True(b); using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) { var getter = cursor.GetGetter <VBuffer <double> >(metricCol); var foldGetter = cursor.GetGetter <ReadOnlyMemory <char> >(foldCol); ReadOnlyMemory <char> fold = default; // Get the verage. b = cursor.MoveNext(); Assert.True(b); var avg = default(VBuffer <double>); getter(ref avg); foldGetter(ref fold); Assert.True(ReadOnlyMemoryUtils.EqualsStr("Average", fold)); // Get the standard deviation. b = cursor.MoveNext(); Assert.True(b); var stdev = default(VBuffer <double>); getter(ref stdev); foldGetter(ref fold); Assert.True(ReadOnlyMemoryUtils.EqualsStr("Standard Deviation", fold)); var stdevValues = stdev.GetValues(); Assert.Equal(2.462, stdevValues[0], 3); Assert.Equal(2.763, stdevValues[1], 3); Assert.Equal(3.273, stdevValues[2], 3); var sumBldr = new BufferBuilder <double>(R8Adder.Instance); sumBldr.Reset(avg.Length, true); var val = default(VBuffer <double>); for (int f = 0; f < 2; f++) { b = cursor.MoveNext(); Assert.True(b); getter(ref val); foldGetter(ref fold); sumBldr.AddFeatures(0, in val); Assert.True(ReadOnlyMemoryUtils.EqualsStr("Fold " + f, fold)); } var sum = default(VBuffer <double>); sumBldr.GetResult(ref sum); var avgValues = avg.GetValues(); var sumValues = sum.GetValues(); for (int i = 0; i < avgValues.Length; i++) { Assert.Equal(avgValues[i], sumValues[i] / 2); } b = cursor.MoveNext(); Assert.False(b); } data = experiment.GetOutput(crossValidateOutput.PerInstanceMetrics); Assert.True(data.Schema.TryGetColumnIndex("Instance", out int nameCol)); using (var cursor = data.GetRowCursor(col => col == nameCol)) { var getter = cursor.GetGetter <ReadOnlyMemory <char> >(nameCol); while (cursor.MoveNext()) { ReadOnlyMemory <char> name = default; getter(ref name); Assert.Subset(new HashSet <string>() { "Private", "?", "Federal-gov" }, new HashSet <string>() { name.ToString() }); if (cursor.Position > 4) { break; } } } }
public void TestCrossValidationMacroWithNonDefaultNames() { string dataPath = GetDataPath(@"adult.tiny.with-schema.txt"); using (var env = new TlcEnvironment(42)) { var subGraph = env.CreateExperiment(); var textToKey = new ML.Transforms.TextToKeyConverter(); textToKey.Column = new[] { new ML.Transforms.TermTransformColumn() { Name = "Label1", Source = "Label" } }; var textToKeyOutput = subGraph.Add(textToKey); var hash = new ML.Transforms.HashConverter(); hash.Column = new[] { new ML.Transforms.HashJoinTransformColumn() { Name = "GroupId1", Source = "Workclass" } }; hash.Data = textToKeyOutput.OutputData; var hashOutput = subGraph.Add(hash); var learnerInput = new Trainers.FastTreeRanker { TrainingData = hashOutput.OutputData, NumThreads = 1, LabelColumn = "Label1", GroupIdColumn = "GroupId1" }; var learnerOutput = subGraph.Add(learnerInput); var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner { TransformModels = new ArrayVar <ITransformModel>(textToKeyOutput.Model, hashOutput.Model), PredictorModel = learnerOutput.PredictorModel }; var modelCombineOutput = subGraph.Add(modelCombine); var experiment = env.CreateExperiment(); var importInput = new ML.Data.TextLoader(dataPath); importInput.Arguments.HasHeader = true; importInput.Arguments.Column = new TextLoaderColumn[] { new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } }, new TextLoaderColumn { Name = "Workclass", Source = new[] { new TextLoaderRange(1) }, Type = DataKind.Text }, new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(9, 14) } } }; var importOutput = experiment.Add(importInput); var crossValidate = new Models.CrossValidator { Data = importOutput.Data, Nodes = subGraph, TransformModel = null, LabelColumn = "Label1", GroupColumn = "GroupId1", Kind = Models.MacroUtilsTrainerKinds.SignatureRankerTrainer }; crossValidate.Inputs.Data = textToKey.Data; crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; var crossValidateOutput = experiment.Add(crossValidate); experiment.Compile(); experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); experiment.Run(); var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); var schema = data.Schema; var b = schema.TryGetColumnIndex("NDCG", out int metricCol); Assert.True(b); b = schema.TryGetColumnIndex("Fold Index", out int foldCol); Assert.True(b); using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) { var getter = cursor.GetGetter <VBuffer <double> >(metricCol); var foldGetter = cursor.GetGetter <DvText>(foldCol); DvText fold = default; // Get the verage. b = cursor.MoveNext(); Assert.True(b); var avg = default(VBuffer <double>); getter(ref avg); foldGetter(ref fold); Assert.True(fold.EqualsStr("Average")); // Get the standard deviation. b = cursor.MoveNext(); Assert.True(b); var stdev = default(VBuffer <double>); getter(ref stdev); foldGetter(ref fold); Assert.True(fold.EqualsStr("Standard Deviation")); Assert.Equal(5.247, stdev.Values[0], 3); Assert.Equal(4.703, stdev.Values[1], 3); Assert.Equal(3.844, stdev.Values[2], 3); var sumBldr = new BufferBuilder <double>(R8Adder.Instance); sumBldr.Reset(avg.Length, true); var val = default(VBuffer <double>); for (int f = 0; f < 2; f++) { b = cursor.MoveNext(); Assert.True(b); getter(ref val); foldGetter(ref fold); sumBldr.AddFeatures(0, ref val); Assert.True(fold.EqualsStr("Fold " + f)); } var sum = default(VBuffer <double>); sumBldr.GetResult(ref sum); for (int i = 0; i < avg.Length; i++) { Assert.Equal(avg.Values[i], sum.Values[i] / 2); } b = cursor.MoveNext(); Assert.False(b); } } }