Exemplo n.º 1
0
        public TransposeLoader(IHostEnvironment env, Arguments args, IMultiStreamSource file)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(LoadName);
            _host.CheckValue(args, nameof(args));
            _host.CheckValue(file, nameof(file));
            _host.Check(file.Count == 1, "Transposed loader accepts a single file only");

            _threads = args.Threads ?? 0;
            if (_threads < 0)
            {
                _threads = 0;
            }

            _file = file;
            using (Stream stream = _file.Open(0))
                using (BinaryReader reader = new BinaryReader(stream))
                {
                    _header = InitHeader(reader);
                    reader.Seek(_header.SubIdvTableOffset);
                    _schemaEntry = new SubIdvEntry.SchemaSubIdv(this, reader);
                    _entries     = new SubIdvEntry.TransposedSubIdv[_header.ColumnCount];
                    for (int c = 0; c < _entries.Length; ++c)
                    {
                        _entries[c] = new SubIdvEntry.TransposedSubIdv(this, reader, c);
                    }
                    _schema = new SchemaImpl(this);
                    if (!HasRowData)
                    {
                        _colTransposers     = new Transposer[_header.ColumnCount];
                        _colTransposersLock = new object();
                    }
                }
        }
Exemplo n.º 2
0
        public ResampleTransform(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, "env");
            _host = env.Register("ResampleTransform");
            _host.CheckValue(args, "args");                 // Checks values are valid.
            _host.CheckValue(input, "input");
            _host.Check(args.lambda > 0, "lambda must be > 0");
            _input        = input;
            _args         = args;
            _cacheReplica = null;

            if (!string.IsNullOrEmpty(_args.column))
            {
                int index;
                if (!_input.Schema.TryGetColumnIndex(_args.column, out index))
                {
                    throw _host.Except("Unable to find column '{0}' in\n'{1}'.", _args.column, SchemaHelper.ToString(_input.Schema));
                }
                if (string.IsNullOrEmpty(_args.classValue))
                {
                    throw _host.Except("Class value cannot be null.");
                }
            }

            _transform = CreateTemplatedTransform();
        }
            public void Read(ModelLoadContext ctx, IHost host)
            {
                string sr = ctx.Reader.ReadString();

                host.CheckValue(sr, "columns");
                columns = sr.Split(',').Where(c => !string.IsNullOrEmpty(c)).ToArray();
                if (columns.Length == 0)
                {
                    columns = null;
                }

                sr = ctx.Reader.ReadString();
                host.CheckValue(sr, "hists");
                hists = sr.Split(',').Where(c => !string.IsNullOrEmpty(c)).ToArray();
                if (hists.Length == 0)
                {
                    hists = null;
                }

                saveInFile = ctx.Reader.ReadString();
                int nb = ctx.Reader.ReadInt32();

                host.Check(nb == 0 || nb == 1, "passThrough");
                passThrough     = nb == 1;
                showSchema      = ctx.Reader.ReadInt32() == 1;
                dimension       = ctx.Reader.ReadInt32() == 1;
                oneRowPerColumn = ctx.Reader.ReadInt32() == 1;
                jsonFormat      = ctx.Reader.ReadInt32() == 1;
                name            = ctx.Reader.ReadString();
            }
Exemplo n.º 4
0
        private TransposeLoader(IHost host, ModelLoadContext ctx, IMultiStreamSource file)
        {
            Contracts.CheckValue(host, nameof(host));
            _host = host;
            _host.CheckValue(file, nameof(file));
            _host.Check(file.Count == 1, "Transposed loader accepts a single file only");

            // *** Binary format **
            // int: Number of threads if explicitly defined, or 0 if the
            //      number of threads was automatically determined

            _threads = ctx.Reader.ReadInt32();
            _host.CheckDecode(_threads >= 0);

            // Dedupe code somehow?
            _file = file;
            using (Stream stream = _file.Open(0))
                using (BinaryReader reader = new BinaryReader(stream))
                {
                    _header = InitHeader(reader);
                    reader.Seek(_header.SubIdvTableOffset);
                    _schemaEntry = new SubIdvEntry.SchemaSubIdv(this, reader);
                    _entries     = new SubIdvEntry.TransposedSubIdv[_header.ColumnCount];
                    for (int c = 0; c < _entries.Length; ++c)
                    {
                        _entries[c] = new SubIdvEntry.TransposedSubIdv(this, reader, c);
                    }
                    if (!HasRowData)
                    {
                        _colTransposers     = new Transposer[_header.ColumnCount];
                        _colTransposersLock = new object();
                    }
                }
        }
Exemplo n.º 5
0
        protected string GetHashKey(long transformsBitMask, RecipeInference.SuggestedRecipe.SuggestedLearner learner)
        {
            var learnerName = learner.ToString();

            Host.Check(!string.IsNullOrEmpty(learnerName));
            return($"{learnerName}+{transformsBitMask}");
        }
Exemplo n.º 6
0
 void ICanSaveModel.Save(ModelSaveContext ctx)
 {
     Host.Check(Meta != null, "Can't save an untrained Stacking combiner");
     Host.CheckValue(ctx, nameof(ctx));
     ctx.CheckAtModel();
     SaveCore(ctx);
 }
        internal DateTimeEstimator(IHostEnvironment env, Options options)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator");
            _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");

            _options = options;
        }
        IPredictor ITrainer <IPredictor> .Train(TrainContext context)
        {
            Host.CheckValue(context, nameof(context));
            var pred = TrainModelCore(context) as IPredictor;

            Host.Check(pred != null, "Training did not return a predictor.");
            return(pred);
        }
Exemplo n.º 9
0
        private static void TransformFeatures(IHost host, ref VBuffer <Float> src, ref VBuffer <Float> dst, TransformInfo transformInfo,
                                              AlignedArray featuresAligned, AlignedArray productAligned)
        {
            Contracts.AssertValue(host, "host");
            host.Check(src.Length == transformInfo.SrcDim, "column does not have the expected dimensionality.");

            var   values = dst.Values;
            Float scale;

            if (transformInfo.RotationTerms != null)
            {
                if (Utils.Size(values) < transformInfo.NewDim)
                {
                    values = new Float[transformInfo.NewDim];
                }
                scale = MathUtils.Sqrt((Float)2.0 / transformInfo.NewDim);
            }
            else
            {
                if (Utils.Size(values) < 2 * transformInfo.NewDim)
                {
                    values = new Float[2 * transformInfo.NewDim];
                }
                scale = MathUtils.Sqrt((Float)1.0 / transformInfo.NewDim);
            }

            if (src.IsDense)
            {
                featuresAligned.CopyFrom(src.Values, 0, src.Length);
                CpuMathUtils.MatTimesSrc(false, false, transformInfo.RndFourierVectors, featuresAligned, productAligned,
                                         transformInfo.NewDim);
            }
            else
            {
                // This overload of MatTimesSrc ignores the values in slots that are not in src.Indices, so there is
                // no need to zero them out.
                featuresAligned.CopyFrom(src.Indices, src.Values, 0, 0, src.Count, zeroItems: false);
                CpuMathUtils.MatTimesSrc(false, false, transformInfo.RndFourierVectors, src.Indices, featuresAligned, 0, 0,
                                         src.Count, productAligned, transformInfo.NewDim);
            }

            for (int i = 0; i < transformInfo.NewDim; i++)
            {
                var dotProduct = productAligned[i];
                if (transformInfo.RotationTerms != null)
                {
                    values[i] = (Float)MathUtils.Cos(dotProduct + transformInfo.RotationTerms[i]) * scale;
                }
                else
                {
                    values[2 * i]     = (Float)MathUtils.Cos(dotProduct) * scale;
                    values[2 * i + 1] = (Float)MathUtils.Sin(dotProduct) * scale;
                }
            }

            dst = new VBuffer <Float>(transformInfo.RotationTerms == null ? 2 * transformInfo.NewDim : transformInfo.NewDim,
                                      values, dst.Indices);
        }
        IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
        {
            _host.CheckValue(inputSchema, nameof(inputSchema));
            var transform = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, new EmptyDataView(_host, inputSchema)) as IRowToRowMapper;

            _host.Check(transform is IRowToRowMapper);

            return(new CompositeRowToRowMapper(inputSchema, new[] { transform }));
        }
Exemplo n.º 11
0
        private void PrepareNextBatch(IEnumerable <IRunResult> results)
        {
            _host.Check(!_disposed, "Creating parameters while sweeper is disposed");
            var paramSets = _baseSweeper.ProposeSweeps(_batchSize, results);

            if (Utils.Size(paramSets) == 0)
            {
                // Mark the queue as completed.
                _paramChannel.Writer.Complete();
                return;
            }
            // Assign an id to each ParameterSet and enque it.
            foreach (var paramSet in paramSets)
            {
                _paramChannel.Writer.TryWrite(new ParameterSetWithId(_numGenerated++, paramSet));
            }
            EnsureResultsSize();
        }
            public void SaveCore(ModelSaveContext ctx, IHost host, VersionInfo versionInfo)
            {
                host.Check(Classes.Count > 0, "The model cannot be saved, it was never trained.");
                host.Check(Classes.Count == Classes.Length, "The model cannot be saved, it was never trained.");
                ctx.SetVersionInfo(versionInfo);
                ctx.Writer.WriteIntArray(Classes.Indices);
                if (LabelType == NumberDataViewType.Single)
                {
                    ctx.Writer.WriteSingleArray(Classes.Values as float[]);
                }
                else if (LabelType == NumberDataViewType.Byte)
                {
                    ctx.Writer.WriteByteArray(Classes.Values as byte[]);
                }
                else if (LabelType == NumberDataViewType.UInt16)
                {
                    ctx.Writer.WriteUIntArray((Classes.Values as ushort[]).Select(c => (uint)c).ToArray());
                }
                else if (LabelType == NumberDataViewType.UInt32)
                {
                    ctx.Writer.WriteUIntArray(Classes.Values as uint[]);
                }
                else
                {
                    throw host.Except("Unexpected type for LabelType.");
                }

                ctx.Writer.Write(_singleColumn ? 1 : 0);
                ctx.Writer.Write(_labelKey ? 1 : 0);
                var preds = Predictors;

                ctx.Writer.Write(preds.Length);
                for (int i = 0; i < preds.Length; i++)
                {
                    ctx.SaveModel(preds[i], string.Format("M2B{0}", i));
                }
                ctx.Writer.Write(_reclassificationPredictor != null ? (byte)1 : (byte)0);
                if (_reclassificationPredictor != null)
                {
                    ctx.SaveModel(_reclassificationPredictor, "Reclassification");
                }
                ctx.Writer.Write((byte)213);
            }
 /// <summary>
 /// Constructor.
 /// </summary>
 /// <param name="env">environment</param>
 /// <param name="input">input source stored as the secondary source</param>
 /// <param name="name">name of the transform</param>
 public AbstractSimpleTransformTemplate(IHostEnvironment env, IDataView input, string name)
 {
     Contracts.CheckValue(env, "env");
     _host = env.Register(name);
     _host.CheckValue(input, "input");
     _host.Check(!string.IsNullOrEmpty(name));
     _sourceCtx  = input;
     _sourcePipe = null;
     _lock       = new object();
 }
Exemplo n.º 14
0
        public unsafe void Get(out int m, out int n, out int k, out float[] p, out float[] q)
        {
            _host.Check(_pMFModel != null, "Attempted to get predictor before training");
            m = _pMFModel->M;
            _host.Check(m > 0, "Number of rows should have been positive but was not");
            n = _pMFModel->N;
            _host.Check(n > 0, "Number of columns should have been positive but was not");
            k = _pMFModel->K;
            _host.Check(k > 0, "Internal dimension should have been positive but was not");

            p = new float[m * k];
            q = new float[n * k];

            unsafe
            {
                Marshal.Copy((IntPtr)_pMFModel->P, p, 0, p.Length);
                Marshal.Copy((IntPtr)_pMFModel->Q, q, 0, q.Length);
            }
        }
Exemplo n.º 15
0
        public RandomFeatureSelector(IHostEnvironment env, Arguments args)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(args, nameof(args));

            _host = env.Register(LoadName);
            _args = args;
            _host.Check(0 < _args.FeaturesSelectionProportion && _args.FeaturesSelectionProportion < 1,
                        "The feature proportion for RandomFeatureSelector should be greater than 0 and lesser than 1");
        }
Exemplo n.º 16
0
        internal ColumnSelectingTransformer(IHostEnvironment env, string[] keepColumns, string[] dropColumns,
                                            bool keepHidden = ColumnSelectingEstimator.Defaults.KeepHidden, bool ignoreMissing = ColumnSelectingEstimator.Defaults.IgnoreMissing)
        {
            _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnSelectingTransformer));
            _host.CheckValueOrNull(keepColumns);
            _host.CheckValueOrNull(dropColumns);

            bool keepValid = Utils.Size(keepColumns) > 0;
            bool dropValid = Utils.Size(dropColumns) > 0;

            // Check that both are not valid
            _host.Check(!(keepValid && dropValid), "Both " + nameof(keepColumns) + " and " + nameof(dropColumns) + " are set. Exactly one can be specified.");
            // Check that both are invalid
            _host.Check(!(!keepValid && !dropValid), "Neither " + nameof(keepColumns) + " and " + nameof(dropColumns) + " is set. Exactly one must be specified.");

            _selectedColumns = (keepValid) ? keepColumns : dropColumns;
            KeepColumns      = keepValid;
            KeepHidden       = keepHidden;
            IgnoreMissing    = ignoreMissing;
        }
 /// <summary>
 /// Loading constructor.
 /// </summary>
 /// <param name="ctx">reading context</param>
 /// <param name="env">environment</param>
 /// <param name="input">input source stored as the secondary source</param>
 /// <param name="name">name of the transform</param>
 public AbstractSimpleTransformTemplate(IHost host, ModelLoadContext ctx, IDataView input, string name)
 {
     Contracts.CheckValue(host, "env");
     _host = host;
     _host.CheckValue(ctx, "env");
     _host.CheckValue(input, "input");
     _host.Check(!string.IsNullOrEmpty(name));
     _sourceCtx  = input;
     _sourcePipe = null;
     _lock       = new object();
 }
        public SelectColumnsTransform(IHostEnvironment env, string[] keepColumns, string[] dropColumns,
                                      bool keepHidden = true, bool ignoreMissing = true)
        {
            _host = Contracts.CheckRef(env, nameof(env)).Register(nameof(SelectColumnsTransform));
            _host.CheckValueOrNull(keepColumns);
            _host.CheckValueOrNull(dropColumns);

            bool keepValid = keepColumns != null && keepColumns.Count() > 0;
            bool dropValid = dropColumns != null && dropColumns.Count() > 0;

            // Check that both are not valid
            _host.Check(!(keepValid && dropValid), "Both keepColumns and dropColumns are set, only one can be specified.");
            // Check that both are invalid
            _host.Check(!(!keepValid && !dropValid), "Neither keepColumns or dropColumns is set, one must be specified.");

            _selectedColumns = (keepValid) ? keepColumns : dropColumns;
            KeepColumns      = keepValid;
            KeepHidden       = keepHidden;
            IgnoreMissing    = ignoreMissing;
        }
 public TransformApplierParams(IHost host, Arguments args)
 {
     Contracts.AssertValue(host);
     host.CheckUserArg(args.Column != null, nameof(args.Column), "Columns must be specified");
     host.CheckUserArg(args.WordFeatureExtractor != null || args.CharFeatureExtractor != null || args.OutputTokens,
                       nameof(args.WordFeatureExtractor), "At least one feature extractor or OutputTokens must be specified.");
     host.Check(Enum.IsDefined(typeof(Language), args.Language));
     host.Check(Enum.IsDefined(typeof(CaseNormalizationMode), args.TextCase));
     WordExtractorFactory = args.WordFeatureExtractor?.CreateComponent(host, args.Dictionary);
     CharExtractorFactory = args.CharFeatureExtractor?.CreateComponent(host, args.Dictionary);
     VectorNormalizer     = args.VectorNormalizer;
     Language             = args.Language;
     StopWordsRemover     = args.StopWordsRemover;
     TextCase             = args.TextCase;
     KeepDiacritics       = args.KeepDiacritics;
     KeepPunctuations     = args.KeepPunctuations;
     KeepNumbers          = args.KeepNumbers;
     OutputTextTokens     = args.OutputTokens;
     Dictionary           = args.Dictionary;
 }
Exemplo n.º 20
0
        internal CountTargetEncodingEstimator(IHostEnvironment env, string labelColumnName, CountTargetEncodingTransformer initialCounts, params InputOutputColumnPair[] columns)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(CountTargetEncodingEstimator));
            _host.CheckValue(initialCounts, nameof(initialCounts));
            _host.CheckNonEmpty(columns, nameof(columns));
            _host.Check(initialCounts.VerifyColumns(columns), nameof(columns));

            _hashingEstimator    = new HashingEstimator(_host, initialCounts.HashingTransformer.Columns.ToArray());
            _countTableEstimator = new CountTableEstimator(_host, labelColumnName, initialCounts.CountTable,
                                                           columns.Select(c => new InputOutputColumnPair(c.OutputColumnName, c.OutputColumnName)).ToArray());
        }
Exemplo n.º 21
0
        private SchemaBindablePipelineEnsembleBase(IHostEnvironment env, IPredictorModel[] predictors, string registrationName, string scoreColumnKind)
        {
            Contracts.CheckValue(env, nameof(env));
            Host = env.Register(registrationName);
            Host.CheckNonEmpty(predictors, nameof(predictors));
            Host.CheckNonWhiteSpace(scoreColumnKind, nameof(scoreColumnKind));

            PredictorModels  = predictors;
            _scoreColumnKind = scoreColumnKind;

            HashSet <string> inputCols = null;

            for (int i = 0; i < predictors.Length; i++)
            {
                var predModel = predictors[i];

                // Get the input column names.
                var inputSchema = predModel.TransformModel.InputSchema;
                if (inputCols == null)
                {
                    inputCols = new HashSet <string>();
                    for (int j = 0; j < inputSchema.ColumnCount; j++)
                    {
                        if (inputSchema.IsHidden(j))
                        {
                            continue;
                        }
                        inputCols.Add(inputSchema.GetColumnName(j));
                    }
                    _inputCols = inputCols.ToArray();
                }
                else
                {
                    int nonHiddenCols = 0;
                    for (int j = 0; j < inputSchema.ColumnCount; j++)
                    {
                        if (inputSchema.IsHidden(j))
                        {
                            continue;
                        }
                        var name = inputSchema.GetColumnName(j);
                        if (!inputCols.Contains(name))
                        {
                            throw Host.Except("Inconsistent schemas: Some schemas do not contain the column '{0}'", name);
                        }
                        nonHiddenCols++;
                    }
                    Host.Check(nonHiddenCols == _inputCols.Length,
                               "Inconsistent schemas: not all schemas have the same number of columns");
                }
            }
        }
Exemplo n.º 22
0
        internal TimeSeriesImputerEstimator(IHostEnvironment env, string timeSeriesColumn, string[] grainColumns, string[] filterColumns, FilterMode filterMode, ImputationStrategy imputeMode, bool supressTypeErrors)
        {
            Contracts.CheckValue(env, nameof(env));
            _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");
            _host = Contracts.CheckRef(env, nameof(env)).Register("TimeSeriesImputerEstimator");
            _host.CheckValue(timeSeriesColumn, nameof(timeSeriesColumn), "TimePoint column should not be null.");
            _host.CheckNonEmpty(grainColumns, nameof(grainColumns), "Need at least one grain column.");
            if (filterMode == FilterMode.Include)
            {
                _host.CheckNonEmpty(filterColumns, nameof(filterColumns), "Need at least 1 filter column if a FilterMode is specified");
            }

            _options = new Options
            {
                TimeSeriesColumn  = timeSeriesColumn,
                GrainColumns      = grainColumns,
                FilterColumns     = filterColumns == null ? new string[] { } : filterColumns,
                FilterMode        = filterMode,
                ImputeMode        = imputeMode,
                SupressTypeErrors = supressTypeErrors
            };
        }
Exemplo n.º 23
0
        internal CategoricalImputerEstimator(IHostEnvironment env, Options options)
        {
            Contracts.CheckValue(env, nameof(env));
            _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");
            _host = env.Register(nameof(CategoricalImputerEstimator));

            foreach (var columnPair in options.Columns)
            {
                columnPair.Source = columnPair.Source ?? columnPair.Name;
            }

            _options = options;
        }
Exemplo n.º 24
0
 private void SetInput(Schema schema, HashSet<string> toDrop)
 {
     var recordType = new JObject();
     recordType["type"] = "record";
     recordType["name"] = "DataInput";
     var fields = new JArray();
     var fieldNames = new HashSet<string>();
     for (int c = 0; c < schema.Count; ++c)
     {
         if (schema[c].IsHidden)
             continue;
         string name = schema[c].Name;
         if (toDrop.Contains(name))
             continue;
         JToken pfaType = PfaTypeOrNullForColumn(schema, c);
         if (pfaType == null)
             continue;
         string fieldName = ModelUtils.CreateNameCore(name, fieldNames.Contains);
         fieldNames.Add(fieldName);
         var fieldDeclaration = new JObject();
         fieldDeclaration["name"] = fieldName;
         fieldDeclaration["type"] = pfaType;
         fields.Add(fieldDeclaration);
         _nameToVarName.Add(name, "input." + fieldName);
     }
     _host.Assert(_nameToVarName.Count == fields.Count);
     _host.Assert(_nameToVarName.Count == fieldNames.Count);
     recordType["fields"] = fields;
     _host.Check(fields.Count >= 1, "Schema produced no inputs for the PFA conversion.");
     if (fields.Count == 1)
     {
         // If there's only one, don't bother forming a record.
         var field = (JObject)fields[0];
         Pfa.InputType = field["type"];
         _nameToVarName[_nameToVarName.Keys.First()] = "input";
     }
     else
         Pfa.InputType = recordType;
 }
Exemplo n.º 25
0
        public override string AddInitializer(IEnumerable <bool> values, IEnumerable <long> dims, string name = null, bool makeUniqueName = true)
        {
            _host.CheckValue(values, nameof(values));
            if (dims != null)
            {
                _host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
            }

            name = AddVariable(name ?? "bools", makeUniqueName);
            _initializers.Add(OnnxUtils.MakeInt32s(name, typeof(bool), values.Select(v => Convert.ToInt32(v)), dims));
            return(name);
        }
        internal DateTimeEstimator(IHostEnvironment env, string inputColumnName, string columnPrefix, HolidayList country = HolidayList.None)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator");
            _host.CheckValue(inputColumnName, nameof(inputColumnName), "Input column should not be null.");
            _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");

            _options = new Options
            {
                Source  = inputColumnName,
                Prefix  = columnPrefix,
                Country = country
            };
        }
Exemplo n.º 27
0
        /// <summary>
        /// Save model to the given context
        /// </summary>
        public void Save(ModelSaveContext ctx)
        {
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // int: number of rows (m), the limit on row
            // int: number of columns (n), the limit on column
            // int: rank of factor matrices (k)
            // float[m * k]: the left factor matrix
            // float[k * n]: the right factor matrix

            _host.Check(_numberOfRows > 0, "Number of rows must be positive");
            _host.Check(_numberofColumns > 0, "Number of columns must be positive");
            _host.Check(_approximationRank > 0, "Number of latent factors must be positive");
            ctx.Writer.Write(_numberOfRows);
            ctx.Writer.Write(_numberofColumns);
            ctx.Writer.Write(_approximationRank);
            _host.Check(Utils.Size(_leftFactorMatrix) == _numberOfRows * _approximationRank, "Unexpected matrix size of a factor matrix (matrix P in LIBMF paper)");
            _host.Check(Utils.Size(_rightFactorMatrix) == _numberofColumns * _approximationRank, "Unexpected matrix size of a factor matrix (matrix Q in LIBMF paper)");
            Utils.WriteSinglesNoCount(ctx.Writer, _leftFactorMatrix.AsSpan(0, _numberOfRows * _approximationRank));
            Utils.WriteSinglesNoCount(ctx.Writer, _rightFactorMatrix.AsSpan(0, _numberofColumns * _approximationRank));
        }
Exemplo n.º 28
0
        /// <summary>
        /// Returns the feature selection scores for each slot of each column.
        /// </summary>
        /// <param name="host">The host.</param>
        /// <param name="input">The input dataview.</param>
        /// <param name="labelColumnName">The label column.</param>
        /// <param name="columns">The columns for which to compute the feature selection scores.</param>
        /// <param name="numBins">The number of bins to use for numeric features.</param>
        /// <returns>A list of scores for each column and each slot.</returns>
        public static Single[][] Train(IHost host, IDataView input, string labelColumnName, string[] columns, int numBins)
        {
            Contracts.CheckValue(host, nameof(host));
            host.CheckValue(input, nameof(input));
            host.CheckNonWhiteSpace(labelColumnName, nameof(labelColumnName));
            host.CheckValue(columns, nameof(columns));
            host.Check(columns.Length > 0, "At least one column must be specified.");
            host.Check(numBins > 1, "numBins must be greater than 1.");

            HashSet <string> colSet = new HashSet <string>();

            foreach (string col in columns)
            {
                if (!colSet.Add(col))
                {
                    throw host.Except("Column '{0}' specified multiple times.", col);
                }
            }

            var colSizes = new int[columns.Length];

            return(TrainCore(host, input, labelColumnName, columns, numBins, colSizes));
        }
        public BagMultiCountTableBuilder(IHostEnvironment env, MultiCountTable table, DataViewSchema.Column[] inputCols, long labelCardinality)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(table, nameof(table));
            _host = env.Register(LoaderSignature);
            _host.Check(Utils.Size(inputCols) == table.ColCount, "Inconsistent number of columns");
            _host.Check(table.SlotCount.Zip(inputCols, (count, col) => (count, col)).
                        All(pair => pair.col.Type.GetValueCount() == pair.count), "Inconsistent number of slots");

            _builder   = table.BaseTable.ToBuilder(labelCardinality);
            _colCount  = table.ColCount;
            _slotCount = new int[_colCount];
            table.SlotCount.CopyTo(_slotCount, 0);
        }
Exemplo n.º 30
0
            private Delegate CreateKeyGetterDelegate <TDst>(Delegate peekDel, ColumnType colType)
            {
                // Make sure the function is dealing with key.
                Host.Check(colType.IsKey);
                // Following equations work only with contiguous key type.
                Host.Check(colType.AsKey.Contiguous);
                // Following equations work only with unsigned integers.
                Host.Check(typeof(TDst) == typeof(ulong) || typeof(TDst) == typeof(uint) ||
                           typeof(TDst) == typeof(byte) || typeof(TDst) == typeof(bool));

                // Convert delegate function to a function which can fetch the underlying value.
                var peek = peekDel as Peek <TRow, TDst>;

                Host.AssertValue(peek);

                TDst  rawKeyValue         = default;
                ulong key                 = 0; // the raw key value as ulong
                ulong min                 = colType.AsKey.Min;
                ulong max                 = min + (ulong)colType.AsKey.Count - 1;
                ulong result              = 0; // the result as ulong
                ValueGetter <TDst> getter = (ref TDst dst) =>
                {
                    peek(GetCurrentRowObject(), Position, ref rawKeyValue);
                    key = (ulong)Convert.ChangeType(rawKeyValue, typeof(ulong));
                    if (min <= key && key <= max)
                    {
                        result = key - min + 1;
                    }
                    else
                    {
                        result = 0;
                    }
                    dst = (TDst)Convert.ChangeType(result, typeof(TDst));
                };

                return(getter);
            }