コード例 #1
0
        protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
        {
            Host.AssertValueOrNull(ch);
            Host.AssertValue(input);
            Host.Assert(0 <= iinfo && iinfo < Infos.Length);
            disposer = null;

            var typeSrc = Infos[iinfo].TypeSrc;
            var typeDst = _exes[iinfo].TypeDst;

            if (!typeDst.IsVector)
            {
                return(RowCursorUtils.GetGetterAs(typeDst, input, Infos[iinfo].Source));
            }
            return(RowCursorUtils.GetVecGetterAs(typeDst.AsVector.ItemType, input, Infos[iinfo].Source));
        }
コード例 #2
0
            public override void InitializeNextPass(IRow row, RoleMappedSchema schema)
            {
                Contracts.Assert(PassNum < 1);
                Contracts.AssertValue(schema.Label);

                var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);

                _labelGetter = RowCursorUtils.GetVecGetterAs <Float>(NumberType.Float, row, schema.Label.Index);
                _scoreGetter = row.GetGetter <VBuffer <Float> >(score.Index);
                Contracts.AssertValue(_labelGetter);
                Contracts.AssertValue(_scoreGetter);

                if (schema.Weight != null)
                {
                    _weightGetter = row.GetGetter <Float>(schema.Weight.Index);
                }
            }
コード例 #3
0
        public static ValueGetter <VBuffer <Single> > GetLabelGetter(ISlotCursor cursor)
        {
            var type = cursor.GetSlotType().ItemType;

            if (type == NumberType.R4)
            {
                return(cursor.GetGetter <Single>());
            }
            if (type == NumberType.R8 || type.IsBool)
            {
                return(GetVecGetterAs <Single>(NumberType.R4, cursor));
            }
            Contracts.Check(type.IsKey, "Only floating point number, boolean, and key type values can be used as label.");
            Contracts.Assert(TestGetLabelGetter(type) == null);
            ulong keyMax = (ulong)type.KeyCount;

            if (keyMax == 0)
            {
                keyMax = ulong.MaxValue;
            }
            var             getSrc = RowCursorUtils.GetVecGetterAs <ulong>(NumberType.U8, cursor);
            VBuffer <ulong> src    = default(VBuffer <ulong>);

            return
                ((ref VBuffer <Single> dst) =>
            {
                getSrc(ref src);
                // Unfortunately defaults in one to not translate to defaults of the other,
                // so this will not be sparsity preserving. Assume a dense output.
                Single[] vals = dst.Values;
                Utils.EnsureSize(ref vals, src.Length);
                foreach (var kv in src.Items(all: true))
                {
                    if (0 < kv.Value && kv.Value <= keyMax)
                    {
                        vals[kv.Key] = kv.Value - 1;
                    }
                    else
                    {
                        vals[kv.Key] = Single.NaN;
                    }
                }
                dst = new VBuffer <Single>(src.Length, vals, dst.Indices);
            });
        }
コード例 #4
0
        /// <summary>
        /// This is for the bagging case - vector input and outputs should be added.
        /// </summary>
        private ValueGetter <VBuffer <Float> > MakeGetterBag(IRow input, int iinfo)
        {
            Host.AssertValue(input);
            Host.Assert(Infos[iinfo].TypeSrc.IsVector);
            Host.Assert(Infos[iinfo].TypeSrc.ItemType.IsKey);
            Host.Assert(_bag[iinfo]);
            Host.Assert(Infos[iinfo].TypeSrc.ItemType.KeyCount == _types[iinfo].VectorSize);

            var info = Infos[iinfo];
            int size = info.TypeSrc.ItemType.KeyCount;

            Host.Assert(size > 0);

            int cv = info.TypeSrc.VectorSize;

            Host.Assert(cv >= 0);

            var getSrc = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, input, info.Source);
            var src    = default(VBuffer <uint>);
            var bldr   = BufferBuilder <float> .CreateDefault();

            return
                ((ref VBuffer <Float> dst) =>
            {
                bldr.Reset(size, false);

                getSrc(ref src);
                Host.Check(cv == 0 || src.Length == cv);

                // The indices are irrelevant in the bagging case.
                var values = src.Values;
                int count = src.Count;
                for (int slot = 0; slot < count; slot++)
                {
                    uint key = values[slot] - 1;
                    if (key < size)
                    {
                        bldr.AddFeature((int)key, 1);
                    }
                }

                bldr.GetResult(ref dst);
            });
        }
コード例 #5
0
        private static ValueGetter <Single> GetLabelGetterNotFloat(IRow cursor, int labelIndex)
        {
            var type = cursor.Schema.GetColumnType(labelIndex);

            Contracts.Assert(type != NumberType.R4 && type != NumberType.R8);

            // boolean type label mapping: True -> 1, False -> 0.
            if (type.IsBool)
            {
                var getBoolSrc = cursor.GetGetter <bool>(labelIndex);
                return
                    ((ref Single dst) =>
                {
                    bool src = default;
                    getBoolSrc(ref src);
                    dst = Convert.ToSingle(src);
                });
            }

            Contracts.Check(type.IsKey, "Only floating point number, boolean, and key type values can be used as label.");
            Contracts.Assert(TestGetLabelGetter(type) == null);
            ulong keyMax = (ulong)type.KeyCount;

            if (keyMax == 0)
            {
                keyMax = ulong.MaxValue;
            }
            var getSrc = RowCursorUtils.GetGetterAs <ulong>(NumberType.U8, cursor, labelIndex);

            return
                ((ref Single dst) =>
            {
                ulong src = 0;
                getSrc(ref src);
                if (0 < src && src <= keyMax)
                {
                    dst = src - 1;
                }
                else
                {
                    dst = Single.NaN;
                }
            });
        }
コード例 #6
0
            public override void InitializeNextPass(IRow row, RoleMappedSchema schema)
            {
                AssertValid(assertGetters: false);

                Host.AssertValue(row);
                Host.AssertValue(schema);

                if (_calculateDbi)
                {
                    Host.AssertValue(schema.Feature);
                    _featGetter = row.GetGetter <VBuffer <Single> >(schema.Feature.Index);
                }
                var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);

                Host.Assert(score.Type.VectorSize == _scoresArr.Length);
                _scoreGetter = row.GetGetter <VBuffer <Single> >(score.Index);

                if (PassNum == 0)
                {
                    if (schema.Label != null)
                    {
                        _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index);
                    }
                    else
                    {
                        _labelGetter = (ref Single value) => value = Single.NaN;
                    }
                    if (schema.Weight != null)
                    {
                        _weightGetter = row.GetGetter <Single>(schema.Weight.Index);
                    }
                }
                else
                {
                    Host.Assert(PassNum == 1 && _calculateDbi);
                    UnweightedCounters.InitializeSecondPass(_clusterCentroids);
                    if (WeightedCounters != null)
                    {
                        WeightedCounters.InitializeSecondPass(_clusterCentroids);
                    }
                }
                AssertValid(assertGetters: true);
            }
コード例 #7
0
        private static IColumn CloneColumnCore <T>(IColumn column)
        {
            Contracts.Assert(column is IColumn <T>);
            IRow meta = column.Metadata;

            if (meta != null)
            {
                meta = RowCursorUtils.CloneRow(meta);
            }

            var tcolumn = (IColumn <T>)column;

            if (!tcolumn.IsActive)
            {
                return(new InactiveImpl <T>(tcolumn.Name, meta, tcolumn.Type));
            }
            T val = default(T);

            tcolumn.GetGetter()(ref val);
            return(GetColumn(tcolumn.Name, tcolumn.Type, ref val, meta));
        }
コード例 #8
0
        /// <summary>
        /// This is for the singleton case. This should be equivalent to both Bag and Ord over
        /// a vector of size one.
        /// </summary>
        private ValueGetter <VBuffer <Float> > MakeGetterOne(IRow input, int iinfo)
        {
            Host.AssertValue(input);
            Host.Assert(Infos[iinfo].TypeSrc.IsKey);
            Host.Assert(Infos[iinfo].TypeSrc.KeyCount == _types[iinfo].VectorSize);

            int size = Infos[iinfo].TypeSrc.KeyCount;

            Host.Assert(size > 0);

            var getSrc = RowCursorUtils.GetGetterAs <uint>(NumberType.U4, input, Infos[iinfo].Source);
            var src    = default(uint);

            return
                ((ref VBuffer <Float> dst) =>
            {
                getSrc(ref src);
                if (src == 0 || src > size)
                {
                    dst = new VBuffer <Float>(size, 0, dst.Values, dst.Indices);
                    return;
                }

                var values = dst.Values;
                var indices = dst.Indices;
                if (Utils.Size(values) < 1)
                {
                    values = new Float[1];
                }
                if (Utils.Size(indices) < 1)
                {
                    indices = new int[1];
                }
                values[0] = 1;
                indices[0] = (int)src - 1;

                dst = new VBuffer <Float>(size, 1, values, indices);
            });
        }
コード例 #9
0
        private Delegate MakeGetter(IChannel ch, IRow input, int iinfo, FinderDecorator decorator = null)
        {
            ch.Assert(_bindings.Infos[iinfo].SrcTypes.All(t => t.IsVector && t.ItemType.IsKey));

            var info     = _bindings.Infos[iinfo];
            int srcCount = info.SrcIndices.Length;

            ValueGetter <VBuffer <uint> >[] getSrc = new ValueGetter <VBuffer <uint> > [srcCount];
            for (int isrc = 0; isrc < srcCount; isrc++)
            {
                getSrc[isrc] = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, input, info.SrcIndices[isrc]);
            }
            var src           = default(VBuffer <uint>);
            var ngramIdFinder = GetNgramIdFinder(iinfo);

            if (decorator != null)
            {
                ngramIdFinder = decorator(iinfo, ngramIdFinder);
            }
            var bldr = new NgramBufferBuilder(_exes[iinfo].NgramLength, _exes[iinfo].SkipLength,
                                              _bindings.Types[iinfo].ValueCount, ngramIdFinder);
            var keyCounts = _bindings.Infos[iinfo].SrcTypes.Select(
                t => t.ItemType.KeyCount > 0 ? (uint)t.ItemType.KeyCount : uint.MaxValue).ToArray();

            // REVIEW: Special casing the srcCount==1 case could potentially improve perf.
            ValueGetter <VBuffer <Float> > del =
                (ref VBuffer <Float> dst) =>
            {
                bldr.Reset();
                for (int i = 0; i < srcCount; i++)
                {
                    getSrc[i](ref src);
                    bldr.AddNgrams(ref src, i, keyCounts[i]);
                }
                bldr.GetResult(ref dst);
            };

            return(del);
        }
コード例 #10
0
        public override Delegate[] CreateGetters(IRow input, Func <int, bool> activeCols, out Action disposer)
        {
            Host.Assert(LabelIndex >= 0);
            Host.Assert(ScoreIndex >= 0);

            disposer = null;

            long cachedPosition = -1;
            var  label          = default(VBuffer <Float>);
            var  score          = default(VBuffer <Float>);

            ValueGetter <VBuffer <Float> > nullGetter = (ref VBuffer <Float> vec) => vec = default(VBuffer <Float>);
            var labelGetter = activeCols(LabelOutput) || activeCols(L1Output) || activeCols(L2Output) || activeCols(DistCol)
                ? RowCursorUtils.GetVecGetterAs <Float>(NumberType.Float, input, LabelIndex)
                : nullGetter;
            var scoreGetter = activeCols(ScoreOutput) || activeCols(L1Output) || activeCols(L2Output) || activeCols(DistCol)
                ? input.GetGetter <VBuffer <Float> >(ScoreIndex)
                : nullGetter;
            Action updateCacheIfNeeded =
                () =>
            {
                if (cachedPosition != input.Position)
                {
                    labelGetter(ref label);
                    scoreGetter(ref score);
                    cachedPosition = input.Position;
                }
            };

            var getters = new Delegate[5];

            if (activeCols(LabelOutput))
            {
                ValueGetter <VBuffer <Float> > labelFn =
                    (ref VBuffer <Float> dst) =>
                {
                    updateCacheIfNeeded();
                    label.CopyTo(ref dst);
                };
                getters[LabelOutput] = labelFn;
            }
            if (activeCols(ScoreOutput))
            {
                ValueGetter <VBuffer <Float> > scoreFn =
                    (ref VBuffer <Float> dst) =>
                {
                    updateCacheIfNeeded();
                    score.CopyTo(ref dst);
                };
                getters[ScoreOutput] = scoreFn;
            }
            if (activeCols(L1Output))
            {
                ValueGetter <double> l1Fn =
                    (ref double dst) =>
                {
                    updateCacheIfNeeded();
                    dst = VectorUtils.L1Distance(ref label, ref score);
                };
                getters[L1Output] = l1Fn;
            }
            if (activeCols(L2Output))
            {
                ValueGetter <double> l2Fn =
                    (ref double dst) =>
                {
                    updateCacheIfNeeded();
                    dst = VectorUtils.L2DistSquared(ref label, ref score);
                };
                getters[L2Output] = l2Fn;
            }
            if (activeCols(DistCol))
            {
                ValueGetter <double> distFn =
                    (ref double dst) =>
                {
                    updateCacheIfNeeded();
                    dst = MathUtils.Sqrt(VectorUtils.L2DistSquared(ref label, ref score));
                };
                getters[DistCol] = distFn;
            }
            return(getters);
        }
コード例 #11
0
        /// <summary>
        /// This is for the indicator (non-bagging) case - vector input and outputs should be concatenated.
        /// </summary>
        private ValueGetter <VBuffer <Float> > MakeGetterInd(IRow input, int iinfo)
        {
            Host.AssertValue(input);
            Host.Assert(Infos[iinfo].TypeSrc.IsVector);
            Host.Assert(Infos[iinfo].TypeSrc.ItemType.IsKey);
            Host.Assert(!_bag[iinfo]);

            var info = Infos[iinfo];
            int size = info.TypeSrc.ItemType.KeyCount;

            Host.Assert(size > 0);

            int cv = info.TypeSrc.VectorSize;

            Host.Assert(cv >= 0);
            Host.Assert(_types[iinfo].VectorSize == size * cv);

            var getSrc = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, input, info.Source);
            var src    = default(VBuffer <uint>);

            return
                ((ref VBuffer <Float> dst) =>
            {
                getSrc(ref src);
                int lenSrc = src.Length;
                Host.Check(lenSrc == cv || cv == 0);

                // Since we generate values in order, no need for a builder.
                var valuesDst = dst.Values;
                var indicesDst = dst.Indices;

                int lenDst = checked (size * lenSrc);
                int cntSrc = src.Count;
                if (Utils.Size(valuesDst) < cntSrc)
                {
                    valuesDst = new Float[cntSrc];
                }
                if (Utils.Size(indicesDst) < cntSrc)
                {
                    indicesDst = new int[cntSrc];
                }

                var values = src.Values;
                int count = 0;
                if (src.IsDense)
                {
                    Host.Assert(lenSrc == cntSrc);
                    for (int slot = 0; slot < cntSrc; slot++)
                    {
                        Host.Assert(count < cntSrc);
                        uint key = values[slot] - 1;
                        if (key >= (uint)size)
                        {
                            continue;
                        }
                        valuesDst[count] = 1;
                        indicesDst[count++] = slot * size + (int)key;
                    }
                }
                else
                {
                    var indices = src.Indices;
                    for (int islot = 0; islot < cntSrc; islot++)
                    {
                        Host.Assert(count < cntSrc);
                        uint key = values[islot] - 1;
                        if (key >= (uint)size)
                        {
                            continue;
                        }
                        valuesDst[count] = 1;
                        indicesDst[count++] = indices[islot] * size + (int)key;
                    }
                }
                dst = new VBuffer <Float>(lenDst, count, valuesDst, indicesDst);
            });
        }
コード例 #12
0
        public override Delegate[] CreateGetters(IRow input, Func <int, bool> activeCols, out Action disposer)
        {
            Host.Assert(LabelIndex >= 0);
            Host.Assert(ScoreIndex >= 0);

            disposer = null;

            long  cachedPosition = -1;
            Float label          = 0;
            var   score          = default(VBuffer <Float>);
            var   l1             = VBufferUtils.CreateDense <Double>(_scoreSize);

            ValueGetter <Float> nanGetter = (ref Float value) => value = Single.NaN;
            var labelGetter = activeCols(L1Col) || activeCols(L2Col) ? RowCursorUtils.GetLabelGetter(input, LabelIndex) : nanGetter;
            ValueGetter <VBuffer <Float> > scoreGetter;

            if (activeCols(L1Col) || activeCols(L2Col))
            {
                scoreGetter = input.GetGetter <VBuffer <Float> >(ScoreIndex);
            }
            else
            {
                scoreGetter = (ref VBuffer <Float> dst) => dst = default(VBuffer <Float>);
            }
            Action updateCacheIfNeeded =
                () =>
            {
                if (cachedPosition != input.Position)
                {
                    labelGetter(ref label);
                    scoreGetter(ref score);
                    var lab = (Double)label;
                    foreach (var s in score.Items(all: true))
                    {
                        l1.Values[s.Key] = Math.Abs(lab - s.Value);
                    }
                    cachedPosition = input.Position;
                }
            };

            var getters = new Delegate[2];

            if (activeCols(L1Col))
            {
                ValueGetter <VBuffer <Double> > l1Fn =
                    (ref VBuffer <Double> dst) =>
                {
                    updateCacheIfNeeded();
                    l1.CopyTo(ref dst);
                };
                getters[L1Col] = l1Fn;
            }
            if (activeCols(L2Col))
            {
                VBufferUtils.PairManipulator <Double, Double> sqr =
                    (int slot, Double x, ref Double y) => y = x * x;

                ValueGetter <VBuffer <Double> > l2Fn =
                    (ref VBuffer <Double> dst) =>
                {
                    updateCacheIfNeeded();
                    dst = new VBuffer <Double>(_scoreSize, 0, dst.Values, dst.Indices);
                    VBufferUtils.ApplyWith(ref l1, ref dst, sqr);
                };
                getters[L2Col] = l2Fn;
            }
            return(getters);
        }
コード例 #13
0
        protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
        {
            Host.AssertValueOrNull(ch);
            Host.AssertValue(input);
            Host.Assert(0 <= iinfo && iinfo < Infos.Length);
            Host.Assert(Infos[iinfo].TypeSrc.IsVector);
            Host.Assert(Infos[iinfo].TypeSrc.ItemType.IsKey);

            disposer = null;

            var getSrc = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, input, Infos[iinfo].Source);
            var src    = default(VBuffer <uint>);
            var bldr   = new NgramBufferBuilder(_exes[iinfo].NgramLength, _exes[iinfo].SkipLength,
                                                _ngramMaps[iinfo].Count, GetNgramIdFinder(iinfo));
            var keyCount = (uint)Infos[iinfo].TypeSrc.ItemType.KeyCount;

            if (keyCount == 0)
            {
                keyCount = uint.MaxValue;
            }

            ValueGetter <VBuffer <Float> > del;

            switch (_exes[iinfo].Weighting)
            {
            case WeightingCriteria.TfIdf:
                Host.AssertValue(_invDocFreqs[iinfo]);
                del =
                    (ref VBuffer <Float> dst) =>
                {
                    getSrc(ref src);
                    if (!bldr.IsEmpty)
                    {
                        bldr.Reset();
                        bldr.AddNgrams(ref src, 0, keyCount);
                        bldr.GetResult(ref dst);
                        VBufferUtils.Apply(ref dst, (int i, ref Float v) => v = (Float)(v * _invDocFreqs[iinfo][i]));
                    }
                    else
                    {
                        dst = new VBuffer <Float>(0, dst.Values, dst.Indices);
                    }
                };
                break;

            case WeightingCriteria.Idf:
                Host.AssertValue(_invDocFreqs[iinfo]);
                del =
                    (ref VBuffer <Float> dst) =>
                {
                    getSrc(ref src);
                    if (!bldr.IsEmpty)
                    {
                        bldr.Reset();
                        bldr.AddNgrams(ref src, 0, keyCount);
                        bldr.GetResult(ref dst);
                        VBufferUtils.Apply(ref dst, (int i, ref Float v) => v = v >= 1 ? (Float)_invDocFreqs[iinfo][i] : 0);
                    }
                    else
                    {
                        dst = new VBuffer <Float>(0, dst.Values, dst.Indices);
                    }
                };
                break;

            case WeightingCriteria.Tf:
                del =
                    (ref VBuffer <Float> dst) =>
                {
                    getSrc(ref src);
                    if (!bldr.IsEmpty)
                    {
                        bldr.Reset();
                        bldr.AddNgrams(ref src, 0, keyCount);
                        bldr.GetResult(ref dst);
                    }
                    else
                    {
                        dst = new VBuffer <Float>(0, dst.Values, dst.Indices);
                    }
                };
                break;

            default:
                throw Host.Except("Unsupported weighting criteria");
            }

            return(del);
        }
コード例 #14
0
        private SequencePool[] Train(Arguments args, IDataView trainingData, out double[][] invDocFreqs)
        {
            // Contains the maximum number of grams to store in the dictionary, for each level of ngrams,
            // from 1 (in position 0) up to ngramLength (in position ngramLength-1)
            var lims = new int[Infos.Length][];

            for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
            {
                var all         = args.Column[iinfo].AllLengths ?? args.AllLengths;
                var ngramLength = _exes[iinfo].NgramLength;
                var maxNumTerms = Utils.Size(args.Column[iinfo].MaxNumTerms) > 0 ? args.Column[iinfo].MaxNumTerms : args.MaxNumTerms;
                if (!all)
                {
                    Host.CheckUserArg(Utils.Size(maxNumTerms) == 0 ||
                                      Utils.Size(maxNumTerms) == 1 && maxNumTerms[0] > 0, nameof(args.MaxNumTerms));
                    lims[iinfo] = new int[ngramLength];
                    lims[iinfo][ngramLength - 1] = Utils.Size(maxNumTerms) == 0 ? Arguments.DefaultMaxTerms : maxNumTerms[0];
                }
                else
                {
                    Host.CheckUserArg(Utils.Size(maxNumTerms) <= ngramLength, nameof(args.MaxNumTerms));
                    Host.CheckUserArg(Utils.Size(maxNumTerms) == 0 || maxNumTerms.All(i => i >= 0) && maxNumTerms[maxNumTerms.Length - 1] > 0, nameof(args.MaxNumTerms));
                    var extend = Utils.Size(maxNumTerms) == 0 ? Arguments.DefaultMaxTerms : maxNumTerms[maxNumTerms.Length - 1];
                    lims[iinfo] = Utils.BuildArray(ngramLength,
                                                   i => i < Utils.Size(maxNumTerms) ? maxNumTerms[i] : extend);
                }
            }

            var helpers = new NgramBufferBuilder[Infos.Length];
            var getters = new ValueGetter <VBuffer <uint> > [Infos.Length];
            var src     = new VBuffer <uint> [Infos.Length];

            // Keep track of how many grams are in the pool for each value of n. Position
            // i in _counts counts how many (i+1)-grams are in the pool for column iinfo.
            var counts    = new int[Infos.Length][];
            var ngramMaps = new SequencePool[Infos.Length];

            bool[] activeInput = new bool[trainingData.Schema.ColumnCount];
            foreach (var info in Infos)
            {
                activeInput[info.Source] = true;
            }
            using (var cursor = trainingData.GetRowCursor(col => activeInput[col]))
                using (var pch = Host.StartProgressChannel("Building n-gram dictionary"))
                {
                    for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                    {
                        Host.Assert(Infos[iinfo].TypeSrc.IsVector && Infos[iinfo].TypeSrc.ItemType.IsKey);
                        var ngramLength = _exes[iinfo].NgramLength;
                        var skipLength  = _exes[iinfo].SkipLength;

                        getters[iinfo]   = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, cursor, Infos[iinfo].Source);
                        src[iinfo]       = default(VBuffer <uint>);
                        counts[iinfo]    = new int[ngramLength];
                        ngramMaps[iinfo] = new SequencePool();

                        // Note: GetNgramIdFinderAdd will control how many ngrams of a specific length will
                        // be added (using lims[iinfo]), therefore we set slotLim to the maximum
                        helpers[iinfo] = new NgramBufferBuilder(ngramLength, skipLength, Utils.ArrayMaxSize,
                                                                GetNgramIdFinderAdd(counts[iinfo], lims[iinfo], ngramMaps[iinfo], _exes[iinfo].RequireIdf(), Host));
                    }

                    int    cInfoFull = 0;
                    bool[] infoFull  = new bool[Infos.Length];

                    invDocFreqs = new double[Infos.Length][];

                    long   totalDocs = 0;
                    Double rowCount  = trainingData.GetRowCount(true) ?? Double.NaN;
                    var    buffers   = new VBuffer <float> [Infos.Length];
                    pch.SetHeader(new ProgressHeader(new[] { "Total n-grams" }, new[] { "documents" }),
                                  e => e.SetProgress(0, totalDocs, rowCount));
                    while (cInfoFull < Infos.Length && cursor.MoveNext())
                    {
                        totalDocs++;
                        for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                        {
                            getters[iinfo](ref src[iinfo]);
                            var keyCount = (uint)Infos[iinfo].TypeSrc.ItemType.KeyCount;
                            if (keyCount == 0)
                            {
                                keyCount = uint.MaxValue;
                            }
                            if (!infoFull[iinfo])
                            {
                                if (_exes[iinfo].RequireIdf())
                                {
                                    helpers[iinfo].Reset();
                                }

                                helpers[iinfo].AddNgrams(ref src[iinfo], 0, keyCount);
                                if (_exes[iinfo].RequireIdf())
                                {
                                    int totalNgrams = counts[iinfo].Sum();
                                    Utils.EnsureSize(ref invDocFreqs[iinfo], totalNgrams);
                                    helpers[iinfo].GetResult(ref buffers[iinfo]);
                                    foreach (var pair in buffers[iinfo].Items())
                                    {
                                        if (pair.Value >= 1)
                                        {
                                            invDocFreqs[iinfo][pair.Key] += 1;
                                        }
                                    }
                                }
                            }
                            AssertValid(counts[iinfo], lims[iinfo], ngramMaps[iinfo]);
                        }
                    }

                    pch.Checkpoint(counts.Sum(c => c.Sum()), totalDocs);
                    for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                    {
                        for (int i = 0; i < Utils.Size(invDocFreqs[iinfo]); i++)
                        {
                            if (invDocFreqs[iinfo][i] != 0)
                            {
                                invDocFreqs[iinfo][i] = Math.Log(totalDocs / invDocFreqs[iinfo][i]);
                            }
                        }
                    }

                    for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                    {
                        AssertValid(counts[iinfo], lims[iinfo], ngramMaps[iinfo]);

                        int ngramLength = _exes[iinfo].NgramLength;
                        for (int i = 0; i < ngramLength; i++)
                        {
                            _exes[iinfo].NonEmptyLevels[i] = counts[iinfo][i] > 0;
                        }
                    }

                    return(ngramMaps);
                }
        }
コード例 #15
0
        // The multi-output regression evaluator prints only the per-label metrics for each fold.
        protected override void PrintFoldResultsCore(IChannel ch, Dictionary <string, IDataView> metrics)
        {
            IDataView fold;

            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out fold))
            {
                throw ch.Except("No overall metrics found");
            }

            int  isWeightedCol;
            bool needWeighted = fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out isWeightedCol);

            int  stratCol;
            bool hasStrats = fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol);
            int  stratVal;
            bool hasStratVals = fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal);

            ch.Assert(hasStrats == hasStratVals);

            var colCount       = fold.Schema.ColumnCount;
            var vBufferGetters = new ValueGetter <VBuffer <double> > [colCount];

            using (var cursor = fold.GetRowCursor(col => true))
            {
                DvBool isWeighted = DvBool.False;
                ValueGetter <DvBool> isWeightedGetter;
                if (needWeighted)
                {
                    isWeightedGetter = cursor.GetGetter <DvBool>(isWeightedCol);
                }
                else
                {
                    isWeightedGetter = (ref DvBool dst) => dst = DvBool.False;
                }

                ValueGetter <uint> stratGetter;
                if (hasStrats)
                {
                    var type = cursor.Schema.GetColumnType(stratCol);
                    stratGetter = RowCursorUtils.GetGetterAs <uint>(type, cursor, stratCol);
                }
                else
                {
                    stratGetter = (ref uint dst) => dst = 0;
                }

                int labelCount = 0;
                for (int i = 0; i < fold.Schema.ColumnCount; i++)
                {
                    if (fold.Schema.IsHidden(i) || (needWeighted && i == isWeightedCol) ||
                        (hasStrats && (i == stratCol || i == stratVal)))
                    {
                        continue;
                    }

                    var type = fold.Schema.GetColumnType(i);
                    if (type.IsKnownSizeVector && type.ItemType == NumberType.R8)
                    {
                        vBufferGetters[i] = cursor.GetGetter <VBuffer <double> >(i);
                        if (labelCount == 0)
                        {
                            labelCount = type.VectorSize;
                        }
                        else
                        {
                            ch.Check(labelCount == type.VectorSize, "All vector metrics should contain the same number of slots");
                        }
                    }
                }
                var labelNames = new DvText[labelCount];
                for (int j = 0; j < labelCount; j++)
                {
                    labelNames[j] = new DvText(string.Format("Label_{0}", j));
                }

                var sb = new StringBuilder();
                sb.AppendLine("Per-label metrics:");
                sb.AppendFormat("{0,12} ", " ");
                for (int i = 0; i < labelCount; i++)
                {
                    sb.AppendFormat(" {0,20}", labelNames[i]);
                }
                sb.AppendLine();

                VBuffer <Double> metricVals      = default(VBuffer <Double>);
                bool             foundWeighted   = !needWeighted;
                bool             foundUnweighted = false;
                uint             strat           = 0;
                while (cursor.MoveNext())
                {
                    isWeightedGetter(ref isWeighted);
                    if (foundWeighted && isWeighted.IsTrue || foundUnweighted && isWeighted.IsFalse)
                    {
                        throw ch.Except("Multiple {0} rows found in overall metrics data view",
                                        isWeighted.IsTrue ? "weighted" : "unweighted");
                    }
                    if (isWeighted.IsTrue)
                    {
                        foundWeighted = true;
                    }
                    else
                    {
                        foundUnweighted = true;
                    }

                    stratGetter(ref strat);
                    if (strat > 0)
                    {
                        continue;
                    }

                    for (int i = 0; i < colCount; i++)
                    {
                        if (vBufferGetters[i] != null)
                        {
                            vBufferGetters[i](ref metricVals);
                            ch.Assert(metricVals.Length == labelCount);

                            sb.AppendFormat("{0}{1,12}:", isWeighted.IsTrue ? "Weighted " : "", fold.Schema.GetColumnName(i));
                            foreach (var metric in metricVals.Items(all: true))
                            {
                                sb.AppendFormat(" {0,20:G20}", metric.Value);
                            }
                            sb.AppendLine();
                        }
                    }
                    if (foundUnweighted && foundWeighted)
                    {
                        break;
                    }
                }
                ch.Assert(foundUnweighted && foundWeighted);
                ch.Info(sb.ToString());
            }
        }
コード例 #16
0
        protected override void PrintFoldResultsCore(IChannel ch, Dictionary <string, IDataView> metrics)
        {
            IDataView top;

            if (!metrics.TryGetValue(AnomalyDetectionEvaluator.TopKResults, out top))
            {
                throw Host.Except("Did not find the top-k results data view");
            }
            var sb = new StringBuilder();

            using (var cursor = top.GetRowCursor(col => true))
            {
                int index;
                if (!top.Schema.TryGetColumnIndex(AnomalyDetectionEvaluator.TopKResultsColumns.Instance, out index))
                {
                    throw Host.Except("Data view does not contain the 'Instance' column");
                }
                var instanceGetter = cursor.GetGetter <ReadOnlyMemory <char> >(index);
                if (!top.Schema.TryGetColumnIndex(AnomalyDetectionEvaluator.TopKResultsColumns.AnomalyScore, out index))
                {
                    throw Host.Except("Data view does not contain the 'Anomaly Score' column");
                }
                var scoreGetter = cursor.GetGetter <Single>(index);
                if (!top.Schema.TryGetColumnIndex(AnomalyDetectionEvaluator.TopKResultsColumns.Label, out index))
                {
                    throw Host.Except("Data view does not contain the 'Label' column");
                }
                var labelGetter = cursor.GetGetter <Single>(index);

                bool hasRows = false;
                while (cursor.MoveNext())
                {
                    if (!hasRows)
                    {
                        sb.AppendFormat("{0} Top-scored Results", _topScored);
                        sb.AppendLine();
                        sb.AppendLine("=================================================");
                        sb.AppendLine("Instance    Anomaly Score     Labeled");
                        hasRows = true;
                    }
                    var    name  = default(ReadOnlyMemory <char>);
                    Single score = 0;
                    Single label = 0;
                    instanceGetter(ref name);
                    scoreGetter(ref score);
                    labelGetter(ref label);
                    sb.AppendFormat("{0,-10}{1,12:G4}{2,12}", name, score, label);
                    sb.AppendLine();
                }
            }
            if (sb.Length > 0)
            {
                ch.Info(MessageSensitivity.UserData, sb.ToString());
            }

            IDataView overall;

            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out overall))
            {
                throw Host.Except("No overall metrics found");
            }

            // Find the number of anomalies, and the thresholds.
            int numAnomIndex;

            if (!overall.Schema.TryGetColumnIndex(AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies, out numAnomIndex))
            {
                throw Host.Except("Could not find the 'NumAnomalies' column");
            }

            int  stratCol;
            var  hasStrat = overall.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol);
            int  stratVal;
            bool hasStratVals = overall.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal);

            Contracts.Assert(hasStrat == hasStratVals);
            long numAnomalies = 0;

            using (var cursor = overall.GetRowCursor(col => col == numAnomIndex ||
                                                     (hasStrat && col == stratCol)))
            {
                var numAnomGetter = cursor.GetGetter <long>(numAnomIndex);
                ValueGetter <uint> stratGetter = null;
                if (hasStrat)
                {
                    var type = cursor.Schema.GetColumnType(stratCol);
                    stratGetter = RowCursorUtils.GetGetterAs <uint>(type, cursor, stratCol);
                }
                bool foundRow = false;
                while (cursor.MoveNext())
                {
                    uint strat = 0;
                    if (stratGetter != null)
                    {
                        stratGetter(ref strat);
                    }
                    if (strat > 0)
                    {
                        continue;
                    }
                    if (foundRow)
                    {
                        throw Host.Except("Found multiple non-stratified rows in overall results data view");
                    }
                    foundRow = true;
                    numAnomGetter(ref numAnomalies);
                }
            }

            var kFormatName = string.Format(FoldDrAtKFormat, _k);
            var pFormatName = string.Format(FoldDrAtPFormat, _p);
            var numAnomName = string.Format(FoldDrAtNumAnomaliesFormat, numAnomalies);

            (string Source, string Name)[] cols =
コード例 #17
0
        protected override void PrintFoldResultsCore(IChannel ch, Dictionary <string, IDataView> metrics)
        {
            IDataView top;

            if (!metrics.TryGetValue(AnomalyDetectionEvaluator.TopKResults, out top))
            {
                throw Host.Except("Did not find the top-k results data view");
            }
            var sb = new StringBuilder();

            using (var cursor = top.GetRowCursor(col => true))
            {
                int index;
                if (!top.Schema.TryGetColumnIndex(AnomalyDetectionEvaluator.TopKResultsColumns.Instance, out index))
                {
                    throw Host.Except("Data view does not contain the 'Instance' column");
                }
                var instanceGetter = cursor.GetGetter <DvText>(index);
                if (!top.Schema.TryGetColumnIndex(AnomalyDetectionEvaluator.TopKResultsColumns.AnomalyScore, out index))
                {
                    throw Host.Except("Data view does not contain the 'Anomaly Score' column");
                }
                var scoreGetter = cursor.GetGetter <Single>(index);
                if (!top.Schema.TryGetColumnIndex(AnomalyDetectionEvaluator.TopKResultsColumns.Label, out index))
                {
                    throw Host.Except("Data view does not contain the 'Label' column");
                }
                var labelGetter = cursor.GetGetter <Single>(index);

                bool hasRows = false;
                while (cursor.MoveNext())
                {
                    if (!hasRows)
                    {
                        sb.AppendFormat("{0} Top-scored Results", _topScored);
                        sb.AppendLine();
                        sb.AppendLine("=================================================");
                        sb.AppendLine("Instance    Anomaly Score     Labeled");
                        hasRows = true;
                    }
                    var    name  = default(DvText);
                    Single score = 0;
                    Single label = 0;
                    instanceGetter(ref name);
                    scoreGetter(ref score);
                    labelGetter(ref label);
                    sb.AppendFormat("{0,-10}{1,12:G4}{2,12}", name, score, label);
                    sb.AppendLine();
                }
            }
            if (sb.Length > 0)
            {
                ch.Info(MessageSensitivity.UserData, sb.ToString());
            }

            IDataView overall;

            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out overall))
            {
                throw Host.Except("No overall metrics found");
            }

            // Find the number of anomalies, and the thresholds.
            int numAnomIndex;

            if (!overall.Schema.TryGetColumnIndex(AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies, out numAnomIndex))
            {
                throw Host.Except("Could not find the 'NumAnomalies' column");
            }

            int  stratCol;
            var  hasStrat = overall.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol);
            int  stratVal;
            bool hasStratVals = overall.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal);

            Contracts.Assert(hasStrat == hasStratVals);
            DvInt8 numAnomalies = 0;

            using (var cursor = overall.GetRowCursor(col => col == numAnomIndex ||
                                                     (hasStrat && col == stratCol)))
            {
                var numAnomGetter = cursor.GetGetter <DvInt8>(numAnomIndex);
                ValueGetter <uint> stratGetter = null;
                if (hasStrat)
                {
                    var type = cursor.Schema.GetColumnType(stratCol);
                    stratGetter = RowCursorUtils.GetGetterAs <uint>(type, cursor, stratCol);
                }
                bool foundRow = false;
                while (cursor.MoveNext())
                {
                    uint strat = 0;
                    if (stratGetter != null)
                    {
                        stratGetter(ref strat);
                    }
                    if (strat > 0)
                    {
                        continue;
                    }
                    if (foundRow)
                    {
                        throw Host.Except("Found multiple non-stratified rows in overall results data view");
                    }
                    foundRow = true;
                    numAnomGetter(ref numAnomalies);
                }
            }

            var args = new ChooseColumnsTransform.Arguments();
            var cols = new List <ChooseColumnsTransform.Column>()
            {
                new ChooseColumnsTransform.Column()
                {
                    Name   = string.Format(FoldDrAtKFormat, _k),
                    Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtK
                },
                new ChooseColumnsTransform.Column()
                {
                    Name   = string.Format(FoldDrAtPFormat, _p),
                    Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr
                },
                new ChooseColumnsTransform.Column()
                {
                    Name   = string.Format(FoldDrAtNumAnomaliesFormat, numAnomalies),
                    Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos
                },
                new ChooseColumnsTransform.Column()
                {
                    Name = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK
                },
                new ChooseColumnsTransform.Column()
                {
                    Name = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP
                },
                new ChooseColumnsTransform.Column()
                {
                    Name = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos
                },
                new ChooseColumnsTransform.Column()
                {
                    Name = BinaryClassifierEvaluator.Auc
                }
            };

            args.Column = cols.ToArray();
            IDataView fold = new ChooseColumnsTransform(Host, args, overall);
            string    weightedFold;

            ch.Info(MetricWriter.GetPerFoldResults(Host, fold, out weightedFold));
        }