Exemple #1
0
        protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
        {
            Contracts.AssertValueOrNull(ch);
            Contracts.AssertValue(input);
            Contracts.Assert(0 <= iinfo && iinfo < Infos.Length);

            disposer = null;
            int col     = Infos[iinfo].Source;
            var typeSrc = input.Schema.GetColumnType(col);

            Contracts.Assert(RowCursorUtils.TestGetLabelGetter(typeSrc) == null);
            return(RowCursorUtils.GetLabelGetter(input, col));
        }
Exemple #2
0
            public override void InitializeNextPass(IRow row, RoleMappedSchema schema)
            {
                Contracts.Assert(PassNum < 1);
                Contracts.AssertValue(schema.Label);

                var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);

                _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index);
                _scoreGetter = row.GetGetter <TScore>(score.Index);
                Contracts.AssertValue(_labelGetter);
                Contracts.AssertValue(_scoreGetter);

                if (schema.Weight != null)
                {
                    _weightGetter = row.GetGetter <float>(schema.Weight.Index);
                }
            }
            public override void InitializeNextPass(IRow row, RoleMappedSchema schema)
            {
                AssertValid(assertGetters: false);

                Host.AssertValue(row);
                Host.AssertValue(schema);

                if (_calculateDbi)
                {
                    Host.AssertValue(schema.Feature);
                    _featGetter = row.GetGetter <VBuffer <Single> >(schema.Feature.Index);
                }
                var score = schema.GetUniqueColumn(MetadataUtils.Const.ScoreValueKind.Score);

                Host.Assert(score.Type.VectorSize == _scoresArr.Length);
                _scoreGetter = row.GetGetter <VBuffer <Single> >(score.Index);

                if (PassNum == 0)
                {
                    if (schema.Label != null)
                    {
                        _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Index);
                    }
                    else
                    {
                        _labelGetter = (ref Single value) => value = Single.NaN;
                    }
                    if (schema.Weight != null)
                    {
                        _weightGetter = row.GetGetter <Single>(schema.Weight.Index);
                    }
                }
                else
                {
                    Host.Assert(PassNum == 1 && _calculateDbi);
                    UnweightedCounters.InitializeSecondPass(_clusterCentroids);
                    if (WeightedCounters != null)
                    {
                        WeightedCounters.InitializeSecondPass(_clusterCentroids);
                    }
                }
                AssertValid(assertGetters: true);
            }
        public override Delegate[] CreateGetters(IRow input, Func <int, bool> activeCols, out Action disposer)
        {
            Host.Assert(LabelIndex >= 0);
            Host.Assert(ScoreIndex >= 0);

            disposer = null;

            long  cachedPosition = -1;
            Float label          = 0;
            var   score          = default(VBuffer <Float>);
            var   l1             = VBufferUtils.CreateDense <Double>(_scoreSize);

            ValueGetter <Float> nanGetter = (ref Float value) => value = Single.NaN;
            var labelGetter = activeCols(L1Col) || activeCols(L2Col) ? RowCursorUtils.GetLabelGetter(input, LabelIndex) : nanGetter;
            ValueGetter <VBuffer <Float> > scoreGetter;

            if (activeCols(L1Col) || activeCols(L2Col))
            {
                scoreGetter = input.GetGetter <VBuffer <Float> >(ScoreIndex);
            }
            else
            {
                scoreGetter = (ref VBuffer <Float> dst) => dst = default(VBuffer <Float>);
            }
            Action updateCacheIfNeeded =
                () =>
            {
                if (cachedPosition != input.Position)
                {
                    labelGetter(ref label);
                    scoreGetter(ref score);
                    var lab = (Double)label;
                    foreach (var s in score.Items(all: true))
                    {
                        l1.Values[s.Key] = Math.Abs(lab - s.Value);
                    }
                    cachedPosition = input.Position;
                }
            };

            var getters = new Delegate[2];

            if (activeCols(L1Col))
            {
                ValueGetter <VBuffer <Double> > l1Fn =
                    (ref VBuffer <Double> dst) =>
                {
                    updateCacheIfNeeded();
                    l1.CopyTo(ref dst);
                };
                getters[L1Col] = l1Fn;
            }
            if (activeCols(L2Col))
            {
                VBufferUtils.PairManipulator <Double, Double> sqr =
                    (int slot, Double x, ref Double y) => y = x * x;

                ValueGetter <VBuffer <Double> > l2Fn =
                    (ref VBuffer <Double> dst) =>
                {
                    updateCacheIfNeeded();
                    dst = new VBuffer <Double>(_scoreSize, 0, dst.Values, dst.Indices);
                    VBufferUtils.ApplyWith(ref l1, ref dst, sqr);
                };
                getters[L2Col] = l2Fn;
            }
            return(getters);
        }