public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var columnDictionary = inputSchema.ToDictionary(x => x.Name); for (int i = 0; i < _columns.Length; i++) { for (int j = 0; j < _columns[i].InputColumnNames.Length; j++) { if (!inputSchema.TryFindColumn(_columns[i].InputColumnNames[j], out var inputCol)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[i].InputColumnNames[j]); } } // Make sure there is at most one vector valued source column. var inputTypes = new DataViewType[_columns[i].InputColumnNames.Length]; var ivec = FindVectorInputColumn(_host, _columns[i].InputColumnNames, inputSchema, inputTypes); var node = ParseAndBindLambda(_host, _columns[i].Expression, ivec, inputTypes, out var perm); var typeRes = node.ResultType; _host.Assert(typeRes is PrimitiveDataViewType); // If one of the input columns is a vector column, we pass through the slot names metadata. SchemaShape.Column.VectorKind outputVectorKind; var metadata = new List <SchemaShape.Column>(); if (ivec == -1) { outputVectorKind = SchemaShape.Column.VectorKind.Scalar; } else { inputSchema.TryFindColumn(_columns[i].InputColumnNames[ivec], out var vectorCol); outputVectorKind = vectorCol.Kind; if (vectorCol.HasSlotNames()) { var b = vectorCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotNames); _host.Assert(b); metadata.Add(slotNames); } } var outputSchemaShapeColumn = new SchemaShape.Column(_columns[i].Name, outputVectorKind, typeRes, false, new SchemaShape(metadata)); columnDictionary[_columns[i].Name] = outputSchemaShapeColumn; } return(new SchemaShape(columnDictionary.Values)); }
protected override void CheckLabelCompatible(SchemaShape.Column labelCol) { Contracts.Assert(labelCol.IsValid); Action error = () => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "BL, R8, R4 or a Key", labelCol.GetTypeString()); if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar) { error(); } if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4 && labelCol.ItemType != NumberType.R8 && !(labelCol.ItemType is BoolType)) { error(); } }
private protected override void CheckLabelCompatible(SchemaShape.Column labelCol) { Contracts.Assert(labelCol.IsValid); Action error = () => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "float, double, bool or KeyType", labelCol.GetTypeString()); if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar) { error(); } if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single && labelCol.ItemType != NumberDataViewType.Double && !(labelCol.ItemType is BooleanDataViewType)) { error(); } }
protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, string weightColumn, string groupIdColumn, int numLeaves, int numTrees, int minDocumentsInLeafs, double learningRate, Action <TArgs> advancedSettings) : base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, advancedSettings) { if (Args.LearningRates != learningRate) { using (var ch = Host.Start($"Setting learning rate to: {learningRate} as supplied in the direct arguments.")) Args.LearningRates = learningRate; } }
private static ColumnType MakeColumnType(SchemaShape.Column inputCol) { ColumnType curType = inputCol.ItemType; if (inputCol.IsKey) { curType = new KeyType(curType.AsPrimitive.RawKind, 0, AllKeySizes); } if (inputCol.Kind == SchemaShape.Column.VectorKind.VariableVector) { curType = new VectorType(curType.AsPrimitive, 0); } else if (inputCol.Kind == SchemaShape.Column.VectorKind.Vector) { curType = new VectorType(curType.AsPrimitive, AllVectorSizes); } return(curType); }
private static ColumnType MakeColumnType(SchemaShape.Column column) { ColumnType curType = column.ItemType; if (column.IsKey) { curType = new KeyType(((PrimitiveType)curType).RawKind, 0, AllKeySizes); } if (column.Kind == SchemaShape.Column.VectorKind.VariableVector) { curType = new VectorType((PrimitiveType)curType, 0); } else if (column.Kind == SchemaShape.Column.VectorKind.Vector) { curType = new VectorType((PrimitiveType)curType, AllVectorSizes); } return(curType); }
internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn) : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn)) { Host.CheckValue(args, nameof(args)); Args = args; Contracts.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null, nameof(Args.NumThreads), "numThreads must be positive (or empty for default)"); Contracts.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative"); Contracts.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative"); Contracts.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive"); Contracts.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive"); Contracts.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive"); Contracts.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative"); Contracts.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative"); L2Weight = Args.L2Weight; L1Weight = Args.L1Weight; OptTol = Args.OptTol; MemorySize = Args.MemorySize; MaxIterations = Args.MaxIterations; SgdInitializationTolerance = Args.SgdInitializationTolerance; Quiet = Args.Quiet; InitWtsDiameter = Args.InitWtsDiameter; UseThreads = Args.UseThreads; NumThreads = Args.NumThreads; DenseOptimizer = Args.DenseOptimizer; EnforceNonNegativity = Args.EnforceNonNegativity; if (EnforceNonNegativity && ShowTrainingStats) { ShowTrainingStats = false; using (var ch = Host.Start("Initialization")) { ch.Warning("The training statistics cannot be computed with non-negativity constraint."); ch.Done(); } } ShowTrainingStats = false; _srcPredictor = default; }
/// <summary> /// Initializes the <see cref="MetaMulticlassClassificationTrainer{TTransformer, TModel}"/> from the <see cref="OptionsBase"/> class. /// </summary> /// <param name="env">The private instance of the <see cref="IHostEnvironment"/>.</param> /// <param name="options">The legacy arguments <see cref="OptionsBase"/>class.</param> /// <param name="name">The component name.</param> /// <param name="labelColumn">The label column for the metalinear trainer and the binary trainer.</param> /// <param name="singleEstimator">The binary estimator.</param> /// <param name="calibrator">The calibrator. If a calibrator is not explicitly provided, it will default to <see cref="PlattCalibratorTrainer"/></param> internal MetaMulticlassClassificationTrainer(IHostEnvironment env, OptionsBase options, string name, string labelColumn = null, TScalarTrainer singleEstimator = null, ICalibratorTrainer calibrator = null) { Host = Contracts.CheckRef(env, nameof(env)).Register(name); Host.CheckValue(options, nameof(options)); Args = options; if (labelColumn != null) LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true); Trainer = singleEstimator ?? CreateTrainer(); Calibrator = calibrator ?? new PlattCalibratorTrainer(env); if (options.Calibrator != null) Calibrator = options.Calibrator.CreateComponent(Host); // Regarding caching, no matter what the internal predictor, we're performing many passes // simply by virtue of this being a meta-trainer, so we will still cache. Info = new TrainerInfo(normalization: Trainer.Info.NeedNormalization); }
/// <summary> /// Schema propagation for transformers. Returns the output schema of the data, if /// the input schema is like the one provided. /// </summary> public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); void CheckColumnsCompatible(SchemaShape.Column cachedColumn, string columnRole) { if (!inputSchema.TryFindColumn(cachedColumn.Name, out var col)) { throw _host.ExceptSchemaMismatch(nameof(col), columnRole, cachedColumn.Name); } if (!cachedColumn.IsCompatibleWith(col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, cachedColumn.Name, cachedColumn.GetTypeString(), col.GetTypeString()); } } // Check if label column is good. var labelColumn = new SchemaShape.Column(LabelName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false); CheckColumnsCompatible(labelColumn, "label"); // Check if columns of matrix's row and column indexes are good. Note that column of IDataView and column of matrix are two different things. var matrixColumnIndexColumn = new SchemaShape.Column(MatrixColumnIndexName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true); var matrixRowIndexColumn = new SchemaShape.Column(MatrixRowIndexName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true); CheckColumnsCompatible(matrixColumnIndexColumn, "matrixColumnIndex"); CheckColumnsCompatible(matrixRowIndexColumn, "matrixRowIndex"); // Input columns just pass through so that output column dictionary contains all input columns. var outColumns = inputSchema.ToDictionary(x => x.Name); // Add columns produced by this estimator. foreach (var col in GetOutputColumnsCore(inputSchema)) { outColumns[col.Name] = col; } return(new SchemaShape(outColumns.Values)); }
internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env, string[] featureColumns, string labelColumn = DefaultColumnNames.Label, string weights = null) : base(env, LoadName) { var args = new Options(); Initialize(env, args); Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); FeatureColumns = new SchemaShape.Column[featureColumns.Length]; for (int i = 0; i < featureColumns.Length; i++) { FeatureColumns[i] = new SchemaShape.Column(featureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); WeightColumn = weights != null ? new SchemaShape.Column(weights, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default; }
/// <summary> /// Legacy constructor initializing a new instance of <see cref="FieldAwareFactorizationMachineTrainer"/> through the legacy /// <see cref="Arguments"/> class. /// </summary> /// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param> /// <param name="args">An instance of the legacy <see cref="Arguments"/> to apply advanced parameters to the algorithm.</param> public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments args) : base(env, LoadName) { Initialize(env, args); Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); // There can be multiple feature columns in FFM, jointly specified by args.FeatureColumn and args.ExtraFeatureColumns. FeatureColumns = new SchemaShape.Column[1 + args.ExtraFeatureColumns.Length]; // Treat the default feature column as the 1st field. FeatureColumns[0] = new SchemaShape.Column(args.FeatureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); // Add 2nd, 3rd, and other fields from a FFM-specific argument, args.ExtraFeatureColumns. for (int i = 0; args.ExtraFeatureColumns != null && i < args.ExtraFeatureColumns.Length; i++) { FeatureColumns[i + 1] = new SchemaShape.Column(args.ExtraFeatureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } LabelColumn = new SchemaShape.Column(args.LabelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); WeightColumn = args.WeightColumn.IsExplicit ? new SchemaShape.Column(args.WeightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default; }
private protected CalibratorEstimatorBase(IHostEnvironment env, ICalibratorTrainer calibratorTrainer, string labelColumn, string scoreColumn, string weightColumn) { Host = env; _calibratorTrainer = calibratorTrainer; if (!string.IsNullOrWhiteSpace(labelColumn)) { LabelColumn = TrainerUtils.MakeBoolScalarLabel(labelColumn); } else { env.CheckParam(!calibratorTrainer.NeedsTraining, nameof(labelColumn), "For trained calibrators, " + nameof(labelColumn) + " must be specified."); } ScoreColumn = TrainerUtils.MakeR4ScalarColumn(scoreColumn); // Do we fanthom this being named anything else (renaming column)? Complete metadata? if (weightColumn != null) { WeightColumn = TrainerUtils.MakeR4ScalarWeightColumn(weightColumn); } }
public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.Columns.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); } if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); } SchemaShape metadata; // In the event that we are transforming something that is of type key, we will get their type of key value // metadata, unless it has none or is not vector in which case we back off to having key values over the item type. if (!col.IsKey || !col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector) { kv = new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, col.ItemType, col.IsKey); } Contracts.AssertValue(kv); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) { metadata = new SchemaShape(new[] { slotMeta, kv }); } else { metadata = new SchemaShape(new[] { kv }); } result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, NumberType.U4, true, metadata); } return(new SchemaShape(result.Values)); }
/// <summary> /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer. /// Used for schema propagation and verification in a pipeline. /// </summary> public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); } if (!col.ItemType.IsStandardScalar()) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); } SchemaShape metadata; // In the event that we are transforming something that is of type key, we will get their type of key value // metadata, unless it has none or is not vector in which case we back off to having key values over the item type. if (!col.IsKey || !col.Annotations.TryFindColumn(AnnotationUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector) { kv = new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, colInfo.TextKeyValues ? TextDataViewType.Instance : col.ItemType, col.IsKey); } Contracts.Assert(kv.IsValid); if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta)) { metadata = new SchemaShape(new[] { slotMeta, kv }); } else { metadata = new SchemaShape(new[] { kv }); } result[colInfo.OutputColumnName] = new SchemaShape.Column(colInfo.OutputColumnName, col.Kind, NumberDataViewType.UInt32, true, metadata); } return(new SchemaShape(result.Values)); }
internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env, string[] featureColumnNames, string labelColumnName = DefaultColumnNames.Label, string exampleWeightColumnName = null) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(LoadName); var args = new Options(); Initialize(env, args); FeatureColumns = new SchemaShape.Column[featureColumnNames.Length]; for (int i = 0; i < featureColumnNames.Length; i++) { FeatureColumns[i] = new SchemaShape.Column(featureColumnNames[i], SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); } LabelColumn = new SchemaShape.Column(labelColumnName, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false); WeightColumn = exampleWeightColumnName != null ? new SchemaShape.Column(exampleWeightColumnName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false) : default; }
private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, string weightColumn, string groupIdColumn, int?numLeaves, int?minDataPerLeaf, double?learningRate, int numBoostRound, Action <LightGbmArguments> advancedSettings) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn)) { Args = new LightGbmArguments(); Args.NumLeaves = numLeaves; Args.MinDataPerLeaf = minDataPerLeaf; Args.LearningRate = learningRate; Args.NumBoostRound = numBoostRound; //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); Args.LabelColumn = label.Name; Args.FeatureColumn = featureColumn; if (weightColumn != null) { Args.WeightColumn = Optional <string> .Explicit(weightColumn); } if (groupIdColumn != null) { Args.GroupIdColumn = Optional <string> .Explicit(groupIdColumn); } InitParallelTraining(); }
internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Options options) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(LoadName); Initialize(env, options); var extraColumnLength = (options.ExtraFeatureColumns != null ? options.ExtraFeatureColumns.Length : 0); // There can be multiple feature columns in FFM, jointly specified by args.FeatureColumnName and args.ExtraFeatureColumns. FeatureColumns = new SchemaShape.Column[1 + extraColumnLength]; // Treat the default feature column as the 1st field. FeatureColumns[0] = new SchemaShape.Column(options.FeatureColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); // Add 2nd, 3rd, and other fields from a FFM-specific argument, args.ExtraFeatureColumns. for (int i = 0; i < extraColumnLength; i++) { FeatureColumns[i + 1] = new SchemaShape.Column(options.ExtraFeatureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false); } LabelColumn = new SchemaShape.Column(options.LabelColumnName, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false); WeightColumn = options.ExampleWeightColumnName != null ? new SchemaShape.Column(options.ExampleWeightColumnName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false) : default; }
private static TOptions ArgsInit(string featureColumn, SchemaShape.Column labelColumn, string weightColumn, float l1Weight, float l2Weight, float optimizationTolerance, int memorySize, bool enforceNoNegativity) { var args = new TOptions { FeatureColumnName = featureColumn, LabelColumnName = labelColumn.Name, ExampleWeightColumnName = weightColumn, L1Regularization = l1Weight, L2Regularization = l2Weight, OptmizationTolerance = optimizationTolerance, IterationsToRemember = memorySize, EnforceNonNegativity = enforceNoNegativity }; args.ExampleWeightColumnName = weightColumn; return(args); }
internal LbfgsTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn, string weightColumn, float l1Weight, float l2Weight, float optimizationTolerance, int memorySize, bool enforceNoNegativity) : this(env, new TArgs { FeatureColumn = featureColumn, LabelColumn = labelColumn.Name, WeightColumn = weightColumn != null ? Optional <string> .Explicit(weightColumn) : Optional <string> .Implicit(DefaultColumnNames.Weight), L1Weight = l1Weight, L2Weight = l2Weight, OptTol = optimizationTolerance, MemorySize = memorySize, EnforceNonNegativity = enforceNoNegativity }, labelColumn) { }
internal LbfgsTrainerBase(IHostEnvironment env, string featureColumn, SchemaShape.Column labelColumn, string weightColumn, float l1Weight, float l2Weight, float optimizationTolerance, int memorySize, bool enforceNoNegativity) : this(env, new TOptions { FeatureColumnName = featureColumn, LabelColumnName = labelColumn.Name, ExampleWeightColumnName = weightColumn, L1Regularization = l1Weight, L2Regularization = l2Weight, OptmizationTolerance = optimizationTolerance, IterationsToRemember = memorySize, EnforceNonNegativity = enforceNoNegativity }, labelColumn) { }
/// <summary> /// Initializing a new instance of <see cref="MatrixFactorizationTrainer"/>. /// </summary> /// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param> /// <param name="labelColumn">The name of the label column.</param> /// <param name="matrixColumnIndexColumnName">The name of the column hosting the matrix's column IDs.</param> /// <param name="matrixRowIndexColumnName">The name of the column hosting the matrix's row IDs.</param> /// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param> /// <param name="context">The <see cref="TrainerEstimatorContext"/> for additional input data to training.</param> public MatrixFactorizationTrainer(IHostEnvironment env, string labelColumn, string matrixColumnIndexColumnName, string matrixRowIndexColumnName, TrainerEstimatorContext context = null, Action <Arguments> advancedSettings = null) : base(env, LoadNameValue) { var args = new Arguments(); advancedSettings?.Invoke(args); _lambda = args.Lambda; _k = args.K; _iter = args.NumIterations; _eta = args.Eta; _threads = args.NumThreads ?? Environment.ProcessorCount; _quiet = args.Quiet; _doNmf = args.NonNegative; Info = new TrainerInfo(normalization: false, caching: false); Context = context; LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); MatrixColumnIndexColumn = new SchemaShape.Column(matrixColumnIndexColumnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); MatrixRowIndexColumn = new SchemaShape.Column(matrixRowIndexColumnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true); }
/// <summary> /// Initializing a new instance of <see cref="FieldAwareFactorizationMachineTrainer"/>. /// </summary> /// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param> /// <param name="labelColumn">The name of the label column.</param> /// <param name="featureColumns">The name of column hosting the features.</param> /// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param> /// <param name="weightColumn">The name of the weight column.</param> /// <param name="context">The <see cref="TrainerEstimatorContext"/> for additional input data to training.</param> public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, string labelColumn, string[] featureColumns, string weightColumn = null, TrainerEstimatorContext context = null, Action <Arguments> advancedSettings = null) : base(env, LoadName) { var args = new Arguments(); advancedSettings?.Invoke(args); Initialize(env, args); Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); Context = context; FeatureColumns = new SchemaShape.Column[featureColumns.Length]; for (int i = 0; i < featureColumns.Length; i++) { FeatureColumns[i] = new SchemaShape.Column(featureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); WeightColumn = weightColumn != null ? new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : null; }
private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn, string weightColumn, float l1Weight, float l2Weight, float optimizationTolerance, int memorySize, bool enforceNoNegativity) { var args = new TArgs { FeatureColumn = featureColumn, LabelColumn = labelColumn.Name, WeightColumn = weightColumn ?? Optional <string> .Explicit(weightColumn), L1Weight = l1Weight, L2Weight = l2Weight, OptTol = optimizationTolerance, MemorySize = memorySize, EnforceNonNegativity = enforceNoNegativity }; args.WeightColumn = weightColumn; return(args); }
private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn, string weightColumn = null, string groupIdColumn = null, Action <LightGbmArguments> advancedSettings = null) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) { Args = new LightGbmArguments(); //apply the advanced args, if the user supplied any advancedSettings?.Invoke(Args); Args.LabelColumn = label.Name; Args.FeatureColumn = featureColumn; if (weightColumn != null) { Args.WeightColumn = weightColumn; } if (groupIdColumn != null) { Args.GroupIdColumn = groupIdColumn; } InitParallelTraining(); }
private SchemaShape.Column CheckInputsAndMakeColumn( SchemaShape inputSchema, string name, string[] sources) { _host.AssertNonEmpty(sources); var cols = new SchemaShape.Column[sources.Length]; // If any input is a var vector, so is the output. bool varVector = false; // If any input is not normalized, the output is not normalized. bool isNormalized = true; // If any input has categorical indices, so will the output. bool hasCategoricals = false; // If any is scalar or had slot names, then the output will have slot names. bool hasSlotNames = false; // We will get the item type from the first column. ColumnType itemType = null; for (int i = 0; i < sources.Length; ++i) { if (!inputSchema.TryFindColumn(sources[i], out var col)) { throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", sources[i]); } if (i == 0) { itemType = col.ItemType; } // For the sake of an estimator I am going to have a hard policy of no keys. // Appending keys makes no real sense anyway. if (col.IsKey) { throw _host.Except($"Column '{sources[i]}' is key." + $"Concatenation of keys is unsupported."); } if (!col.ItemType.Equals(itemType)) { throw _host.Except($"Column '{sources[i]}' has values of {col.ItemType}" + $"which is not the same as earlier observed type of {itemType}."); } varVector |= col.Kind == SchemaShape.Column.VectorKind.VariableVector; isNormalized &= col.IsNormalized(); hasCategoricals |= HasCategoricals(col); hasSlotNames |= col.Kind == SchemaShape.Column.VectorKind.Scalar || col.HasSlotNames(); } var vecKind = varVector ? SchemaShape.Column.VectorKind.VariableVector : SchemaShape.Column.VectorKind.Vector; List <SchemaShape.Column> meta = new List <SchemaShape.Column>(); if (isNormalized) { meta.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); } if (hasCategoricals) { meta.Add(new SchemaShape.Column(MetadataUtils.Kinds.CategoricalSlotRanges, SchemaShape.Column.VectorKind.Vector, NumberType.I4, false)); } if (hasSlotNames) { meta.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); } return(new SchemaShape.Column(name, vecKind, itemType, false, new SchemaShape(meta))); }
protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label) : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights)) { Contracts.CheckValue(args, nameof(args)); Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive); Contracts.CheckUserArg(args.InitWtsDiameter >= 0, nameof(args.InitWtsDiameter), UserErrorNonNegative); Contracts.CheckUserArg(args.StreamingCacheSize > 0, nameof(args.StreamingCacheSize), UserErrorPositive); Args = args; Name = name; // REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue. Info = new TrainerInfo(calibration: NeedCalibration, supportIncrementalTrain: true); }
internal static bool HasKeyValues(this SchemaShape.Column col) { return(col.Metadata.TryFindColumn(Kinds.KeyValues, out var metaCol) && metaCol.Kind == SchemaShape.Column.VectorKind.Vector && metaCol.ItemType.IsText); }
public StochasticTrainerBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = default) : base(host, feature, label, weight) { }
protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn, string weightColumn = null, string groupIdColumn = null, Action <TArgs> advancedSettings = null) : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings) { }
protected AveragedLinearTrainer(AveragedLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label) : base(args, env, name, label) { Contracts.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), UserErrorPositive); Contracts.CheckUserArg(!args.ResetWeightsAfterXExamples.HasValue || args.ResetWeightsAfterXExamples > 0, nameof(args.ResetWeightsAfterXExamples), UserErrorPositive); // Weights are scaled down by 2 * L2 regularization on each update step, so 0.5 would scale all weights to 0, which is not sensible. Contracts.CheckUserArg(0 <= args.L2RegularizerWeight && args.L2RegularizerWeight < 0.5, nameof(args.L2RegularizerWeight), "must be in range [0, 0.5)"); Contracts.CheckUserArg(args.RecencyGain >= 0, nameof(args.RecencyGain), UserErrorNonNegative); Contracts.CheckUserArg(args.AveragedTolerance >= 0, nameof(args.AveragedTolerance), UserErrorNonNegative); // Verify user didn't specify parameters that conflict Contracts.Check(!args.DoLazyUpdates || !args.RecencyGainMulti && args.RecencyGain == 0, "Cannot have both recency gain and lazy updates."); Args = args; }