Exemplo n.º 1
0
        /// <summary>
        /// Get the getter for the group column, or null if there is no group column.
        /// </summary>
        public static ValueGetter <ulong> GetOptGroupGetter(this DataViewRow row, RoleMappedSchema schema)
        {
            Contracts.CheckValue(row, nameof(row));
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.Check(schema.Schema == row.Schema, "schemas don't match!");

            var col = schema.Group;

            if (!col.HasValue)
            {
                return(null);
            }
            return(RowCursorUtils.GetGetterAs <ulong>(NumberDataViewType.UInt64, row, col.Value.Index));
        }
Exemplo n.º 2
0
        /// <summary>
        /// Get the getter for the weight column, or null if there is no weight column.
        /// </summary>
        public static ValueGetter <float> GetOptWeightFloatGetter(this Row row, RoleMappedSchema schema)
        {
            Contracts.CheckValue(row, nameof(row));
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.Check(schema.Schema == row.Schema, "schemas don't match!");

            var col = schema.Weight;

            if (!col.HasValue)
            {
                return(null);
            }
            return(RowCursorUtils.GetGetterAs <float>(NumberType.Float, row, col.Value.Index));
        }
Exemplo n.º 3
0
        protected override Delegate[] CreatePredictionGetters(Booster xgboostModel, IRow input, Func <int, bool> predicate)
        {
            var active = Utils.BuildArray(OutputSchema.ColumnCount, predicate);

            xgboostModel.LazyInit();
            var getters = new Delegate[1];

            if (active[0])
            {
                var             featureGetter   = RowCursorUtils.GetVecGetterAs <Float>(PrimitiveType.FromKind(DataKind.R4), input, InputSchema.Feature.Index);
                VBuffer <Float> features        = new VBuffer <Float>();
                var             postProcessor   = Parent.GetOutputPostProcessor();
                int             expectedLength  = input.Schema.GetColumnType(InputSchema.Feature.Index).VectorSize;
                var             xgboostBuffer   = Booster.CreateInternalBuffer();
                int             nbMappedClasses = _classMapping == null ? 0 : _numberOfClasses;

                if (nbMappedClasses == 0)
                {
                    ValueGetter <VBuffer <Float> > localGetter = (ref VBuffer <Float> prediction) =>
                    {
                        featureGetter(ref features);
                        Contracts.Assert(features.Length == expectedLength);
                        xgboostModel.Predict(ref features, ref prediction, ref xgboostBuffer);
                        postProcessor(ref prediction);
                    };
                    getters[0] = localGetter;
                }
                else
                {
                    ValueGetter <VBuffer <Float> > localGetter = (ref VBuffer <Float> prediction) =>
                    {
                        featureGetter(ref features);
                        Contracts.Assert(features.Length == expectedLength);
                        xgboostModel.Predict(ref features, ref prediction, ref xgboostBuffer);
                        Contracts.Assert(prediction.IsDense);
                        postProcessor(ref prediction);
                        var indices = prediction.Indices;
                        if (indices == null || indices.Length < _classMapping.Length)
                        {
                            indices = new int[_classMapping.Length];
                        }
                        Array.Copy(_classMapping, indices, _classMapping.Length);
                        prediction = new VBuffer <float>(nbMappedClasses, _classMapping.Length, prediction.Values, indices);
                    };
                    getters[0] = localGetter;
                }
            }
            return(getters);
        }
Exemplo n.º 4
0
        /// <summary>
        /// Get the getter for the group column, or null if there is no group column.
        /// </summary>
        public static ValueGetter <ulong> GetOptGroupGetter(this IRow row, RoleMappedSchema schema)
        {
            Contracts.CheckValue(row, nameof(row));
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.Check(schema.Schema == row.Schema, "schemas don't match!");
            Contracts.CheckValueOrNull(schema.Group);

            var col = schema.Group;

            if (col == null)
            {
                return(null);
            }
            return(RowCursorUtils.GetGetterAs <ulong>(NumberType.U8, row, col.Index));
        }
Exemplo n.º 5
0
        private static double CalculateAvgLoss(IChannel ch, RoleMappedData data, bool norm, float[] linearWeights, AlignedArray latentWeightsAligned,
                                               int latentDimAligned, AlignedArray latentSum, int[] featureFieldBuffer, int[] featureIndexBuffer, float[] featureValueBuffer, VBuffer <float> buffer, ref long badExampleCount)
        {
            var featureColumns    = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature);
            Func <int, bool> pred = c => featureColumns.Select(ci => ci.Index).Contains(c) || c == data.Schema.Label.Value.Index || c == data.Schema.Weight?.Index;
            var    getters        = new ValueGetter <VBuffer <float> > [featureColumns.Count];
            float  label          = 0;
            float  weight         = 1;
            double loss           = 0;
            float  modelResponse  = 0;
            long   exampleCount   = 0;

            badExampleCount = 0;
            int count = 0;

            using (var cursor = data.Data.GetRowCursor(pred))
            {
                var labelGetter  = RowCursorUtils.GetLabelGetter(cursor, data.Schema.Label.Value.Index);
                var weightGetter = data.Schema.Weight?.Index is int weightIdx?cursor.GetGetter <float>(weightIdx) : null;

                for (int f = 0; f < featureColumns.Count; f++)
                {
                    getters[f] = cursor.GetGetter <VBuffer <float> >(featureColumns[f].Index);
                }
                while (cursor.MoveNext())
                {
                    labelGetter(ref label);
                    weightGetter?.Invoke(ref weight);
                    float annihilation = label - label + weight - weight;
                    if (!FloatUtils.IsFinite(annihilation))
                    {
                        badExampleCount++;
                        continue;
                    }
                    if (!FieldAwareFactorizationMachineUtils.LoadOneExampleIntoBuffer(getters, buffer, norm, ref count,
                                                                                      featureFieldBuffer, featureIndexBuffer, featureValueBuffer))
                    {
                        badExampleCount++;
                        continue;
                    }
                    FieldAwareFactorizationMachineInterface.CalculateIntermediateVariables(featureColumns.Count, latentDimAligned, count,
                                                                                           featureFieldBuffer, featureIndexBuffer, featureValueBuffer, linearWeights, latentWeightsAligned, latentSum, ref modelResponse);
                    loss += weight * CalculateLoss(label, modelResponse);
                    exampleCount++;
                }
            }
            return(loss / exampleCount);
        }
Exemplo n.º 6
0
        protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
        {
            Host.AssertValueOrNull(ch);
            Host.AssertValue(input);
            Host.Assert(0 <= iinfo && iinfo < Infos.Length);
            disposer = null;

            var typeSrc = Infos[iinfo].TypeSrc;
            var typeDst = _exes[iinfo].TypeDst;

            if (!typeDst.IsVector())
            {
                return(GetGetterAs(typeDst, input, Infos[iinfo].Source));
            }
            return(RowCursorUtils.GetVecGetterAs(typeDst.AsVector().ItemType(), input, Infos[iinfo].Source));
        }
            public DataViewRowCursor GetRowCursor(IEnumerable <DataViewSchema.Column> columnsNeeded, Random rand = null)
            {
                var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);

                _host.AssertValueOrNull(rand);

                // Build out the active state for the input
                var inputPred = GetDependencies(predicate);
                var inputCols = Source.Schema.Where(x => inputPred(x.Index));

                var inputRowCursor = Source.GetRowCursor(inputCols, rand);

                // Build the active state for the output
                var active = Utils.BuildArray(_mapper.OutputSchema.Count, columnsNeeded);

                return(new Cursor(_host, _mapper, inputRowCursor, active));
            }
Exemplo n.º 8
0
                public Impl(Row input, int pyColIndex, int idvColIndex, ColumnType type, ValuePoker <TSrc> poker)
                    : base(input, pyColIndex)
                {
                    Contracts.AssertValue(input);
                    Contracts.Assert(0 <= idvColIndex && idvColIndex < input.Schema.Count);

                    if (type.IsVector)
                    {
                        _getVec = RowCursorUtils.GetVecGetterAs <TSrc>((PrimitiveType)type.ItemType, input, idvColIndex);
                    }
                    else
                    {
                        _get = RowCursorUtils.GetGetterAs <TSrc>(type, input, idvColIndex);
                    }

                    _poker = poker;
                }
Exemplo n.º 9
0
        private ValueGetter <VBuffer <Float> > GetTopic(IRow input, int iinfo)
        {
            var  getSrc        = RowCursorUtils.GetVecGetterAs <Double>(NumberType.R8, input, Infos[iinfo].Source);
            var  src           = default(VBuffer <Double>);
            var  lda           = _ldas[iinfo];
            int  numBurninIter = lda.InfoEx.NumBurninIter;
            bool reset         = lda.InfoEx.ResetRandomGenerator;

            return
                ((ref VBuffer <Float> dst) =>
            {
                // REVIEW: This will work, but there are opportunities for caching
                // based on input.Counter that are probably worthwhile given how long inference takes.
                getSrc(ref src);
                lda.Output(ref src, ref dst, numBurninIter, reset);
            });
        }
        public override RowCursor[] GetRowCursorSet(IEnumerable <Schema.Column> columnsNeeded, int n, Random rand = null)
        {
            Host.CheckValueOrNull(rand);
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
            var inputPred = _bindings.GetDependencies(predicate);

            var inputCols = Source.Schema.Where(x => inputPred(x.Index));

            RowCursor[] cursors = Source.GetRowCursorSet(inputCols, n, rand);
            bool        active  = predicate(_bindings.MapIinfoToCol(0));

            for (int c = 0; c < cursors.Length; ++c)
            {
                cursors[c] = new Cursor(Host, _bindings, cursors[c], active);
            }
            return(cursors);
        }
        public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {

            Host.CheckValueOrNull(rand);

            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
            Func<int, bool> inputPred = GetActive(predicate, out bool[] active);

            var inputCols = Source.Schema.Where(x => inputPred(x.Index));
            var inputs = Source.GetRowCursorSet(inputCols, n, rand);
            Host.AssertNonEmpty(inputs);

            // No need to split if this is given 1 input cursor.
            var cursors = new DataViewRowCursor[inputs.Length];
            for (int i = 0; i < inputs.Length; i++)
                cursors[i] = CreateCursorCore(inputs[i], active);
            return cursors;
        }
Exemplo n.º 12
0
                public Impl(DataViewRow input, int pyColIndex, int idvColIndex, DataViewType type, ValuePoker <TSrc> poker)
                    : base(input, pyColIndex)
                {
                    Contracts.AssertValue(input);
                    Contracts.Assert(0 <= idvColIndex && idvColIndex < input.Schema.Count);

                    if (type is VectorDataViewType)
                    {
                        _getVec = RowCursorUtils.GetVecGetterAs <TSrc>((PrimitiveDataViewType)type.GetItemType(), input, idvColIndex);
                    }
                    else
                    {
                        _get = RowCursorUtils.GetGetterAs <TSrc>(type, input, idvColIndex);
                    }

                    _poker       = poker;
                    _isVarLength = (type.GetValueCount() == 0);
                }
Exemplo n.º 13
0
        internal MultiClassClassifierMetrics(IExceptionContext ectx, Row overallResult, int topK)
        {
            double FetchDouble(string name) => RowCursorUtils.Fetch <double>(ectx, overallResult, name);

            AccuracyMicro    = FetchDouble(MultiClassClassifierEvaluator.AccuracyMicro);
            AccuracyMacro    = FetchDouble(MultiClassClassifierEvaluator.AccuracyMacro);
            LogLoss          = FetchDouble(MultiClassClassifierEvaluator.LogLoss);
            LogLossReduction = FetchDouble(MultiClassClassifierEvaluator.LogLossReduction);
            TopK             = topK;
            if (topK > 0)
            {
                TopKAccuracy = FetchDouble(MultiClassClassifierEvaluator.TopKAccuracy);
            }

            var perClassLogLoss = RowCursorUtils.Fetch <VBuffer <double> >(ectx, overallResult, MultiClassClassifierEvaluator.PerClassLogLoss);

            PerClassLogLoss = new double[perClassLogLoss.Length];
            perClassLogLoss.CopyTo(PerClassLogLoss);
        }
        public override DataViewRowCursor[] GetRowCursorSet(IEnumerable <DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
        {
            Host.AssertValueOrNull(rand);
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);

            Func <int, bool> inputPred = TypedSrc.GetDependencies(predicate);

            var inputCols = Input.Schema.Where(x => inputPred(x.Index));
            var inputs    = Input.GetRowCursorSet(inputCols, n, rand);

            Host.AssertNonEmpty(inputs);
            var active = Utils.BuildArray(Input.Schema.Count, inputCols);

            // No need to split if this is given 1 input cursor.
            var cursors = new DataViewRowCursor[inputs.Length];

            for (int i = 0; i < inputs.Length; i++)
            {
                cursors[i] = new Cursor(this, inputs[i], active);
            }
            return(cursors);
        }
Exemplo n.º 15
0
        private Delegate MakeGetter(IChannel ch, IRow input, int iinfo, FinderDecorator decorator = null)
        {
            ch.Assert(_bindings.Infos[iinfo].SrcTypes.All(t => t.IsVector && t.ItemType.IsKey));

            var info     = _bindings.Infos[iinfo];
            int srcCount = info.SrcIndices.Length;

            ValueGetter <VBuffer <uint> >[] getSrc = new ValueGetter <VBuffer <uint> > [srcCount];
            for (int isrc = 0; isrc < srcCount; isrc++)
            {
                getSrc[isrc] = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, input, info.SrcIndices[isrc]);
            }
            var src           = default(VBuffer <uint>);
            var ngramIdFinder = GetNgramIdFinder(iinfo);

            if (decorator != null)
            {
                ngramIdFinder = decorator(iinfo, ngramIdFinder);
            }
            var bldr = new NgramBufferBuilder(_exes[iinfo].NgramLength, _exes[iinfo].SkipLength,
                                              _bindings.Types[iinfo].ValueCount, ngramIdFinder);
            var keyCounts = _bindings.Infos[iinfo].SrcTypes.Select(
                t => t.ItemType.KeyCount > 0 ? (uint)t.ItemType.KeyCount : uint.MaxValue).ToArray();

            // REVIEW: Special casing the srcCount==1 case could potentially improve perf.
            ValueGetter <VBuffer <Float> > del =
                (ref VBuffer <Float> dst) =>
            {
                bldr.Reset();
                for (int i = 0; i < srcCount; i++)
                {
                    getSrc[i](ref src);
                    bldr.AddNgrams(ref src, i, keyCounts[i]);
                }
                bldr.GetResult(ref dst);
            };

            return(del);
        }
Exemplo n.º 16
0
                public CsrFiller(DataViewRow input,
                                 int idvColIndex,
                                 DataViewType type,
                                 DataAppender <TSrc> dataAppender,
                                 CsrData csrData)
                    : base()
                {
                    Contracts.AssertValue(input);
                    Contracts.Assert(0 <= idvColIndex && idvColIndex < input.Schema.Count);

                    if (type is VectorDataViewType)
                    {
                        _getVec = RowCursorUtils.GetVecGetterAs <TSrc>((PrimitiveDataViewType)type.GetItemType(), input, idvColIndex);
                    }
                    else
                    {
                        _get = RowCursorUtils.GetGetterAs <TSrc>(type, input, idvColIndex);
                    }

                    _csrData      = csrData;
                    _dataAppender = dataAppender;
                }
Exemplo n.º 17
0
        public void AssertStaticKeys()
        {
            var env     = new ConsoleEnvironment(0, verbose: true);
            var counted = new MetaCounted();

            // We'll test a few things here. First, the case where the key-value metadata is text.
            var  metaValues1 = new VBuffer <ReadOnlyMemory <char> >(3, new[] { "a".AsMemory(), "b".AsMemory(), "c".AsMemory() });
            var  meta1       = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, new VectorType(TextType.Instance, 3), ref metaValues1);
            uint value1      = 2;
            var  col1        = RowColumnUtils.GetColumn("stay", new KeyType(DataKind.U4, 0, 3), ref value1, RowColumnUtils.GetRow(counted, meta1));

            // Next the case where those values are ints.
            var metaValues2 = new VBuffer <int>(3, new int[] { 1, 2, 3, 4 });
            var meta2       = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, new VectorType(NumberType.I4, 4), ref metaValues2);
            var value2      = new VBuffer <byte>(2, 0, null, null);
            var col2        = RowColumnUtils.GetColumn("awhile", new VectorType(new KeyType(DataKind.U1, 2, 4), 2), ref value2, RowColumnUtils.GetRow(counted, meta2));

            // Then the case where a value of that kind exists, but is of not of the right kind, in which case it should not be identified as containing that metadata.
            var metaValues3 = (float)2;
            var meta3       = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, NumberType.R4, ref metaValues3);
            var value3      = (ushort)1;
            var col3        = RowColumnUtils.GetColumn("and", new KeyType(DataKind.U2, 0, 2), ref value3, RowColumnUtils.GetRow(counted, meta3));

            // Then a final case where metadata of that kind is actaully simply altogether absent.
            var value4 = new VBuffer <uint>(5, 0, null, null);
            var col4   = RowColumnUtils.GetColumn("listen", new VectorType(new KeyType(DataKind.U4, 0, 2)), ref value4);

            // Finally compose a trivial data view out of all this.
            var row  = RowColumnUtils.GetRow(counted, col1, col2, col3, col4);
            var view = RowCursorUtils.RowAsDataView(env, row);

            // Whew! I'm glad that's over with. Let us start running the test in ernest.
            // First let's do a direct match of the types to ensure that works.
            view.AssertStatic(env, c => (
                                  stay: c.KeyU4.TextValues.Scalar,
                                  awhile: c.KeyU1.I4Values.Vector,
                                  and: c.KeyU2.NoValue.Scalar,
                                  listen: c.KeyU4.NoValue.VarVector));

            // Next let's match against the superclasses (where no value types are
            // asserted), to ensure that the less specific case still passes.
            view.AssertStatic(env, c => (
                                  stay: c.KeyU4.NoValue.Scalar,
                                  awhile: c.KeyU1.NoValue.Vector,
                                  and: c.KeyU2.NoValue.Scalar,
                                  listen: c.KeyU4.NoValue.VarVector));

            // Here we assert a subset.
            view.AssertStatic(env, c => (
                                  stay: c.KeyU4.TextValues.Scalar,
                                  awhile: c.KeyU1.I4Values.Vector));

            // OK. Now we've confirmed the basic stuff works, let's check other scenarios.
            // Due to the fact that we cannot yet assert only a *single* column, these always appear
            // in at least pairs.

            // First try to get the right type of exception to test against.
            Type e = null;

            try
            {
                view.AssertStatic(env, c => (
                                      stay: c.KeyU4.TextValues.Scalar,
                                      awhile: c.KeyU2.I4Values.Vector));
            }
            catch (Exception eCaught)
            {
                e = eCaught.GetType();
            }
            Assert.NotNull(e);

            // What if the key representation type is wrong?
            Assert.Throws(e, () =>
                          view.AssertStatic(env, c => (
                                                stay: c.KeyU4.TextValues.Scalar,
                                                awhile: c.KeyU2.I4Values.Vector)));

            // What if the key value type is wrong?
            Assert.Throws(e, () =>
                          view.AssertStatic(env, c => (
                                                stay: c.KeyU4.TextValues.Scalar,
                                                awhile: c.KeyU1.I2Values.Vector)));

            // Same two tests, but for scalar?
            Assert.Throws(e, () =>
                          view.AssertStatic(env, c => (
                                                stay: c.KeyU2.TextValues.Scalar,
                                                awhile: c.KeyU1.I2Values.Vector)));

            Assert.Throws(e, () =>
                          view.AssertStatic(env, c => (
                                                stay: c.KeyU4.BoolValues.Scalar,
                                                awhile: c.KeyU1.I2Values.Vector)));

            // How about if we misidentify the vectorness?
            Assert.Throws(e, () =>
                          view.AssertStatic(env, c => (
                                                stay: c.KeyU4.TextValues.Vector,
                                                awhile: c.KeyU1.I2Values.Vector)));

            // How about the names?
            Assert.Throws(e, () =>
                          view.AssertStatic(env, c => (
                                                stay: c.KeyU4.TextValues.Scalar,
                                                alot: c.KeyU1.I4Values.Vector)));
        }
Exemplo n.º 18
0
        private FieldAwareFactorizationMachineModelParameters TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data,
                                                                        RoleMappedData validData = null, FieldAwareFactorizationMachineModelParameters predictor = null)
        {
            _host.AssertValue(ch);
            _host.AssertValue(pch);

            data.CheckBinaryLabel();
            var featureColumns    = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature);
            int fieldCount        = featureColumns.Count;
            int totalFeatureCount = 0;

            int[] fieldColumnIndexes = new int[fieldCount];
            for (int f = 0; f < fieldCount; f++)
            {
                var col = featureColumns[f];
                _host.Assert(!col.IsHidden);
                if (!(col.Type is VectorDataViewType vectorType) ||
                    !vectorType.IsKnownSize ||
                    vectorType.ItemType != NumberDataViewType.Single)
                {
                    throw ch.ExceptParam(nameof(data), "Training feature column '{0}' must be a known-size vector of Single, but has type: {1}.", col.Name, col.Type);
                }
                _host.Assert(vectorType.Size > 0);
                fieldColumnIndexes[f] = col.Index;
                totalFeatureCount    += vectorType.Size;
            }
            ch.Check(checked (totalFeatureCount * fieldCount * _latentDimAligned) <= Utils.ArrayMaxSize, "Latent dimension or the number of fields too large");
            if (predictor != null)
            {
                ch.Check(predictor.FeatureCount == totalFeatureCount, "Input model's feature count mismatches training feature count");
                ch.Check(predictor.LatentDimension == _latentDim, "Input model's latent dimension mismatches trainer's");
            }
            if (validData != null)
            {
                validData.CheckBinaryLabel();
                var validFeatureColumns = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature);
                _host.Assert(fieldCount == validFeatureColumns.Count);
                for (int f = 0; f < fieldCount; f++)
                {
                    var featCol      = featureColumns[f];
                    var validFeatCol = validFeatureColumns[f];
                    _host.Assert(featCol.Name == validFeatCol.Name);
                    _host.Assert(featCol.Type == validFeatCol.Type);
                }
            }
            bool shuffle = _shuffle;

            if (shuffle && !data.Data.CanShuffle)
            {
                ch.Warning("Training data does not support shuffling, so ignoring request to shuffle");
                shuffle = false;
            }
            var rng                = shuffle ? _host.Rand : null;
            var featureGetters     = new ValueGetter <VBuffer <float> > [fieldCount];
            var featureBuffer      = new VBuffer <float>();
            var featureValueBuffer = new float[totalFeatureCount];
            var featureIndexBuffer = new int[totalFeatureCount];
            var featureFieldBuffer = new int[totalFeatureCount];
            var latentSum          = new AlignedArray(fieldCount * fieldCount * _latentDimAligned, 16);
            var metricNames        = new List <string>()
            {
                "Training-loss"
            };

            if (validData != null)
            {
                metricNames.Add("Validation-loss");
            }
            int    iter                 = 0;
            long   exampleCount         = 0;
            long   badExampleCount      = 0;
            long   validBadExampleCount = 0;
            double loss                 = 0;
            double validLoss            = 0;

            pch.SetHeader(new ProgressHeader(metricNames.ToArray(), new string[] { "iterations", "examples" }), entry =>
            {
                entry.SetProgress(0, iter, _numIterations);
                entry.SetProgress(1, exampleCount);
            });

            var columns = data.Schema.Schema.Where(x => fieldColumnIndexes.Contains(x.Index)).ToList();

            columns.Add(data.Schema.Label.Value);
            if (data.Schema.Weight != null)
            {
                columns.Add(data.Schema.Weight.Value);
            }

            InitializeTrainingState(fieldCount, totalFeatureCount, predictor, out float[] linearWeights,
                                    out AlignedArray latentWeightsAligned, out float[] linearAccSqGrads, out AlignedArray latentAccSqGradsAligned);

            // refer to Algorithm 3 in https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
            while (iter++ < _numIterations)
            {
                using (var cursor = data.Data.GetRowCursor(columns, rng))
                {
                    var labelGetter  = RowCursorUtils.GetLabelGetter(cursor, data.Schema.Label.Value.Index);
                    var weightGetter = data.Schema.Weight?.Index is int weightIdx?RowCursorUtils.GetGetterAs <float>(NumberDataViewType.Single, cursor, weightIdx) : null;

                    for (int i = 0; i < fieldCount; i++)
                    {
                        featureGetters[i] = cursor.GetGetter <VBuffer <float> >(cursor.Schema[fieldColumnIndexes[i]]);
                    }
                    loss            = 0;
                    exampleCount    = 0;
                    badExampleCount = 0;
                    while (cursor.MoveNext())
                    {
                        float label         = 0;
                        float weight        = 1;
                        int   count         = 0;
                        float modelResponse = 0;
                        labelGetter(ref label);
                        weightGetter?.Invoke(ref weight);
                        float annihilation = label - label + weight - weight;
                        if (!FloatUtils.IsFinite(annihilation))
                        {
                            badExampleCount++;
                            continue;
                        }
                        if (!FieldAwareFactorizationMachineUtils.LoadOneExampleIntoBuffer(featureGetters, featureBuffer, _norm, ref count,
                                                                                          featureFieldBuffer, featureIndexBuffer, featureValueBuffer))
                        {
                            badExampleCount++;
                            continue;
                        }

                        // refer to Algorithm 1 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
                        FieldAwareFactorizationMachineInterface.CalculateIntermediateVariables(fieldCount, _latentDimAligned, count,
                                                                                               featureFieldBuffer, featureIndexBuffer, featureValueBuffer, linearWeights, latentWeightsAligned, latentSum, ref modelResponse);
                        var slope = CalculateLossSlope(label, modelResponse);

                        // refer to Algorithm 2 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
                        FieldAwareFactorizationMachineInterface.CalculateGradientAndUpdate(_lambdaLinear, _lambdaLatent, _learningRate, fieldCount, _latentDimAligned, weight, count,
                                                                                           featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum, slope, linearWeights, latentWeightsAligned, linearAccSqGrads, latentAccSqGradsAligned);
                        loss += weight * CalculateLoss(label, modelResponse);
                        exampleCount++;
                    }
                    loss /= exampleCount;
                }

                if (_verbose)
                {
                    if (validData == null)
                    {
                        pch.Checkpoint(loss, iter, exampleCount);
                    }
                    else
                    {
                        validLoss = CalculateAvgLoss(ch, validData, _norm, linearWeights, latentWeightsAligned, _latentDimAligned, latentSum,
                                                     featureFieldBuffer, featureIndexBuffer, featureValueBuffer, featureBuffer, ref validBadExampleCount);
                        pch.Checkpoint(loss, validLoss, iter, exampleCount);
                    }
                }
            }
            if (badExampleCount != 0)
            {
                ch.Warning($"Skipped {badExampleCount} examples with bad label/weight/features in training set");
            }
            if (validBadExampleCount != 0)
            {
                ch.Warning($"Skipped {validBadExampleCount} examples with bad label/weight/features in validation set");
            }

            return(new FieldAwareFactorizationMachineModelParameters(_host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned));
        }
        private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data, RoleMappedData validData = null)
        {
            Host.AssertValue(ch);
            ch.AssertValue(data);
            ch.AssertValueOrNull(validData);

            ColumnInfo matrixColumnIndexColInfo;
            ColumnInfo matrixRowIndexColInfo;
            ColumnInfo validMatrixColumnIndexColInfo = null;
            ColumnInfo validMatrixRowIndexColInfo    = null;

            ch.CheckValue(data.Schema.Label, nameof(data), "Input data did not have a unique label");
            RecommenderUtils.CheckAndGetMatrixIndexColumns(data, out matrixColumnIndexColInfo, out matrixRowIndexColInfo, isDecode: false);
            if (data.Schema.Label.Type != NumberType.R4 && data.Schema.Label.Type != NumberType.R8)
            {
                throw ch.Except("Column '{0}' for label should be floating point, but is instead {1}", data.Schema.Label.Name, data.Schema.Label.Type);
            }
            MatrixFactorizationPredictor predictor;

            if (validData != null)
            {
                ch.CheckValue(validData, nameof(validData));
                ch.CheckValue(validData.Schema.Label, nameof(validData), "Input validation data did not have a unique label");
                RecommenderUtils.CheckAndGetMatrixIndexColumns(validData, out validMatrixColumnIndexColInfo, out validMatrixRowIndexColInfo, isDecode: false);
                if (validData.Schema.Label.Type != NumberType.R4 && validData.Schema.Label.Type != NumberType.R8)
                {
                    throw ch.Except("Column '{0}' for validation label should be floating point, but is instead {1}", data.Schema.Label.Name, data.Schema.Label.Type);
                }

                if (!matrixColumnIndexColInfo.Type.Equals(validMatrixColumnIndexColInfo.Type))
                {
                    throw ch.ExceptParam(nameof(validData), "Train and validation sets' matrix-column types differed, {0} vs. {1}",
                                         matrixColumnIndexColInfo.Type, validMatrixColumnIndexColInfo.Type);
                }
                if (!matrixRowIndexColInfo.Type.Equals(validMatrixRowIndexColInfo.Type))
                {
                    throw ch.ExceptParam(nameof(validData), "Train and validation sets' matrix-row types differed, {0} vs. {1}",
                                         matrixRowIndexColInfo.Type, validMatrixRowIndexColInfo.Type);
                }
            }

            int colCount = matrixColumnIndexColInfo.Type.KeyCount;
            int rowCount = matrixRowIndexColInfo.Type.KeyCount;

            ch.Assert(rowCount > 0);
            ch.Assert(colCount > 0);

            // Checks for equality on the validation set ensure it is correct here.
            using (var cursor = data.Data.GetRowCursor(c => c == matrixColumnIndexColInfo.Index || c == matrixRowIndexColInfo.Index || c == data.Schema.Label.Index))
            {
                // LibMF works only over single precision floats, but we want to be able to consume either.
                var labGetter = RowCursorUtils.GetGetterAs <float>(NumberType.R4, cursor, data.Schema.Label.Index);
                var matrixColumnIndexGetter = RowCursorUtils.GetGetterAs <uint>(NumberType.U4, cursor, matrixColumnIndexColInfo.Index);
                var matrixRowIndexGetter    = RowCursorUtils.GetGetterAs <uint>(NumberType.U4, cursor, matrixRowIndexColInfo.Index);

                if (validData == null)
                {
                    // Have the trainer do its work.
                    using (var buffer = PrepareBuffer())
                    {
                        buffer.Train(ch, rowCount, colCount, cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter);
                        predictor = new MatrixFactorizationPredictor(Host, buffer, matrixColumnIndexColInfo.Type.AsKey, matrixRowIndexColInfo.Type.AsKey);
                    }
                }
                else
                {
                    using (var validCursor = validData.Data.GetRowCursor(
                               c => c == validMatrixColumnIndexColInfo.Index || c == validMatrixRowIndexColInfo.Index || c == validData.Schema.Label.Index))
                    {
                        ValueGetter <float> validLabelGetter = RowCursorUtils.GetGetterAs <float>(NumberType.R4, validCursor, validData.Schema.Label.Index);
                        var validMatrixColumnIndexGetter     = RowCursorUtils.GetGetterAs <uint>(NumberType.U4, validCursor, validMatrixColumnIndexColInfo.Index);
                        var validMatrixRowIndexGetter        = RowCursorUtils.GetGetterAs <uint>(NumberType.U4, validCursor, validMatrixRowIndexColInfo.Index);

                        // Have the trainer do its work.
                        using (var buffer = PrepareBuffer())
                        {
                            buffer.TrainWithValidation(ch, rowCount, colCount,
                                                       cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter,
                                                       validCursor, validLabelGetter, validMatrixRowIndexGetter, validMatrixColumnIndexGetter);
                            predictor = new MatrixFactorizationPredictor(Host, buffer, matrixColumnIndexColInfo.Type.AsKey, matrixRowIndexColInfo.Type.AsKey);
                        }
                    }
                }
            }
            return(predictor);
        }
        private ValueGetter <long> GetLabelGetter(DataViewRow row, DataViewSchema.Column col)
        {
            // The label column type is checked as part of args validation.
            var type = col.Type;

            _host.Assert(type is KeyDataViewType || type is NumberDataViewType || type is BooleanDataViewType);

            if (type is BooleanDataViewType)
            {
                bool src    = default;
                var  getSrc = row.GetGetter <bool>(col);
                return
                    ((ref long dst) =>
                {
                    getSrc(ref src);
                    if (src)
                    {
                        dst = 1;
                    }
                    else
                    {
                        dst = 0;
                    }
                });
            }
            if (type is KeyDataViewType)
            {
                _host.Assert(type.GetKeyCount() > 0);

                int   size   = type.GetKeyCountAsInt32();
                ulong src    = 0;
                var   getSrc = RowCursorUtils.GetGetterAs <ulong>(NumberDataViewType.UInt64, row, col.Index);
                return
                    ((ref long dst) =>
                {
                    getSrc(ref src);
                    // The value should fall between 0 and size inclusive, where 0 is considered
                    // missing/invalid (this is the contract of the KeyType). However, we still handle the
                    // cases of too large values correctly (by treating them as invalid).
                    if (src <= (ulong)size)
                    {
                        dst = (long)src - 1;
                    }
                    else
                    {
                        dst = -1;
                    }
                });
            }
            else
            {
                double src    = 0;
                var    getSrc = RowCursorUtils.GetGetterAs <double>(NumberDataViewType.Double, row, col.Index);
                return
                    ((ref long dst) =>
                {
                    getSrc(ref src);
                    // NaN maps to -1.
                    if (double.IsNaN(src))
                    {
                        dst = -1;
                    }
                    else
                    {
                        dst = (long)src;
                    }
                });
            }
        }
Exemplo n.º 21
0
        protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
        {
            Host.AssertValueOrNull(ch);
            Host.AssertValue(input);
            Host.Assert(0 <= iinfo && iinfo < Infos.Length);
            Host.Assert(Infos[iinfo].TypeSrc.IsVector);
            Host.Assert(Infos[iinfo].TypeSrc.ItemType.IsKey);

            disposer = null;

            var getSrc = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, input, Infos[iinfo].Source);
            var src    = default(VBuffer <uint>);
            var bldr   = new NgramBufferBuilder(_exes[iinfo].NgramLength, _exes[iinfo].SkipLength,
                                                _ngramMaps[iinfo].Count, GetNgramIdFinder(iinfo));
            var keyCount = (uint)Infos[iinfo].TypeSrc.ItemType.KeyCount;

            if (keyCount == 0)
            {
                keyCount = uint.MaxValue;
            }

            ValueGetter <VBuffer <Float> > del;

            switch (_exes[iinfo].Weighting)
            {
            case WeightingCriteria.TfIdf:
                Host.AssertValue(_invDocFreqs[iinfo]);
                del =
                    (ref VBuffer <Float> dst) =>
                {
                    getSrc(ref src);
                    if (!bldr.IsEmpty)
                    {
                        bldr.Reset();
                        bldr.AddNgrams(in src, 0, keyCount);
                        bldr.GetResult(ref dst);
                        VBufferUtils.Apply(ref dst, (int i, ref Float v) => v = (Float)(v * _invDocFreqs[iinfo][i]));
                    }
                    else
                    {
                        dst = new VBuffer <Float>(0, dst.Values, dst.Indices);
                    }
                };
                break;

            case WeightingCriteria.Idf:
                Host.AssertValue(_invDocFreqs[iinfo]);
                del =
                    (ref VBuffer <Float> dst) =>
                {
                    getSrc(ref src);
                    if (!bldr.IsEmpty)
                    {
                        bldr.Reset();
                        bldr.AddNgrams(in src, 0, keyCount);
                        bldr.GetResult(ref dst);
                        VBufferUtils.Apply(ref dst, (int i, ref Float v) => v = v >= 1 ? (Float)_invDocFreqs[iinfo][i] : 0);
                    }
                    else
                    {
                        dst = new VBuffer <Float>(0, dst.Values, dst.Indices);
                    }
                };
                break;

            case WeightingCriteria.Tf:
                del =
                    (ref VBuffer <Float> dst) =>
                {
                    getSrc(ref src);
                    if (!bldr.IsEmpty)
                    {
                        bldr.Reset();
                        bldr.AddNgrams(in src, 0, keyCount);
                        bldr.GetResult(ref dst);
                    }
                    else
                    {
                        dst = new VBuffer <Float>(0, dst.Values, dst.Indices);
                    }
                };
                break;

            default:
                throw Host.Except("Unsupported weighting criteria");
            }

            return(del);
        }
Exemplo n.º 22
0
        private SequencePool[] Train(Arguments args, IDataView trainingData, out double[][] invDocFreqs)
        {
            // Contains the maximum number of grams to store in the dictionary, for each level of ngrams,
            // from 1 (in position 0) up to ngramLength (in position ngramLength-1)
            var lims = new int[Infos.Length][];

            for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
            {
                var all         = args.Column[iinfo].AllLengths ?? args.AllLengths;
                var ngramLength = _exes[iinfo].NgramLength;
                var maxNumTerms = Utils.Size(args.Column[iinfo].MaxNumTerms) > 0 ? args.Column[iinfo].MaxNumTerms : args.MaxNumTerms;
                if (!all)
                {
                    Host.CheckUserArg(Utils.Size(maxNumTerms) == 0 ||
                                      Utils.Size(maxNumTerms) == 1 && maxNumTerms[0] > 0, nameof(args.MaxNumTerms));
                    lims[iinfo] = new int[ngramLength];
                    lims[iinfo][ngramLength - 1] = Utils.Size(maxNumTerms) == 0 ? Arguments.DefaultMaxTerms : maxNumTerms[0];
                }
                else
                {
                    Host.CheckUserArg(Utils.Size(maxNumTerms) <= ngramLength, nameof(args.MaxNumTerms));
                    Host.CheckUserArg(Utils.Size(maxNumTerms) == 0 || maxNumTerms.All(i => i >= 0) && maxNumTerms[maxNumTerms.Length - 1] > 0, nameof(args.MaxNumTerms));
                    var extend = Utils.Size(maxNumTerms) == 0 ? Arguments.DefaultMaxTerms : maxNumTerms[maxNumTerms.Length - 1];
                    lims[iinfo] = Utils.BuildArray(ngramLength,
                                                   i => i < Utils.Size(maxNumTerms) ? maxNumTerms[i] : extend);
                }
            }

            var helpers = new NgramBufferBuilder[Infos.Length];
            var getters = new ValueGetter <VBuffer <uint> > [Infos.Length];
            var src     = new VBuffer <uint> [Infos.Length];

            // Keep track of how many grams are in the pool for each value of n. Position
            // i in _counts counts how many (i+1)-grams are in the pool for column iinfo.
            var counts    = new int[Infos.Length][];
            var ngramMaps = new SequencePool[Infos.Length];

            bool[] activeInput = new bool[trainingData.Schema.ColumnCount];
            foreach (var info in Infos)
            {
                activeInput[info.Source] = true;
            }
            using (var cursor = trainingData.GetRowCursor(col => activeInput[col]))
                using (var pch = Host.StartProgressChannel("Building n-gram dictionary"))
                {
                    for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                    {
                        Host.Assert(Infos[iinfo].TypeSrc.IsVector && Infos[iinfo].TypeSrc.ItemType.IsKey);
                        var ngramLength = _exes[iinfo].NgramLength;
                        var skipLength  = _exes[iinfo].SkipLength;

                        getters[iinfo]   = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, cursor, Infos[iinfo].Source);
                        src[iinfo]       = default(VBuffer <uint>);
                        counts[iinfo]    = new int[ngramLength];
                        ngramMaps[iinfo] = new SequencePool();

                        // Note: GetNgramIdFinderAdd will control how many ngrams of a specific length will
                        // be added (using lims[iinfo]), therefore we set slotLim to the maximum
                        helpers[iinfo] = new NgramBufferBuilder(ngramLength, skipLength, Utils.ArrayMaxSize,
                                                                GetNgramIdFinderAdd(counts[iinfo], lims[iinfo], ngramMaps[iinfo], _exes[iinfo].RequireIdf(), Host));
                    }

                    int    cInfoFull = 0;
                    bool[] infoFull  = new bool[Infos.Length];

                    invDocFreqs = new double[Infos.Length][];

                    long   totalDocs = 0;
                    Double rowCount  = trainingData.GetRowCount() ?? Double.NaN;
                    var    buffers   = new VBuffer <float> [Infos.Length];
                    pch.SetHeader(new ProgressHeader(new[] { "Total n-grams" }, new[] { "documents" }),
                                  e => e.SetProgress(0, totalDocs, rowCount));
                    while (cInfoFull < Infos.Length && cursor.MoveNext())
                    {
                        totalDocs++;
                        for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                        {
                            getters[iinfo](ref src[iinfo]);
                            var keyCount = (uint)Infos[iinfo].TypeSrc.ItemType.KeyCount;
                            if (keyCount == 0)
                            {
                                keyCount = uint.MaxValue;
                            }
                            if (!infoFull[iinfo])
                            {
                                if (_exes[iinfo].RequireIdf())
                                {
                                    helpers[iinfo].Reset();
                                }

                                helpers[iinfo].AddNgrams(in src[iinfo], 0, keyCount);
                                if (_exes[iinfo].RequireIdf())
                                {
                                    int totalNgrams = counts[iinfo].Sum();
                                    Utils.EnsureSize(ref invDocFreqs[iinfo], totalNgrams);
                                    helpers[iinfo].GetResult(ref buffers[iinfo]);
                                    foreach (var pair in buffers[iinfo].Items())
                                    {
                                        if (pair.Value >= 1)
                                        {
                                            invDocFreqs[iinfo][pair.Key] += 1;
                                        }
                                    }
                                }
                            }
                            AssertValid(counts[iinfo], lims[iinfo], ngramMaps[iinfo]);
                        }
                    }

                    pch.Checkpoint(counts.Sum(c => c.Sum()), totalDocs);
                    for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                    {
                        for (int i = 0; i < Utils.Size(invDocFreqs[iinfo]); i++)
                        {
                            if (invDocFreqs[iinfo][i] != 0)
                            {
                                invDocFreqs[iinfo][i] = Math.Log(totalDocs / invDocFreqs[iinfo][i]);
                            }
                        }
                    }

                    for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                    {
                        AssertValid(counts[iinfo], lims[iinfo], ngramMaps[iinfo]);

                        int ngramLength = _exes[iinfo].NgramLength;
                        for (int i = 0; i < ngramLength; i++)
                        {
                            _exes[iinfo].NonEmptyLevels[i] = counts[iinfo][i] > 0;
                        }
                    }

                    return(ngramMaps);
                }
        }
Exemplo n.º 23
0
        private MatrixFactorizationModelParameters TrainCore(IChannel ch, RoleMappedData data, RoleMappedData validData = null)
        {
            _host.AssertValue(ch);
            ch.AssertValue(data);
            ch.AssertValueOrNull(validData);

            ch.CheckParam(data.Schema.Label.HasValue, nameof(data), "Input data did not have a unique label");
            RecommenderUtils.CheckAndGetMatrixIndexColumns(data, out var matrixColumnIndexColInfo, out var matrixRowIndexColInfo, isDecode: false);
            var labelCol = data.Schema.Label.Value;

            if (labelCol.Type != NumberDataViewType.Single && labelCol.Type != NumberDataViewType.Double)
            {
                throw ch.Except("Column '{0}' for label should be floating point, but is instead {1}", labelCol.Name, labelCol.Type);
            }
            MatrixFactorizationModelParameters predictor;

            if (validData != null)
            {
                ch.CheckValue(validData, nameof(validData));
                ch.CheckParam(validData.Schema.Label.HasValue, nameof(validData), "Input validation data did not have a unique label");
                RecommenderUtils.CheckAndGetMatrixIndexColumns(validData, out var validMatrixColumnIndexColInfo, out var validMatrixRowIndexColInfo, isDecode: false);
                var validLabelCol = validData.Schema.Label.Value;
                if (validLabelCol.Type != NumberDataViewType.Single && validLabelCol.Type != NumberDataViewType.Double)
                {
                    throw ch.Except("Column '{0}' for validation label should be floating point, but is instead {1}", validLabelCol.Name, validLabelCol.Type);
                }

                if (!matrixColumnIndexColInfo.Type.Equals(validMatrixColumnIndexColInfo.Type))
                {
                    throw ch.ExceptParam(nameof(validData), "Train and validation sets' matrix-column types differed, {0} vs. {1}",
                                         matrixColumnIndexColInfo.Type, validMatrixColumnIndexColInfo.Type);
                }
                if (!matrixRowIndexColInfo.Type.Equals(validMatrixRowIndexColInfo.Type))
                {
                    throw ch.ExceptParam(nameof(validData), "Train and validation sets' matrix-row types differed, {0} vs. {1}",
                                         matrixRowIndexColInfo.Type, validMatrixRowIndexColInfo.Type);
                }
            }

            int colCount = matrixColumnIndexColInfo.Type.GetKeyCountAsInt32(_host);
            int rowCount = matrixRowIndexColInfo.Type.GetKeyCountAsInt32(_host);

            ch.Assert(rowCount > 0);
            ch.Assert(colCount > 0);

            // Checks for equality on the validation set ensure it is correct here.
            using (var cursor = data.Data.GetRowCursor(matrixColumnIndexColInfo, matrixRowIndexColInfo, data.Schema.Label.Value))
            {
                // LibMF works only over single precision floats, but we want to be able to consume either.
                var labGetter = RowCursorUtils.GetGetterAs <float>(NumberDataViewType.Single, cursor, data.Schema.Label.Value.Index);
                var matrixColumnIndexGetter = RowCursorUtils.GetGetterAs <uint>(NumberDataViewType.UInt32, cursor, matrixColumnIndexColInfo.Index);
                var matrixRowIndexGetter    = RowCursorUtils.GetGetterAs <uint>(NumberDataViewType.UInt32, cursor, matrixRowIndexColInfo.Index);

                if (validData == null)
                {
                    // Have the trainer do its work.
                    using (var buffer = PrepareBuffer())
                    {
                        buffer.Train(ch, rowCount, colCount, cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter);
                        predictor = new MatrixFactorizationModelParameters(_host, buffer, (KeyDataViewType)matrixColumnIndexColInfo.Type, (KeyDataViewType)matrixRowIndexColInfo.Type);
                    }
                }
                else
                {
                    RecommenderUtils.CheckAndGetMatrixIndexColumns(validData, out var validMatrixColumnIndexColInfo, out var validMatrixRowIndexColInfo, isDecode: false);
                    using (var validCursor = validData.Data.GetRowCursor(matrixColumnIndexColInfo, matrixRowIndexColInfo, data.Schema.Label.Value))
                    {
                        ValueGetter <float> validLabelGetter = RowCursorUtils.GetGetterAs <float>(NumberDataViewType.Single, validCursor, validData.Schema.Label.Value.Index);
                        var validMatrixColumnIndexGetter     = RowCursorUtils.GetGetterAs <uint>(NumberDataViewType.UInt32, validCursor, validMatrixColumnIndexColInfo.Index);
                        var validMatrixRowIndexGetter        = RowCursorUtils.GetGetterAs <uint>(NumberDataViewType.UInt32, validCursor, validMatrixRowIndexColInfo.Index);

                        // Have the trainer do its work.
                        using (var buffer = PrepareBuffer())
                        {
                            buffer.TrainWithValidation(ch, rowCount, colCount,
                                                       cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter,
                                                       validCursor, validLabelGetter, validMatrixRowIndexGetter, validMatrixColumnIndexGetter);
                            predictor = new MatrixFactorizationModelParameters(_host, buffer, (KeyDataViewType)matrixColumnIndexColInfo.Type, (KeyDataViewType)matrixRowIndexColInfo.Type);
                        }
                    }
                }
            }
            return(predictor);
        }
Exemplo n.º 24
0
        private void Train(IChannel ch, IDataView trainingData, LdaState[] states)
        {
            Host.AssertValue(ch);
            ch.AssertValue(trainingData);
            ch.AssertValue(states);
            ch.Assert(states.Length == Infos.Length);

            bool[] activeColumns = new bool[trainingData.Schema.ColumnCount];
            int[]  numVocabs     = new int[Infos.Length];

            for (int i = 0; i < Infos.Length; i++)
            {
                activeColumns[Infos[i].Source] = true;
                numVocabs[i] = 0;
            }

            //the current lda needs the memory allocation before feedin data, so needs two sweeping of the data,
            //one for the pre-calc memory, one for feedin data really
            //another solution can be prepare these two value externally and put them in the beginning of the input file.
            long[] corpusSize  = new long[Infos.Length];
            int[]  numDocArray = new int[Infos.Length];

            using (var cursor = trainingData.GetRowCursor(col => activeColumns[col]))
            {
                var getters = new ValueGetter <VBuffer <Double> > [Utils.Size(Infos)];
                for (int i = 0; i < Infos.Length; i++)
                {
                    corpusSize[i]  = 0;
                    numDocArray[i] = 0;
                    getters[i]     = RowCursorUtils.GetVecGetterAs <Double>(NumberType.R8, cursor, Infos[i].Source);
                }
                VBuffer <Double> src      = default(VBuffer <Double>);
                long             rowCount = 0;

                while (cursor.MoveNext())
                {
                    ++rowCount;
                    for (int i = 0; i < Infos.Length; i++)
                    {
                        int docSize = 0;
                        getters[i](ref src);

                        // compute term, doc instance#.
                        for (int termID = 0; termID < src.Count; termID++)
                        {
                            int termFreq = GetFrequency(src.Values[termID]);
                            if (termFreq < 0)
                            {
                                // Ignore this row.
                                docSize = 0;
                                break;
                            }

                            if (docSize >= _exes[i].NumMaxDocToken - termFreq)
                            {
                                break; //control the document length
                            }
                            //if legal then add the term
                            docSize += termFreq;
                        }

                        // Ignore empty doc
                        if (docSize == 0)
                        {
                            continue;
                        }

                        numDocArray[i]++;
                        corpusSize[i] += docSize * 2 + 1;   // in the beggining of each doc, there is a cursor variable

                        // increase numVocab if needed.
                        if (numVocabs[i] < src.Length)
                        {
                            numVocabs[i] = src.Length;
                        }
                    }
                }

                for (int i = 0; i < Infos.Length; ++i)
                {
                    if (numDocArray[i] != rowCount)
                    {
                        ch.Assert(numDocArray[i] < rowCount);
                        ch.Warning($"Column '{Infos[i].Name}' has skipped {rowCount - numDocArray[i]} of {rowCount} rows either empty or with negative, non-finite, or fractional values.");
                    }
                }
            }

            // Initialize all LDA states
            for (int i = 0; i < Infos.Length; i++)
            {
                var state = new LdaState(Host, _exes[i], numVocabs[i]);
                if (numDocArray[i] == 0 || corpusSize[i] == 0)
                {
                    throw ch.Except("The specified documents are all empty in column '{0}'.", Infos[i].Name);
                }

                state.AllocateDataMemory(numDocArray[i], corpusSize[i]);
                states[i] = state;
            }

            using (var cursor = trainingData.GetRowCursor(col => activeColumns[col]))
            {
                int[] docSizeCheck = new int[Infos.Length];
                // This could be optimized so that if multiple trainers consume the same column, it is
                // fed into the train method once.
                var getters = new ValueGetter <VBuffer <Double> > [Utils.Size(Infos)];
                for (int i = 0; i < Infos.Length; i++)
                {
                    docSizeCheck[i] = 0;
                    getters[i]      = RowCursorUtils.GetVecGetterAs <Double>(NumberType.R8, cursor, Infos[i].Source);
                }

                VBuffer <Double> src = default(VBuffer <Double>);

                while (cursor.MoveNext())
                {
                    for (int i = 0; i < Infos.Length; i++)
                    {
                        getters[i](ref src);
                        docSizeCheck[i] += states[i].FeedTrain(Host, ref src);
                    }
                }
                for (int i = 0; i < Infos.Length; i++)
                {
                    Host.Assert(corpusSize[i] == docSizeCheck[i]);
                    states[i].CompleteTrain();
                }
            }
        }