private static IDataView ApplyKeyToVec(List <KeyToVectorTransform.Column> ktv, IDataView viewTrain, IHost host) { Contracts.AssertValueOrNull(ktv); Contracts.AssertValue(viewTrain); Contracts.AssertValue(host); if (Utils.Size(ktv) > 0) { // Instead of simply using KeyToVector, we are jumping to some hoops here to do the right thing in a very common case // when the user has slightly different key values between the training and testing set. // The solution is to apply KeyToValue, then Term using the terms from the key metadata of the original key column // and finally the KeyToVector transform. viewTrain = new KeyToValueTransform(host, new KeyToValueTransform.Arguments() { Column = ktv .Select(c => new KeyToValueTransform.Column() { Name = c.Name, Source = c.Source }) .ToArray() }, viewTrain); viewTrain = new Data.TermTransform(host, new Data.TermTransform.Arguments() { Column = ktv .Select(c => new Data.TermTransform.Column() { Name = c.Name, Source = c.Name, Terms = GetTerms(viewTrain, c.Source) }) .ToArray(), TextKeyValues = true }, viewTrain); viewTrain = new KeyToVectorTransform(host, new KeyToVectorTransform.Arguments() { Column = ktv .Select(c => new KeyToVectorTransform.Column() { Name = c.Name, Source = c.Name }) .ToArray(), Bag = false }, viewTrain); } return(viewTrain); }
public static CommonOutputs.TransformOutput PrepareClassificationLabel(IHostEnvironment env, ClassificationLabelInput input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("PrepareClassificationLabel"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); int labelCol; if (!input.Data.Schema.TryGetColumnIndex(input.LabelColumn, out labelCol)) { throw host.Except($"Column '{input.LabelColumn}' not found."); } var labelType = input.Data.Schema.GetColumnType(labelCol); if (labelType.IsKey || labelType.IsBool) { var nop = NopTransform.CreateIfNeeded(env, input.Data); return(new CommonOutputs.TransformOutput { Model = new TransformModel(env, nop, input.Data), OutputData = nop }); } var args = new Data.TermTransform.Arguments() { Column = new[] { new Data.TermTransform.Column() { Name = input.LabelColumn, Source = input.LabelColumn, TextKeyValues = input.TextKeyValues, Sort = Data.TermTransform.SortOrder.Value } } }; var xf = new Data.TermTransform(host, args, input.Data); return(new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }); }