private Delegate MakeGetter(IChannel ch, IRow input, int iinfo, FinderDecorator decorator = null) { ch.Assert(_bindings.Infos[iinfo].SrcTypes.All(t => t.IsVector && t.ItemType.IsKey)); var info = _bindings.Infos[iinfo]; int srcCount = info.SrcIndices.Length; ValueGetter <VBuffer <uint> >[] getSrc = new ValueGetter <VBuffer <uint> > [srcCount]; for (int isrc = 0; isrc < srcCount; isrc++) { getSrc[isrc] = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, input, info.SrcIndices[isrc]); } var src = default(VBuffer <uint>); var ngramIdFinder = GetNgramIdFinder(iinfo); if (decorator != null) { ngramIdFinder = decorator(iinfo, ngramIdFinder); } var bldr = new NgramBufferBuilder(_exes[iinfo].NgramLength, _exes[iinfo].SkipLength, _bindings.Types[iinfo].ValueCount, ngramIdFinder); var keyCounts = _bindings.Infos[iinfo].SrcTypes.Select( t => t.ItemType.KeyCount > 0 ? (uint)t.ItemType.KeyCount : uint.MaxValue).ToArray(); // REVIEW: Special casing the srcCount==1 case could potentially improve perf. ValueGetter <VBuffer <Float> > del = (ref VBuffer <Float> dst) => { bldr.Reset(); for (int i = 0; i < srcCount; i++) { getSrc[i](ref src); bldr.AddNgrams(ref src, i, keyCounts[i]); } bldr.GetResult(ref dst); }; return(del); }
protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) { Host.AssertValueOrNull(ch); Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < Infos.Length); Host.Assert(Infos[iinfo].TypeSrc.IsVector); Host.Assert(Infos[iinfo].TypeSrc.ItemType.IsKey); disposer = null; var getSrc = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, input, Infos[iinfo].Source); var src = default(VBuffer <uint>); var bldr = new NgramBufferBuilder(_exes[iinfo].NgramLength, _exes[iinfo].SkipLength, _ngramMaps[iinfo].Count, GetNgramIdFinder(iinfo)); var keyCount = (uint)Infos[iinfo].TypeSrc.ItemType.KeyCount; if (keyCount == 0) { keyCount = uint.MaxValue; } ValueGetter <VBuffer <Float> > del; switch (_exes[iinfo].Weighting) { case WeightingCriteria.TfIdf: Host.AssertValue(_invDocFreqs[iinfo]); del = (ref VBuffer <Float> dst) => { getSrc(ref src); if (!bldr.IsEmpty) { bldr.Reset(); bldr.AddNgrams(ref src, 0, keyCount); bldr.GetResult(ref dst); VBufferUtils.Apply(ref dst, (int i, ref Float v) => v = (Float)(v * _invDocFreqs[iinfo][i])); } else { dst = new VBuffer <Float>(0, dst.Values, dst.Indices); } }; break; case WeightingCriteria.Idf: Host.AssertValue(_invDocFreqs[iinfo]); del = (ref VBuffer <Float> dst) => { getSrc(ref src); if (!bldr.IsEmpty) { bldr.Reset(); bldr.AddNgrams(ref src, 0, keyCount); bldr.GetResult(ref dst); VBufferUtils.Apply(ref dst, (int i, ref Float v) => v = v >= 1 ? (Float)_invDocFreqs[iinfo][i] : 0); } else { dst = new VBuffer <Float>(0, dst.Values, dst.Indices); } }; break; case WeightingCriteria.Tf: del = (ref VBuffer <Float> dst) => { getSrc(ref src); if (!bldr.IsEmpty) { bldr.Reset(); bldr.AddNgrams(ref src, 0, keyCount); bldr.GetResult(ref dst); } else { dst = new VBuffer <Float>(0, dst.Values, dst.Indices); } }; break; default: throw Host.Except("Unsupported weighting criteria"); } return(del); }
private SequencePool[] Train(Arguments args, IDataView trainingData, out double[][] invDocFreqs) { // Contains the maximum number of grams to store in the dictionary, for each level of ngrams, // from 1 (in position 0) up to ngramLength (in position ngramLength-1) var lims = new int[Infos.Length][]; for (int iinfo = 0; iinfo < Infos.Length; iinfo++) { var all = args.Column[iinfo].AllLengths ?? args.AllLengths; var ngramLength = _exes[iinfo].NgramLength; var maxNumTerms = Utils.Size(args.Column[iinfo].MaxNumTerms) > 0 ? args.Column[iinfo].MaxNumTerms : args.MaxNumTerms; if (!all) { Host.CheckUserArg(Utils.Size(maxNumTerms) == 0 || Utils.Size(maxNumTerms) == 1 && maxNumTerms[0] > 0, nameof(args.MaxNumTerms)); lims[iinfo] = new int[ngramLength]; lims[iinfo][ngramLength - 1] = Utils.Size(maxNumTerms) == 0 ? Arguments.DefaultMaxTerms : maxNumTerms[0]; } else { Host.CheckUserArg(Utils.Size(maxNumTerms) <= ngramLength, nameof(args.MaxNumTerms)); Host.CheckUserArg(Utils.Size(maxNumTerms) == 0 || maxNumTerms.All(i => i >= 0) && maxNumTerms[maxNumTerms.Length - 1] > 0, nameof(args.MaxNumTerms)); var extend = Utils.Size(maxNumTerms) == 0 ? Arguments.DefaultMaxTerms : maxNumTerms[maxNumTerms.Length - 1]; lims[iinfo] = Utils.BuildArray(ngramLength, i => i < Utils.Size(maxNumTerms) ? maxNumTerms[i] : extend); } } var helpers = new NgramBufferBuilder[Infos.Length]; var getters = new ValueGetter <VBuffer <uint> > [Infos.Length]; var src = new VBuffer <uint> [Infos.Length]; // Keep track of how many grams are in the pool for each value of n. Position // i in _counts counts how many (i+1)-grams are in the pool for column iinfo. var counts = new int[Infos.Length][]; var ngramMaps = new SequencePool[Infos.Length]; bool[] activeInput = new bool[trainingData.Schema.ColumnCount]; foreach (var info in Infos) { activeInput[info.Source] = true; } using (var cursor = trainingData.GetRowCursor(col => activeInput[col])) using (var pch = Host.StartProgressChannel("Building n-gram dictionary")) { for (int iinfo = 0; iinfo < Infos.Length; iinfo++) { Host.Assert(Infos[iinfo].TypeSrc.IsVector && Infos[iinfo].TypeSrc.ItemType.IsKey); var ngramLength = _exes[iinfo].NgramLength; var skipLength = _exes[iinfo].SkipLength; getters[iinfo] = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, cursor, Infos[iinfo].Source); src[iinfo] = default(VBuffer <uint>); counts[iinfo] = new int[ngramLength]; ngramMaps[iinfo] = new SequencePool(); // Note: GetNgramIdFinderAdd will control how many ngrams of a specific length will // be added (using lims[iinfo]), therefore we set slotLim to the maximum helpers[iinfo] = new NgramBufferBuilder(ngramLength, skipLength, Utils.ArrayMaxSize, GetNgramIdFinderAdd(counts[iinfo], lims[iinfo], ngramMaps[iinfo], _exes[iinfo].RequireIdf(), Host)); } int cInfoFull = 0; bool[] infoFull = new bool[Infos.Length]; invDocFreqs = new double[Infos.Length][]; long totalDocs = 0; Double rowCount = trainingData.GetRowCount(true) ?? Double.NaN; var buffers = new VBuffer <float> [Infos.Length]; pch.SetHeader(new ProgressHeader(new[] { "Total n-grams" }, new[] { "documents" }), e => e.SetProgress(0, totalDocs, rowCount)); while (cInfoFull < Infos.Length && cursor.MoveNext()) { totalDocs++; for (int iinfo = 0; iinfo < Infos.Length; iinfo++) { getters[iinfo](ref src[iinfo]); var keyCount = (uint)Infos[iinfo].TypeSrc.ItemType.KeyCount; if (keyCount == 0) { keyCount = uint.MaxValue; } if (!infoFull[iinfo]) { if (_exes[iinfo].RequireIdf()) { helpers[iinfo].Reset(); } helpers[iinfo].AddNgrams(ref src[iinfo], 0, keyCount); if (_exes[iinfo].RequireIdf()) { int totalNgrams = counts[iinfo].Sum(); Utils.EnsureSize(ref invDocFreqs[iinfo], totalNgrams); helpers[iinfo].GetResult(ref buffers[iinfo]); foreach (var pair in buffers[iinfo].Items()) { if (pair.Value >= 1) { invDocFreqs[iinfo][pair.Key] += 1; } } } } AssertValid(counts[iinfo], lims[iinfo], ngramMaps[iinfo]); } } pch.Checkpoint(counts.Sum(c => c.Sum()), totalDocs); for (int iinfo = 0; iinfo < Infos.Length; iinfo++) { for (int i = 0; i < Utils.Size(invDocFreqs[iinfo]); i++) { if (invDocFreqs[iinfo][i] != 0) { invDocFreqs[iinfo][i] = Math.Log(totalDocs / invDocFreqs[iinfo][i]); } } } for (int iinfo = 0; iinfo < Infos.Length; iinfo++) { AssertValid(counts[iinfo], lims[iinfo], ngramMaps[iinfo]); int ngramLength = _exes[iinfo].NgramLength; for (int i = 0; i < ngramLength; i++) { _exes[iinfo].NonEmptyLevels[i] = counts[iinfo][i] > 0; } } return(ngramMaps); } }