private static void EnsureCachedResultValueMapper(ValueMapper <VBuffer <float>, float, float> mapper,
                                                              ref long cachedPosition, ValueGetter <VBuffer <float> > featureGetter, ref VBuffer <float> features,
                                                              ref float score, ref float prob, DataViewRow input)
            {
                Contracts.AssertValue(mapper);
                if (cachedPosition != input.Position)
                {
                    if (featureGetter != null)
                    {
                        featureGetter(ref features);
                    }

                    mapper(in features, ref score, ref prob);
                    cachedPosition = input.Position;
                }
            }
コード例 #2
0
            public DataViewRow GetRow(DataViewRow input, Func <int, bool> predicate)
            {
                var scoreGetter = CreateScoreGetter(input, predicate, out Action disposer);

                return(new SimpleRow(OutputSchema, input, new[] { scoreGetter }, disposer));
            }
コード例 #3
0
 /// <summary>
 /// Get the getter for the feature column, assuming it is a vector of float.
 /// </summary>
 public static ValueGetter <VBuffer <float> > GetFeatureFloatVectorGetter(this DataViewRow row, RoleMappedData data)
 {
     Contracts.CheckValue(data, nameof(data));
     return(GetFeatureFloatVectorGetter(row, data.Schema));
 }
コード例 #4
0
 /// <summary>
 /// Return a new state object.
 /// </summary>
 protected abstract TState InitializeState(DataViewRow input);
コード例 #5
0
 public ListAggregator(DataViewRow row, int col)
 {
     Contracts.AssertValue(row);
     _srcGetter = row.GetGetter <TValue>(row.Schema[col]);
     _getter    = (ValueGetter <VBuffer <TValue> >)Getter;
 }
コード例 #6
0
            private Action <TRow> GenerateSetter(DataViewRow input, int index, InternalSchemaDefinition.Column column, Delegate poke, Delegate peek)
            {
                var colType     = input.Schema[index].Type;
                var fieldType   = column.OutputType;
                var genericType = fieldType;
                Func <DataViewRow, int, Delegate, Delegate, Action <TRow> > del;

                if (fieldType.IsArray)
                {
                    Ch.Assert(colType is VectorType);
                    // VBuffer<ReadOnlyMemory<char>> -> String[]
                    if (fieldType.GetElementType() == typeof(string))
                    {
                        Ch.Assert(colType.GetItemType() is TextDataViewType);
                        return(CreateConvertingVBufferSetter <ReadOnlyMemory <char>, string>(input, index, poke, peek, x => x.ToString()));
                    }

                    // VBuffer<T> -> T[]
                    if (fieldType.GetElementType().IsGenericType&& fieldType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable <>))
                    {
                        Ch.Assert(colType.GetItemType().RawType == Nullable.GetUnderlyingType(fieldType.GetElementType()));
                    }
                    else
                    {
                        Ch.Assert(colType.GetItemType().RawType == fieldType.GetElementType());
                    }
                    del         = CreateDirectVBufferSetter <int>;
                    genericType = fieldType.GetElementType();
                }
                else if (colType is VectorType vectorType)
                {
                    // VBuffer<T> -> VBuffer<T>
                    // REVIEW: Do we care about accomodating VBuffer<string> -> VBuffer<ReadOnlyMemory<char>>?
                    Ch.Assert(fieldType.IsGenericType);
                    Ch.Assert(fieldType.GetGenericTypeDefinition() == typeof(VBuffer <>));
                    Ch.Assert(fieldType.GetGenericArguments()[0] == vectorType.ItemType.RawType);
                    del         = CreateVBufferToVBufferSetter <int>;
                    genericType = vectorType.ItemType.RawType;
                }
                else if (colType is PrimitiveDataViewType)
                {
                    if (fieldType == typeof(string))
                    {
                        // ReadOnlyMemory<char> -> String
                        Ch.Assert(colType is TextDataViewType);
                        Ch.Assert(peek == null);
                        return(CreateConvertingActionSetter <ReadOnlyMemory <char>, string>(input, index, poke, x => x.ToString()));
                    }

                    // T -> T
                    if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable <>))
                    {
                        Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(fieldType));
                    }
                    else
                    {
                        Ch.Assert(colType.RawType == fieldType);
                    }

                    del = CreateDirectSetter <int>;
                }
                else
                {
                    // REVIEW: Is this even possible?
                    throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", column.OutputType.FullName);
                }
                MethodInfo meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genericType);

                return((Action <TRow>)meth.Invoke(this, new object[] { input, index, poke, peek }));
            }
コード例 #7
0
 /// <summary>
 /// Get the getter for the first input column.
 /// </summary>
 protected abstract ValueGetter <TLabel> GetLabelGetter(DataViewRow row);
コード例 #8
0
 public RowImpl(DataViewRow input, Mapper mapper)
     : base(input)
 {
     _mapper = mapper;
 }
コード例 #9
0
 public DataViewRow GetRow(DataViewRow input, Func <int, bool> active)
 {
     return(new RowImpl(input, _mapper));
 }
コード例 #10
0
                public DataViewRow GetRow(DataViewRow input, Func <int, bool> predicate)
                {
                    var innerRow = _mapper.GetRow(input, predicate);

                    return(new RowImpl(innerRow, OutputSchema));
                }
コード例 #11
0
 private ValueGetter <T> GetSrcGetter <T>(DataViewRow input, int iinfo)
 {
     return(input.GetGetter <T>(input.Schema[_bindings.SrcCols[iinfo]]));
 }
コード例 #12
0
 private protected WrappingRow(DataViewRow input)
 {
     Contracts.AssertValue(input);
     Input = input;
 }
コード例 #13
0
 public DataViewRow GetRow(DataViewRow input, Func <int, bool> predicate)
 {
     _ectx.CheckValue(input, nameof(input));
     _ectx.CheckValue(predicate, nameof(predicate));
     return(new SimpleRow(OutputSchema, input, CreateGetters(input, predicate)));
 }
コード例 #14
0
 protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func <int, bool> activeOutput, out Action disposer)
 => throw new NotImplementedException("This should never be called!");
コード例 #15
0
 protected override RowCursorState InitializeState(DataViewRow input)
 {
     return(new RowCursorState(_truncationLevel));
 }
コード例 #16
0
        private protected override Delegate[] CreateGettersCore(DataViewRow input, Func <int, bool> activeCols, out Action disposer)
        {
            disposer = null;

            var getters = new Delegate[3];

            if (!activeCols(ClusterIdCol) && !activeCols(SortedClusterCol) && !activeCols(SortedClusterScoreCol))
            {
                return(getters);
            }

            long             cachedPosition = -1;
            VBuffer <Single> scores         = default(VBuffer <Single>);
            var scoresArr = new Single[_numClusters];

            int[] sortedIndices = new int[_numClusters];

            var    scoreGetter         = input.GetGetter <VBuffer <Single> >(input.Schema[ScoreIndex]);
            Action updateCacheIfNeeded =
                () =>
            {
                if (cachedPosition != input.Position)
                {
                    scoreGetter(ref scores);
                    scores.CopyTo(scoresArr);
                    int j = 0;
                    foreach (var index in Enumerable.Range(0, scoresArr.Length).OrderBy(i => scoresArr[i]))
                    {
                        sortedIndices[j++] = index;
                    }
                    cachedPosition = input.Position;
                }
            };

            if (activeCols(ClusterIdCol))
            {
                ValueGetter <uint> assignedFn =
                    (ref uint dst) =>
                {
                    updateCacheIfNeeded();
                    dst = (uint)sortedIndices[0] + 1;
                };
                getters[ClusterIdCol] = assignedFn;
            }

            if (activeCols(SortedClusterScoreCol))
            {
                ValueGetter <VBuffer <Single> > topKScoresFn =
                    (ref VBuffer <Single> dst) =>
                {
                    updateCacheIfNeeded();
                    var editor = VBufferEditor.Create(ref dst, _numClusters);
                    for (int i = 0; i < _numClusters; i++)
                    {
                        editor.Values[i] = scores.GetItemOrDefault(sortedIndices[i]);
                    }
                    dst = editor.Commit();
                };
                getters[SortedClusterScoreCol] = topKScoresFn;
            }

            if (activeCols(SortedClusterCol))
            {
                ValueGetter <VBuffer <uint> > topKClassesFn =
                    (ref VBuffer <uint> dst) =>
                {
                    updateCacheIfNeeded();
                    var editor = VBufferEditor.Create(ref dst, _numClusters);
                    for (int i = 0; i < _numClusters; i++)
                    {
                        editor.Values[i] = (uint)sortedIndices[i] + 1;
                    }
                    dst = editor.Commit();
                };
                getters[SortedClusterCol] = topKClassesFn;
            }
            return(getters);
        }
コード例 #17
0
 public IRowReadableAs <TRow> GetRow(DataViewRow input)
 {
     return(new RowImplementation(new TypedRow(this, input)));
 }
コード例 #18
0
 public NameOnnxValueGetter(DataViewRow input, string colName, int colIndex)
 {
     _colName   = colName;
     _srcgetter = input.GetGetter <T>(colIndex);
 }
コード例 #19
0
 public TypedRow(TypedCursorable <TRow> parent, DataViewRow input)
     : base(parent, input, "Row")
 {
 }
コード例 #20
0
            DataViewRow ISchemaBoundRowMapper.GetRow(DataViewRow input, IEnumerable <DataViewSchema.Column> activeColumns)
            {
                var scoreGetter = CreateScoreGetter(input, out Action disposer);

                return(new SimpleRow(OutputSchema, input, new[] { scoreGetter }, disposer));
            }
コード例 #21
0
 /// <summary>
 /// Get the getter for the second input column.
 /// </summary>
 protected abstract ValueGetter <TScore> GetScoreGetter(DataViewRow row);
コード例 #22
0
 internal abstract Delegate CreateScoreGetter(DataViewRow input, out Action disposer);
コード例 #23
0
 public abstract Delegate GetMappingGetter(DataViewRow input);
コード例 #24
0
 /// <summary>
 /// Function needed by reflection in <see cref="CreateNamedOnnxValueGetterVec(DataViewRow, Type, int, OnnxShape)"/>.
 /// </summary>
 private static INamedOnnxValueGetter CreateNamedOnnxValueGetterVecCore <T>(DataViewRow input, int colIndex, OnnxShape onnxShape)
 {
     return(new NamedOnnxValueGetterVec <T>(input, colIndex, onnxShape));
 }
コード例 #25
0
                DataViewRow ISchemaBoundRowMapper.GetRow(DataViewRow input, IEnumerable <DataViewSchema.Column> activeColumns)
                {
                    var innerRow = _mapper.GetRow(input, activeColumns);

                    return(new RowImpl(innerRow, OutputSchema));
                }
コード例 #26
0
 public NameOnnxValueGetter(DataViewRow input, int colIndex)
 {
     _colName   = input.Schema[colIndex].Name;
     _srcGetter = input.GetGetter <T>(input.Schema[colIndex]);
 }
コード例 #27
0
 internal abstract Delegate CreateScoreGetter(DataViewRow input, Func <int, bool> mapperPredicate, out Action disposer);
コード例 #28
0
 protected override ValueGetter <Single> GetScoreGetter(DataViewRow row)
 {
     return(row.GetGetter <Single>(row.Schema[_bindings.ScoreIndex]));
 }
コード例 #29
0
 /// <summary>
 /// Get a getter for the label as a float. This assumes that the label column type
 /// has already been validated as appropriate for the kind of training being done.
 /// </summary>
 public static ValueGetter <float> GetLabelFloatGetter(this DataViewRow row, RoleMappedData data)
 {
     Contracts.CheckValue(data, nameof(data));
     return(GetLabelFloatGetter(row, data.Schema));
 }
コード例 #30
0
        private ValueGetter <VBuffer <TDst> > GetGetterVec <T0, TDst>(IExceptionContext ectx, DataViewRow input, DataViewSchema.Column[] inputColumns, int[] perm, Delegate del, DataViewType outputColumnItemType)
        {
            ectx.Assert(inputColumns.Length == 1);
            ectx.Assert(perm.Length == 1);
            ectx.Assert(perm[0] == 0);

            var fn      = (Func <T0, TDst>)del;
            var getSrc0 = input.GetGetter <VBuffer <T0> >(inputColumns[0]);
            var src0    = default(VBuffer <T0>);

            var dstDef = fn(default(T0));
            var isDef  = Conversions.Instance.GetIsDefaultPredicate <TDst>(outputColumnItemType);

            if (isDef(in dstDef))
            {
                // Sparsity is preserved.
                return
                    ((ref VBuffer <TDst> dst) =>
                {
                    getSrc0(ref src0);
                    int count = src0.GetValues().Length;

                    var editor = VBufferEditor.Create(ref dst, src0.Length, count);
                    for (int i = 0; i < count; i++)
                    {
                        editor.Values[i] = fn(src0.GetValues()[i]);
                    }

                    if (!src0.IsDense)
                    {
                        src0.GetIndices().CopyTo(editor.Indices);
                    }
                    dst = editor.Commit();
                });
            }