public static CommonOutputs.TransformOutput Generate(IHostEnvironment env, GenerateNumberTransform.Arguments input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, "GenerateNumber", input); var xf = new GenerateNumberTransform(h, input, input.Data); return(new CommonOutputs.TransformOutput() { Model = new TransformModel(h, xf, input.Data), OutputData = xf }); }
private IDataView WrapPerInstance(RoleMappedData perInst) { var idv = perInst.Data; // Make a list of column names that Maml outputs as part of the per-instance data view, and then wrap // the per-instance data computed by the evaluator in a SelectColumnsTransform. var cols = new List <(string Source, string Name)>(); var colsToKeep = new List <string>(); // If perInst is the result of cross-validation and contains a fold Id column, include it. int foldCol; if (perInst.Schema.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out foldCol)) { colsToKeep.Add(MetricKinds.ColumnNames.FoldIndex); } // Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform. if (perInst.Schema.Name == null) { var args = new GenerateNumberTransform.Arguments(); args.Column = new[] { new GenerateNumberTransform.Column() { Name = "Instance" } }; args.UseCounter = true; idv = new GenerateNumberTransform(Host, args, idv); colsToKeep.Add("Instance"); } else { cols.Add((perInst.Schema.Name.Name, "Instance")); colsToKeep.Add("Instance"); } // Maml outputs the weight column if it exists. if (perInst.Schema.Weight != null) { colsToKeep.Add(perInst.Schema.Weight.Name); } // Get the other columns from the evaluator. foreach (var col in GetPerInstanceColumnsToSave(perInst.Schema)) { colsToKeep.Add(col); } idv = new ColumnsCopyingTransformer(Host, cols.ToArray()).Transform(idv); idv = ColumnSelectingTransformer.CreateKeep(Host, idv, colsToKeep.ToArray()); return(GetPerInstanceMetricsCore(idv, perInst.Schema)); }
private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output) { // The stratification column and/or group column, if they exist at all, must be present at this point. var schema = input.Schema; output = input; // If no stratification column was specified, but we have a group column of type Single, Double or // Key (contiguous) use it. string stratificationColumn = null; if (!string.IsNullOrWhiteSpace(Args.StratificationColumn)) { stratificationColumn = Args.StratificationColumn; } else { string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); int index; if (group != null && schema.TryGetColumnIndex(group, out index)) { // Check if group column key type with known cardinality. var type = schema.GetColumnType(index); if (type.KeyCount > 0) { stratificationColumn = group; } } } if (string.IsNullOrEmpty(stratificationColumn)) { stratificationColumn = "StratificationColumn"; int tmp; int inc = 0; while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) { stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc); } var keyGenArgs = new GenerateNumberTransform.Arguments(); var col = new GenerateNumberTransform.Column(); col.Name = stratificationColumn; keyGenArgs.Column = new[] { col }; output = new GenerateNumberTransform(Host, keyGenArgs, input); } else { int col; if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col)) { throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn); } var type = input.Schema.GetColumnType(col); if (!RangeFilter.IsValidRangeFilterColumnType(ch, type)) { ch.Info("Hashing the stratification column"); var origStratCol = stratificationColumn; int tmp; int inc = 0; while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) { stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc); } var hashargs = new HashTransform.Arguments(); hashargs.Column = new[] { new HashTransform.Column { Source = origStratCol, Name = stratificationColumn } }; hashargs.HashBits = 30; output = new HashTransform(Host, hashargs, input); } } return(stratificationColumn); }