private Delegate MakeGetter(IChannel ch, IRow input, int iinfo, FinderDecorator decorator = null) { ch.Assert(_bindings.Infos[iinfo].SrcTypes.All(t => t.IsVector && t.ItemType.IsKey)); var info = _bindings.Infos[iinfo]; int srcCount = info.SrcIndices.Length; ValueGetter <VBuffer <uint> >[] getSrc = new ValueGetter <VBuffer <uint> > [srcCount]; for (int isrc = 0; isrc < srcCount; isrc++) { getSrc[isrc] = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, input, info.SrcIndices[isrc]); } var src = default(VBuffer <uint>); var ngramIdFinder = GetNgramIdFinder(iinfo); if (decorator != null) { ngramIdFinder = decorator(iinfo, ngramIdFinder); } var bldr = new NgramBufferBuilder(_exes[iinfo].NgramLength, _exes[iinfo].SkipLength, _bindings.Types[iinfo].ValueCount, ngramIdFinder); var keyCounts = _bindings.Infos[iinfo].SrcTypes.Select( t => t.ItemType.KeyCount > 0 ? (uint)t.ItemType.KeyCount : uint.MaxValue).ToArray(); // REVIEW: Special casing the srcCount==1 case could potentially improve perf. ValueGetter <VBuffer <Float> > del = (ref VBuffer <Float> dst) => { bldr.Reset(); for (int i = 0; i < srcCount; i++) { getSrc[i](ref src); bldr.AddNgrams(ref src, i, keyCounts[i]); } bldr.GetResult(ref dst); }; return(del); }
protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) { Host.AssertValueOrNull(ch); Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < Infos.Length); Host.Assert(Infos[iinfo].TypeSrc.IsVector); Host.Assert(Infos[iinfo].TypeSrc.ItemType.IsKey); disposer = null; var getSrc = RowCursorUtils.GetVecGetterAs <uint>(NumberType.U4, input, Infos[iinfo].Source); var src = default(VBuffer <uint>); var bldr = new NgramBufferBuilder(_exes[iinfo].NgramLength, _exes[iinfo].SkipLength, _ngramMaps[iinfo].Count, GetNgramIdFinder(iinfo)); var keyCount = (uint)Infos[iinfo].TypeSrc.ItemType.KeyCount; if (keyCount == 0) { keyCount = uint.MaxValue; } ValueGetter <VBuffer <Float> > del; switch (_exes[iinfo].Weighting) { case WeightingCriteria.TfIdf: Host.AssertValue(_invDocFreqs[iinfo]); del = (ref VBuffer <Float> dst) => { getSrc(ref src); if (!bldr.IsEmpty) { bldr.Reset(); bldr.AddNgrams(ref src, 0, keyCount); bldr.GetResult(ref dst); VBufferUtils.Apply(ref dst, (int i, ref Float v) => v = (Float)(v * _invDocFreqs[iinfo][i])); } else { dst = new VBuffer <Float>(0, dst.Values, dst.Indices); } }; break; case WeightingCriteria.Idf: Host.AssertValue(_invDocFreqs[iinfo]); del = (ref VBuffer <Float> dst) => { getSrc(ref src); if (!bldr.IsEmpty) { bldr.Reset(); bldr.AddNgrams(ref src, 0, keyCount); bldr.GetResult(ref dst); VBufferUtils.Apply(ref dst, (int i, ref Float v) => v = v >= 1 ? (Float)_invDocFreqs[iinfo][i] : 0); } else { dst = new VBuffer <Float>(0, dst.Values, dst.Indices); } }; break; case WeightingCriteria.Tf: del = (ref VBuffer <Float> dst) => { getSrc(ref src); if (!bldr.IsEmpty) { bldr.Reset(); bldr.AddNgrams(ref src, 0, keyCount); bldr.GetResult(ref dst); } else { dst = new VBuffer <Float>(0, dst.Values, dst.Indices); } }; break; default: throw Host.Except("Unsupported weighting criteria"); } return(del); }