public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var columnDictionary = inputSchema.ToDictionary(x => x.Name);

            for (int i = 0; i < _columns.Length; i++)
            {
                for (int j = 0; j < _columns[i].InputColumnNames.Length; j++)
                {
                    if (!inputSchema.TryFindColumn(_columns[i].InputColumnNames[j], out var inputCol))
                    {
                        throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[i].InputColumnNames[j]);
                    }
                }

                // Make sure there is at most one vector valued source column.
                var inputTypes = new DataViewType[_columns[i].InputColumnNames.Length];
                var ivec       = FindVectorInputColumn(_host, _columns[i].InputColumnNames, inputSchema, inputTypes);
                var node       = ParseAndBindLambda(_host, _columns[i].Expression, ivec, inputTypes, out var perm);

                var typeRes = node.ResultType;
                _host.Assert(typeRes is PrimitiveDataViewType);

                // If one of the input columns is a vector column, we pass through the slot names metadata.
                SchemaShape.Column.VectorKind outputVectorKind;
                var metadata = new List <SchemaShape.Column>();
                if (ivec == -1)
                {
                    outputVectorKind = SchemaShape.Column.VectorKind.Scalar;
                }
                else
                {
                    inputSchema.TryFindColumn(_columns[i].InputColumnNames[ivec], out var vectorCol);
                    outputVectorKind = vectorCol.Kind;
                    if (vectorCol.HasSlotNames())
                    {
                        var b = vectorCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotNames);
                        _host.Assert(b);
                        metadata.Add(slotNames);
                    }
                }
                var outputSchemaShapeColumn = new SchemaShape.Column(_columns[i].Name, outputVectorKind, typeRes, false, new SchemaShape(metadata));
                columnDictionary[_columns[i].Name] = outputSchemaShapeColumn;
            }
            return(new SchemaShape(columnDictionary.Values));
        }
示例#2
0
        protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
        {
            Contracts.Assert(labelCol.IsValid);

            Action error =
                () => throw Host.ExceptSchemaMismatch(nameof(labelCol), RoleMappedSchema.ColumnRole.Label.Value, labelCol.Name, "BL, R8, R4 or a Key", labelCol.GetTypeString());

            if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
            {
                error();
            }

            if (!labelCol.IsKey && labelCol.ItemType != NumberType.R4 && labelCol.ItemType != NumberType.R8 && !(labelCol.ItemType is BoolType))
            {
                error();
            }
        }
        private protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
        {
            Contracts.Assert(labelCol.IsValid);

            Action error =
                () => throw Host.ExceptSchemaMismatch(nameof(labelCol), "label", labelCol.Name, "float, double, bool or KeyType", labelCol.GetTypeString());

            if (labelCol.Kind != SchemaShape.Column.VectorKind.Scalar)
            {
                error();
            }

            if (!labelCol.IsKey && labelCol.ItemType != NumberDataViewType.Single && labelCol.ItemType != NumberDataViewType.Double && !(labelCol.ItemType is BooleanDataViewType))
            {
                error();
            }
        }
示例#4
0
 protected BoostingFastTreeTrainerBase(IHostEnvironment env,
                                       SchemaShape.Column label,
                                       string featureColumn,
                                       string weightColumn,
                                       string groupIdColumn,
                                       int numLeaves,
                                       int numTrees,
                                       int minDocumentsInLeafs,
                                       double learningRate,
                                       Action <TArgs> advancedSettings)
     : base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, advancedSettings)
 {
     if (Args.LearningRates != learningRate)
     {
         using (var ch = Host.Start($"Setting learning rate to: {learningRate} as supplied in the direct arguments."))
             Args.LearningRates = learningRate;
     }
 }
示例#5
0
        private static ColumnType MakeColumnType(SchemaShape.Column inputCol)
        {
            ColumnType curType = inputCol.ItemType;

            if (inputCol.IsKey)
            {
                curType = new KeyType(curType.AsPrimitive.RawKind, 0, AllKeySizes);
            }
            if (inputCol.Kind == SchemaShape.Column.VectorKind.VariableVector)
            {
                curType = new VectorType(curType.AsPrimitive, 0);
            }
            else if (inputCol.Kind == SchemaShape.Column.VectorKind.Vector)
            {
                curType = new VectorType(curType.AsPrimitive, AllVectorSizes);
            }
            return(curType);
        }
示例#6
0
        private static ColumnType MakeColumnType(SchemaShape.Column column)
        {
            ColumnType curType = column.ItemType;

            if (column.IsKey)
            {
                curType = new KeyType(((PrimitiveType)curType).RawKind, 0, AllKeySizes);
            }
            if (column.Kind == SchemaShape.Column.VectorKind.VariableVector)
            {
                curType = new VectorType((PrimitiveType)curType, 0);
            }
            else if (column.Kind == SchemaShape.Column.VectorKind.Vector)
            {
                curType = new VectorType((PrimitiveType)curType, AllVectorSizes);
            }
            return(curType);
        }
示例#7
0
        internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn)
            : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
                   labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
        {
            Host.CheckValue(args, nameof(args));
            Args = args;

            Contracts.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null,
                                   nameof(Args.NumThreads), "numThreads must be positive (or empty for default)");
            Contracts.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative");
            Contracts.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative");
            Contracts.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive");
            Contracts.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive");
            Contracts.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive");
            Contracts.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative");
            Contracts.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative");

            L2Weight      = Args.L2Weight;
            L1Weight      = Args.L1Weight;
            OptTol        = Args.OptTol;
            MemorySize    = Args.MemorySize;
            MaxIterations = Args.MaxIterations;
            SgdInitializationTolerance = Args.SgdInitializationTolerance;
            Quiet                = Args.Quiet;
            InitWtsDiameter      = Args.InitWtsDiameter;
            UseThreads           = Args.UseThreads;
            NumThreads           = Args.NumThreads;
            DenseOptimizer       = Args.DenseOptimizer;
            EnforceNonNegativity = Args.EnforceNonNegativity;

            if (EnforceNonNegativity && ShowTrainingStats)
            {
                ShowTrainingStats = false;
                using (var ch = Host.Start("Initialization"))
                {
                    ch.Warning("The training statistics cannot be computed with non-negativity constraint.");
                    ch.Done();
                }
            }

            ShowTrainingStats = false;
            _srcPredictor     = default;
        }
示例#8
0
        /// <summary>
        /// Initializes the <see cref="MetaMulticlassClassificationTrainer{TTransformer, TModel}"/> from the <see cref="OptionsBase"/> class.
        /// </summary>
        /// <param name="env">The private instance of the <see cref="IHostEnvironment"/>.</param>
        /// <param name="options">The legacy arguments <see cref="OptionsBase"/>class.</param>
        /// <param name="name">The component name.</param>
        /// <param name="labelColumn">The label column for the metalinear trainer and the binary trainer.</param>
        /// <param name="singleEstimator">The binary estimator.</param>
        /// <param name="calibrator">The calibrator. If a calibrator is not explicitly provided, it will default to <see cref="PlattCalibratorTrainer"/></param>
        internal MetaMulticlassClassificationTrainer(IHostEnvironment env, OptionsBase options, string name, string labelColumn = null,
            TScalarTrainer singleEstimator = null, ICalibratorTrainer calibrator = null)
        {
            Host = Contracts.CheckRef(env, nameof(env)).Register(name);
            Host.CheckValue(options, nameof(options));
            Args = options;

            if (labelColumn != null)
                LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true);

            Trainer = singleEstimator ?? CreateTrainer();

            Calibrator = calibrator ?? new PlattCalibratorTrainer(env);
            if (options.Calibrator != null)
                Calibrator = options.Calibrator.CreateComponent(Host);

            // Regarding caching, no matter what the internal predictor, we're performing many passes
            // simply by virtue of this being a meta-trainer, so we will still cache.
            Info = new TrainerInfo(normalization: Trainer.Info.NeedNormalization);
        }
示例#9
0
        /// <summary>
        /// Schema propagation for transformers. Returns the output schema of the data, if
        /// the input schema is like the one provided.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));

            void CheckColumnsCompatible(SchemaShape.Column cachedColumn, string columnRole)
            {
                if (!inputSchema.TryFindColumn(cachedColumn.Name, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(col), columnRole, cachedColumn.Name);
                }

                if (!cachedColumn.IsCompatibleWith(col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, cachedColumn.Name,
                                                     cachedColumn.GetTypeString(), col.GetTypeString());
                }
            }

            // Check if label column is good.
            var labelColumn = new SchemaShape.Column(LabelName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false);

            CheckColumnsCompatible(labelColumn, "label");

            // Check if columns of matrix's row and column indexes are good. Note that column of IDataView and column of matrix are two different things.
            var matrixColumnIndexColumn = new SchemaShape.Column(MatrixColumnIndexName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true);
            var matrixRowIndexColumn    = new SchemaShape.Column(MatrixRowIndexName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true);

            CheckColumnsCompatible(matrixColumnIndexColumn, "matrixColumnIndex");
            CheckColumnsCompatible(matrixRowIndexColumn, "matrixRowIndex");

            // Input columns just pass through so that output column dictionary contains all input columns.
            var outColumns = inputSchema.ToDictionary(x => x.Name);

            // Add columns produced by this estimator.
            foreach (var col in GetOutputColumnsCore(inputSchema))
            {
                outColumns[col.Name] = col;
            }

            return(new SchemaShape(outColumns.Values));
        }
示例#10
0
        internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env,
                                                       string[] featureColumns,
                                                       string labelColumn = DefaultColumnNames.Label,
                                                       string weights     = null)
            : base(env, LoadName)
        {
            var args = new Options();

            Initialize(env, args);
            Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true);

            FeatureColumns = new SchemaShape.Column[featureColumns.Length];

            for (int i = 0; i < featureColumns.Length; i++)
            {
                FeatureColumns[i] = new SchemaShape.Column(featureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
            }

            LabelColumn  = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
            WeightColumn = weights != null ? new SchemaShape.Column(weights, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default;
        }
        /// <summary>
        /// Legacy constructor initializing a new instance of <see cref="FieldAwareFactorizationMachineTrainer"/> through the legacy
        /// <see cref="Arguments"/> class.
        /// </summary>
        /// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
        /// <param name="args">An instance of the legacy <see cref="Arguments"/> to apply advanced parameters to the algorithm.</param>
        public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments args)
            : base(env, LoadName)
        {
            Initialize(env, args);
            Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true);

            // There can be multiple feature columns in FFM, jointly specified by args.FeatureColumn and args.ExtraFeatureColumns.
            FeatureColumns = new SchemaShape.Column[1 + args.ExtraFeatureColumns.Length];

            // Treat the default feature column as the 1st field.
            FeatureColumns[0] = new SchemaShape.Column(args.FeatureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);

            // Add 2nd, 3rd, and other fields from a FFM-specific argument, args.ExtraFeatureColumns.
            for (int i = 0; args.ExtraFeatureColumns != null && i < args.ExtraFeatureColumns.Length; i++)
            {
                FeatureColumns[i + 1] = new SchemaShape.Column(args.ExtraFeatureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
            }

            LabelColumn  = new SchemaShape.Column(args.LabelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
            WeightColumn = args.WeightColumn.IsExplicit ? new SchemaShape.Column(args.WeightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default;
        }
        private protected CalibratorEstimatorBase(IHostEnvironment env,
                                                  ICalibratorTrainer calibratorTrainer, string labelColumn, string scoreColumn, string weightColumn)
        {
            Host = env;
            _calibratorTrainer = calibratorTrainer;

            if (!string.IsNullOrWhiteSpace(labelColumn))
            {
                LabelColumn = TrainerUtils.MakeBoolScalarLabel(labelColumn);
            }
            else
            {
                env.CheckParam(!calibratorTrainer.NeedsTraining, nameof(labelColumn), "For trained calibrators, " + nameof(labelColumn) + " must be specified.");
            }
            ScoreColumn = TrainerUtils.MakeR4ScalarColumn(scoreColumn); // Do we fanthom this being named anything else (renaming column)? Complete metadata?

            if (weightColumn != null)
            {
                WeightColumn = TrainerUtils.MakeR4ScalarWeightColumn(weightColumn);
            }
        }
示例#13
0
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.Columns.ToDictionary(x => x.Name);

            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
                }

                if ((col.ItemType.ItemType.RawKind == default) || !(col.ItemType.IsVector || col.ItemType.IsPrimitive))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
                }
                SchemaShape metadata;

                // In the event that we are transforming something that is of type key, we will get their type of key value
                // metadata, unless it has none or is not vector in which case we back off to having key values over the item type.
                if (!col.IsKey || !col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector)
                {
                    kv = new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
                                                col.ItemType, col.IsKey);
                }
                Contracts.AssertValue(kv);

                if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta))
                {
                    metadata = new SchemaShape(new[] { slotMeta, kv });
                }
                else
                {
                    metadata = new SchemaShape(new[] { kv });
                }
                result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, NumberType.U4, true, metadata);
            }

            return(new SchemaShape(result.Values));
        }
示例#14
0
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);

            foreach (var colInfo in _columns)
            {
                if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
                }

                if (!col.ItemType.IsStandardScalar())
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName);
                }
                SchemaShape metadata;

                // In the event that we are transforming something that is of type key, we will get their type of key value
                // metadata, unless it has none or is not vector in which case we back off to having key values over the item type.
                if (!col.IsKey || !col.Annotations.TryFindColumn(AnnotationUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector)
                {
                    kv = new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
                                                colInfo.TextKeyValues ? TextDataViewType.Instance : col.ItemType, col.IsKey);
                }
                Contracts.Assert(kv.IsValid);

                if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
                {
                    metadata = new SchemaShape(new[] { slotMeta, kv });
                }
                else
                {
                    metadata = new SchemaShape(new[] { kv });
                }
                result[colInfo.OutputColumnName] = new SchemaShape.Column(colInfo.OutputColumnName, col.Kind, NumberDataViewType.UInt32, true, metadata);
            }

            return(new SchemaShape(result.Values));
        }
示例#15
0
        internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env,
                                                       string[] featureColumnNames,
                                                       string labelColumnName         = DefaultColumnNames.Label,
                                                       string exampleWeightColumnName = null)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(LoadName);

            var args = new Options();

            Initialize(env, args);

            FeatureColumns = new SchemaShape.Column[featureColumnNames.Length];

            for (int i = 0; i < featureColumnNames.Length; i++)
            {
                FeatureColumns[i] = new SchemaShape.Column(featureColumnNames[i], SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
            }

            LabelColumn  = new SchemaShape.Column(labelColumnName, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false);
            WeightColumn = exampleWeightColumnName != null ? new SchemaShape.Column(exampleWeightColumnName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false) : default;
        }
        private protected LightGbmTrainerBase(IHostEnvironment env,
                                              string name,
                                              SchemaShape.Column label,
                                              string featureColumn,
                                              string weightColumn,
                                              string groupIdColumn,
                                              int?numLeaves,
                                              int?minDataPerLeaf,
                                              double?learningRate,
                                              int numBoostRound,
                                              Action <LightGbmArguments> advancedSettings)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
        {
            Args = new LightGbmArguments();

            Args.NumLeaves      = numLeaves;
            Args.MinDataPerLeaf = minDataPerLeaf;
            Args.LearningRate   = learningRate;
            Args.NumBoostRound  = numBoostRound;

            //apply the advanced args, if the user supplied any
            advancedSettings?.Invoke(Args);

            Args.LabelColumn   = label.Name;
            Args.FeatureColumn = featureColumn;

            if (weightColumn != null)
            {
                Args.WeightColumn = Optional <string> .Explicit(weightColumn);
            }

            if (groupIdColumn != null)
            {
                Args.GroupIdColumn = Optional <string> .Explicit(groupIdColumn);
            }

            InitParallelTraining();
        }
示例#17
0
        internal FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Options options)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(LoadName);

            Initialize(env, options);
            var extraColumnLength = (options.ExtraFeatureColumns != null ? options.ExtraFeatureColumns.Length : 0);

            // There can be multiple feature columns in FFM, jointly specified by args.FeatureColumnName and args.ExtraFeatureColumns.
            FeatureColumns = new SchemaShape.Column[1 + extraColumnLength];

            // Treat the default feature column as the 1st field.
            FeatureColumns[0] = new SchemaShape.Column(options.FeatureColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);

            // Add 2nd, 3rd, and other fields from a FFM-specific argument, args.ExtraFeatureColumns.
            for (int i = 0; i < extraColumnLength; i++)
            {
                FeatureColumns[i + 1] = new SchemaShape.Column(options.ExtraFeatureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
            }

            LabelColumn  = new SchemaShape.Column(options.LabelColumnName, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false);
            WeightColumn = options.ExampleWeightColumnName != null ? new SchemaShape.Column(options.ExampleWeightColumnName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false) : default;
        }
示例#18
0
        private static TOptions ArgsInit(string featureColumn, SchemaShape.Column labelColumn,
                                         string weightColumn,
                                         float l1Weight,
                                         float l2Weight,
                                         float optimizationTolerance,
                                         int memorySize,
                                         bool enforceNoNegativity)
        {
            var args = new TOptions
            {
                FeatureColumnName       = featureColumn,
                LabelColumnName         = labelColumn.Name,
                ExampleWeightColumnName = weightColumn,
                L1Regularization        = l1Weight,
                L2Regularization        = l2Weight,
                OptmizationTolerance    = optimizationTolerance,
                IterationsToRemember    = memorySize,
                EnforceNonNegativity    = enforceNoNegativity
            };

            args.ExampleWeightColumnName = weightColumn;
            return(args);
        }
示例#19
0
 internal LbfgsTrainerBase(IHostEnvironment env,
                           string featureColumn,
                           SchemaShape.Column labelColumn,
                           string weightColumn,
                           float l1Weight,
                           float l2Weight,
                           float optimizationTolerance,
                           int memorySize,
                           bool enforceNoNegativity)
     : this(env, new TArgs
 {
     FeatureColumn = featureColumn,
     LabelColumn = labelColumn.Name,
     WeightColumn = weightColumn != null ? Optional <string> .Explicit(weightColumn) : Optional <string> .Implicit(DefaultColumnNames.Weight),
     L1Weight = l1Weight,
     L2Weight = l2Weight,
     OptTol = optimizationTolerance,
     MemorySize = memorySize,
     EnforceNonNegativity = enforceNoNegativity
 },
            labelColumn)
 {
 }
示例#20
0
 internal LbfgsTrainerBase(IHostEnvironment env,
                           string featureColumn,
                           SchemaShape.Column labelColumn,
                           string weightColumn,
                           float l1Weight,
                           float l2Weight,
                           float optimizationTolerance,
                           int memorySize,
                           bool enforceNoNegativity)
     : this(env, new TOptions
 {
     FeatureColumnName = featureColumn,
     LabelColumnName = labelColumn.Name,
     ExampleWeightColumnName = weightColumn,
     L1Regularization = l1Weight,
     L2Regularization = l2Weight,
     OptmizationTolerance = optimizationTolerance,
     IterationsToRemember = memorySize,
     EnforceNonNegativity = enforceNoNegativity
 },
            labelColumn)
 {
 }
示例#21
0
        /// <summary>
        /// Initializing a new instance of <see cref="MatrixFactorizationTrainer"/>.
        /// </summary>
        /// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
        /// <param name="labelColumn">The name of the label column.</param>
        /// <param name="matrixColumnIndexColumnName">The name of the column hosting the matrix's column IDs.</param>
        /// <param name="matrixRowIndexColumnName">The name of the column hosting the matrix's row IDs.</param>
        /// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
        /// <param name="context">The <see cref="TrainerEstimatorContext"/> for additional input data to training.</param>
        public MatrixFactorizationTrainer(IHostEnvironment env, string labelColumn, string matrixColumnIndexColumnName, string matrixRowIndexColumnName,
                                          TrainerEstimatorContext context = null, Action <Arguments> advancedSettings = null)
            : base(env, LoadNameValue)
        {
            var args = new Arguments();

            advancedSettings?.Invoke(args);

            _lambda  = args.Lambda;
            _k       = args.K;
            _iter    = args.NumIterations;
            _eta     = args.Eta;
            _threads = args.NumThreads ?? Environment.ProcessorCount;
            _quiet   = args.Quiet;
            _doNmf   = args.NonNegative;

            Info    = new TrainerInfo(normalization: false, caching: false);
            Context = context;

            LabelColumn             = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
            MatrixColumnIndexColumn = new SchemaShape.Column(matrixColumnIndexColumnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
            MatrixRowIndexColumn    = new SchemaShape.Column(matrixRowIndexColumnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
        }
        /// <summary>
        /// Initializing a new instance of <see cref="FieldAwareFactorizationMachineTrainer"/>.
        /// </summary>
        /// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
        /// <param name="labelColumn">The name of the label column.</param>
        /// <param name="featureColumns">The name of  column hosting the features.</param>
        /// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
        /// <param name="weightColumn">The name of the weight column.</param>
        /// <param name="context">The <see cref="TrainerEstimatorContext"/> for additional input data to training.</param>
        public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, string labelColumn, string[] featureColumns,
                                                     string weightColumn = null, TrainerEstimatorContext context = null, Action <Arguments> advancedSettings = null)
            : base(env, LoadName)
        {
            var args = new Arguments();

            advancedSettings?.Invoke(args);

            Initialize(env, args);
            Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true);

            Context = context;

            FeatureColumns = new SchemaShape.Column[featureColumns.Length];

            for (int i = 0; i < featureColumns.Length; i++)
            {
                FeatureColumns[i] = new SchemaShape.Column(featureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
            }

            LabelColumn  = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
            WeightColumn = weightColumn != null ? new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : null;
        }
        private static TArgs ArgsInit(string featureColumn, SchemaShape.Column labelColumn,
                                      string weightColumn,
                                      float l1Weight,
                                      float l2Weight,
                                      float optimizationTolerance,
                                      int memorySize,
                                      bool enforceNoNegativity)
        {
            var args = new TArgs
            {
                FeatureColumn        = featureColumn,
                LabelColumn          = labelColumn.Name,
                WeightColumn         = weightColumn ?? Optional <string> .Explicit(weightColumn),
                L1Weight             = l1Weight,
                L2Weight             = l2Weight,
                OptTol               = optimizationTolerance,
                MemorySize           = memorySize,
                EnforceNonNegativity = enforceNoNegativity
            };

            args.WeightColumn = weightColumn;
            return(args);
        }
示例#24
0
        private protected LightGbmTrainerBase(IHostEnvironment env, string name, SchemaShape.Column label, string featureColumn,
                                              string weightColumn = null, string groupIdColumn = null, Action <LightGbmArguments> advancedSettings = null)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
        {
            Args = new LightGbmArguments();

            //apply the advanced args, if the user supplied any
            advancedSettings?.Invoke(Args);
            Args.LabelColumn   = label.Name;
            Args.FeatureColumn = featureColumn;

            if (weightColumn != null)
            {
                Args.WeightColumn = weightColumn;
            }

            if (groupIdColumn != null)
            {
                Args.GroupIdColumn = groupIdColumn;
            }

            InitParallelTraining();
        }
        private SchemaShape.Column CheckInputsAndMakeColumn(
            SchemaShape inputSchema, string name, string[] sources)
        {
            _host.AssertNonEmpty(sources);

            var cols = new SchemaShape.Column[sources.Length];
            // If any input is a var vector, so is the output.
            bool varVector = false;
            // If any input is not normalized, the output is not normalized.
            bool isNormalized = true;
            // If any input has categorical indices, so will the output.
            bool hasCategoricals = false;
            // If any is scalar or had slot names, then the output will have slot names.
            bool hasSlotNames = false;

            // We will get the item type from the first column.
            ColumnType itemType = null;

            for (int i = 0; i < sources.Length; ++i)
            {
                if (!inputSchema.TryFindColumn(sources[i], out var col))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", sources[i]);
                }
                if (i == 0)
                {
                    itemType = col.ItemType;
                }
                // For the sake of an estimator I am going to have a hard policy of no keys.
                // Appending keys makes no real sense anyway.
                if (col.IsKey)
                {
                    throw _host.Except($"Column '{sources[i]}' is key." +
                                       $"Concatenation of keys is unsupported.");
                }
                if (!col.ItemType.Equals(itemType))
                {
                    throw _host.Except($"Column '{sources[i]}' has values of {col.ItemType}" +
                                       $"which is not the same as earlier observed type of {itemType}.");
                }
                varVector       |= col.Kind == SchemaShape.Column.VectorKind.VariableVector;
                isNormalized    &= col.IsNormalized();
                hasCategoricals |= HasCategoricals(col);
                hasSlotNames    |= col.Kind == SchemaShape.Column.VectorKind.Scalar || col.HasSlotNames();
            }
            var vecKind = varVector ? SchemaShape.Column.VectorKind.VariableVector :
                          SchemaShape.Column.VectorKind.Vector;

            List <SchemaShape.Column> meta = new List <SchemaShape.Column>();

            if (isNormalized)
            {
                meta.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false));
            }
            if (hasCategoricals)
            {
                meta.Add(new SchemaShape.Column(MetadataUtils.Kinds.CategoricalSlotRanges, SchemaShape.Column.VectorKind.Vector, NumberType.I4, false));
            }
            if (hasSlotNames)
            {
                meta.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false));
            }

            return(new SchemaShape.Column(name, vecKind, itemType, false, new SchemaShape(meta)));
        }
        protected OnlineLinearTrainer(OnlineLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
            : base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.InitialWeights))
        {
            Contracts.CheckValue(args, nameof(args));
            Contracts.CheckUserArg(args.NumIterations > 0, nameof(args.NumIterations), UserErrorPositive);
            Contracts.CheckUserArg(args.InitWtsDiameter >= 0, nameof(args.InitWtsDiameter), UserErrorNonNegative);
            Contracts.CheckUserArg(args.StreamingCacheSize > 0, nameof(args.StreamingCacheSize), UserErrorPositive);

            Args = args;
            Name = name;
            // REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue.
            Info = new TrainerInfo(calibration: NeedCalibration, supportIncrementalTrain: true);
        }
 internal static bool HasKeyValues(this SchemaShape.Column col)
 {
     return(col.Metadata.TryFindColumn(Kinds.KeyValues, out var metaCol) &&
            metaCol.Kind == SchemaShape.Column.VectorKind.Vector &&
            metaCol.ItemType.IsText);
 }
 public StochasticTrainerBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = default)
     : base(host, feature, label, weight)
 {
 }
 protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
                                       string weightColumn = null, string groupIdColumn = null, Action <TArgs> advancedSettings = null)
     : base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings)
 {
 }
示例#30
0
        protected AveragedLinearTrainer(AveragedLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label)
            : base(args, env, name, label)
        {
            Contracts.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), UserErrorPositive);
            Contracts.CheckUserArg(!args.ResetWeightsAfterXExamples.HasValue || args.ResetWeightsAfterXExamples > 0, nameof(args.ResetWeightsAfterXExamples), UserErrorPositive);

            // Weights are scaled down by 2 * L2 regularization on each update step, so 0.5 would scale all weights to 0, which is not sensible.
            Contracts.CheckUserArg(0 <= args.L2RegularizerWeight && args.L2RegularizerWeight < 0.5, nameof(args.L2RegularizerWeight), "must be in range [0, 0.5)");
            Contracts.CheckUserArg(args.RecencyGain >= 0, nameof(args.RecencyGain), UserErrorNonNegative);
            Contracts.CheckUserArg(args.AveragedTolerance >= 0, nameof(args.AveragedTolerance), UserErrorNonNegative);
            // Verify user didn't specify parameters that conflict
            Contracts.Check(!args.DoLazyUpdates || !args.RecencyGainMulti && args.RecencyGain == 0, "Cannot have both recency gain and lazy updates.");

            Args = args;
        }