protected override Delegate[] GetGetters(DataViewRow output, Func<int, bool> predicate)
        {
            Host.Assert(Bindings.DerivedColumnCount == 1);
            Host.AssertValue(output);
            Host.AssertValue(predicate);
            Host.Assert(output.Schema == Bindings.RowMapper.OutputSchema);
            Host.Assert(Bindings.InfoCount == output.Schema.Count + 1);

            var getters = new Delegate[Bindings.InfoCount];

            // Deal with the predicted label column.
            int delta = Bindings.DerivedColumnCount;
            Delegate delScore = null;
            if (predicate(0))
            {
                Host.Assert(output.IsColumnActive(output.Schema[Bindings.ScoreColumnIndex]));
                getters[0] = GetPredictedLabelGetter(output, out delScore);
            }

            for (int iinfo = delta; iinfo < getters.Length; iinfo++)
            {
                if (!predicate(iinfo))
                    continue;
                if (iinfo == delta + Bindings.ScoreColumnIndex && delScore != null)
                    getters[iinfo] = delScore;
                else
                    getters[iinfo] = GetGetterFromRow(output, iinfo - delta);
            }

            return getters;
        }
예제 #2
0
 /// <summary>
 /// Given a row, returns a one-row data view. This is useful for cases where you have a row, and you
 /// wish to use some facility normally only exposed to dataviews. (For example, you have an <see cref="DataViewRow"/>
 /// but want to save it somewhere using a <see cref="Microsoft.ML.Data.IO.BinarySaver"/>.)
 /// Note that it is not possible for this method to ensure that the input <paramref name="row"/> does not
 /// change, so users of this convenience must take care of what they do with the input row or the data
 /// source it came from, while the returned dataview is potentially being used.
 /// </summary>
 /// <param name="env">An environment used to create the host for the resulting data view</param>
 /// <param name="row">A row, whose columns must all be active</param>
 /// <returns>A single-row data view incorporating that row</returns>
 public static IDataView RowAsDataView(IHostEnvironment env, DataViewRow row)
 {
     Contracts.CheckValue(env, nameof(env));
     env.CheckValue(row, nameof(row));
     env.CheckParam(Enumerable.Range(0, row.Schema.Count).All(c => row.IsColumnActive(c)), nameof(row), "Some columns were inactive");
     return(new OneRowDataView(env, row));
 }
 protected static ValueGetter <T> GetGetterFromRow <T>(DataViewRow output, int col)
 {
     Contracts.AssertValue(output);
     Contracts.Assert(0 <= col && col < output.Schema.Count);
     Contracts.Assert(output.IsColumnActive(output.Schema[col]));
     return(output.GetGetter <T>(output.Schema[col]));
 }
        internal DataViewRow GetStatefulRows(DataViewRow input, IRowToRowMapper mapper, IEnumerable <DataViewSchema.Column> activeColumns, List <StatefulRow> rows)
        {
            Contracts.CheckValue(input, nameof(input));
            Contracts.CheckValue(activeColumns, nameof(activeColumns));

            IRowToRowMapper[] innerMappers = new IRowToRowMapper[0];
            if (mapper is CompositeRowToRowMapper compositeMapper)
            {
                innerMappers = compositeMapper.InnerMappers;
            }

            var activeIndices = new HashSet <int>(activeColumns.Select(c => c.Index));

            if (innerMappers.Length == 0)
            {
                bool differentActive = false;
                for (int c = 0; c < input.Schema.Count; ++c)
                {
                    bool wantsActive = activeIndices.Contains(c);
                    bool isActive    = input.IsColumnActive(input.Schema[c]);
                    differentActive |= wantsActive != isActive;

                    if (wantsActive && !isActive)
                    {
                        throw Contracts.ExceptParam(nameof(input), $"Mapper required column '{input.Schema[c].Name}' active but it was not.");
                    }
                }

                var row = mapper.GetRow(input, activeColumns);
                if (row is StatefulRow statefulRow)
                {
                    rows.Add(statefulRow);
                }
                return(row);
            }

            // For each of the inner mappers, we will be calling their GetRow method, but to do so we need to know
            // what we need from them. The last one will just have the input, but the rest will need to be
            // computed based on the dependencies of the next one in the chain.
            var deps = new IEnumerable <DataViewSchema.Column> [innerMappers.Length];

            deps[deps.Length - 1] = activeColumns;
            for (int i = deps.Length - 1; i >= 1; --i)
            {
                deps[i - 1] = innerMappers[i].GetDependencies(deps[i]);
            }

            DataViewRow result = input;

            for (int i = 0; i < innerMappers.Length; ++i)
            {
                result = GetStatefulRows(result, innerMappers[i], deps[i], rows);
                if (result is StatefulRow statefulResult)
                {
                    rows.Add(statefulResult);
                }
            }
            return(result);
        }
예제 #5
0
        internal DataViewRow GetStatefulRows(DataViewRow input, IRowToRowMapper mapper, Func <int, bool> active, List <StatefulRow> rows)
        {
            Contracts.CheckValue(input, nameof(input));
            Contracts.CheckValue(active, nameof(active));

            IRowToRowMapper[] innerMappers = new IRowToRowMapper[0];
            if (mapper is CompositeRowToRowMapper compositeMapper)
            {
                innerMappers = compositeMapper.InnerMappers;
            }

            if (innerMappers.Length == 0)
            {
                bool differentActive = false;
                for (int c = 0; c < input.Schema.Count; ++c)
                {
                    bool wantsActive = active(c);
                    bool isActive    = input.IsColumnActive(c);
                    differentActive |= wantsActive != isActive;

                    if (wantsActive && !isActive)
                    {
                        throw Contracts.ExceptParam(nameof(input), $"Mapper required column '{input.Schema[c].Name}' active but it was not.");
                    }
                }

                var row = mapper.GetRow(input, active);
                if (row is StatefulRow statefulRow)
                {
                    rows.Add(statefulRow);
                }
                return(row);
            }

            // For each of the inner mappers, we will be calling their GetRow method, but to do so we need to know
            // what we need from them. The last one will just have the input, but the rest will need to be
            // computed based on the dependencies of the next one in the chain.
            var deps = new Func <int, bool> [innerMappers.Length];

            deps[deps.Length - 1] = active;
            for (int i = deps.Length - 1; i >= 1; --i)
            {
                var inputCols = innerMappers[i].OutputSchema.Where(c => deps[i](c.Index));
                var cols      = innerMappers[i].GetDependencies(inputCols).ToArray();
                deps[i - 1] = c => cols.Length > 0 ? cols.Any(col => col.Index == c) : false;
            }

            DataViewRow result = input;

            for (int i = 0; i < innerMappers.Length; ++i)
            {
                result = GetStatefulRows(result, innerMappers[i], deps[i], rows);
                if (result is StatefulRow statefulResult)
                {
                    rows.Add(statefulResult);
                }
            }
            return(result);
        }
예제 #6
0
        /// <summary>
        /// Returns an appropriate <see cref="ValueGetter{T}"/> for a row given an active column
        /// index, but as a delegate. The type parameter for the delegate will correspond to the
        /// raw type of the column.
        /// </summary>
        /// <param name="row">The row to get the getter for</param>
        /// <param name="col">The column index, which must be active on that row</param>
        /// <returns>The getter as a delegate</returns>
        public static Delegate GetGetterAsDelegate(DataViewRow row, int col)
        {
            Contracts.CheckValue(row, nameof(row));
            Contracts.CheckParam(0 <= col && col < row.Schema.Count, nameof(col));
            Contracts.CheckParam(row.IsColumnActive(row.Schema[col]), nameof(col), "column was not active");

            return(Utils.MarshalInvoke(_getGetterAsDelegateCoreMethodInfo, row.Schema[col].Type.RawType, row, col));
        }
예제 #7
0
        protected ValueGetter <T> GetSrcGetter <T>(DataViewRow input, int iinfo)
        {
            Host.AssertValue(input);
            Host.Assert(0 <= iinfo && iinfo < Infos.Length);
            int src = Infos[iinfo].Source;

            Host.Assert(input.IsColumnActive(input.Schema[src]));
            return(input.GetGetter <T>(input.Schema[src]));
        }
예제 #8
0
            public bool CanShuffle => true; // The shuffling is even uniformly IID!! :)

            public OneRowDataView(IHostEnvironment env, DataViewRow row)
            {
                Contracts.AssertValue(env);
                _host = env.Register("OneRowDataView");
                _host.AssertValue(row);
                _host.Assert(Enumerable.Range(0, row.Schema.Count).All(c => row.IsColumnActive(c)));

                _row = row;
            }
예제 #9
0
            private ValueGetter <T> GetSrcGetter <T>(DataViewRow input, int iinfo)
            {
                Host.AssertValue(input);
                Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
                var srcCol = input.Schema[_cols[iinfo]];

                Host.Assert(input.IsColumnActive(srcCol));
                return(input.GetGetter <T>(srcCol));
            }
        protected static Delegate GetGetterFromRow(DataViewRow row, int col)
        {
            Contracts.AssertValue(row);
            Contracts.Assert(0 <= col && col < row.Schema.Count);
            Contracts.Assert(row.IsColumnActive(row.Schema[col]));

            var type = row.Schema[col].Type;

            return(Utils.MarshalInvoke(_getGetterFromRowMethodInfo, type.RawType, row, col));
        }
예제 #11
0
        /// <summary>
        /// Given an IRow, and column index, return a function that utilizes the
        /// <see cref="Conversions.GetStringConversion{TSrc}(DataViewType)"/> on the input
        /// rows to map the values in the column, whatever type they may be, into a string
        /// builder. This method will obviously succeed only if there is a string conversion
        /// into the required type. This method can be useful if you want to output a value
        /// as a string in a generic way, but don't really care how you do it.
        /// </summary>
        public static ValueGetter <StringBuilder> GetGetterAsStringBuilder(DataViewRow row, int col)
        {
            Contracts.CheckValue(row, nameof(row));
            Contracts.CheckParam(0 <= col && col < row.Schema.Count, nameof(col));
            Contracts.CheckParam(row.IsColumnActive(col), nameof(col), "column was not active");

            var typeSrc = row.Schema[col].Type;

            Contracts.Check(typeSrc is PrimitiveDataViewType, "Source column type must be primitive");
            return(Utils.MarshalInvoke(GetGetterAsStringBuilderCore <int>, typeSrc.RawType, typeSrc, row, col));
        }
예제 #12
0
            /// <summary>
            /// Returns whether the given column is active in this row.
            /// </summary>
            public override bool IsColumnActive(DataViewSchema.Column column)
            {
                bool isSrc;
                int  index = _parent._bindings.MapColumnIndex(out isSrc, column.Index);

                if (isSrc)
                {
                    return(_input.IsColumnActive(_input.Schema[index]));
                }
                return(_getters[index] != null);
            }
예제 #13
0
            public override bool IsColumnActive(int col)
            {
                bool isSrc;
                int  index = _parent._bindings.MapColumnIndex(out isSrc, col);

                if (isSrc)
                {
                    return(_input.IsColumnActive((index)));
                }
                return(_getters[index] != null);
            }
            /// <summary>
            /// Returns whether the given column is active in this row.
            /// </summary>
            public override bool IsColumnActive(DataViewSchema.Column column)
            {
                Contracts.CheckParam(column.Index < Schema.Count, nameof(column));
                bool isSrc;
                int  iCol = _parent._bindings.MapColumnIndex(out isSrc, column.Index);

                if (isSrc)
                {
                    return(_input.IsColumnActive(_input.Schema[iCol]));
                }
                return(_appendedRow.IsColumnActive(_appendedRow.Schema[iCol]));
            }
        protected static Delegate GetGetterFromRow(DataViewRow row, int col)
        {
            Contracts.AssertValue(row);
            Contracts.Assert(0 <= col && col < row.Schema.Count);
            Contracts.Assert(row.IsColumnActive(col));

            var type = row.Schema[col].Type;
            Func <DataViewRow, int, ValueGetter <int> > del = GetGetterFromRow <int>;
            var meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.RawType);

            return((Delegate)meth.Invoke(null, new object[] { row, col }));
        }
            public override bool IsColumnActive(int col)
            {
                Contracts.CheckParam(0 <= col && col < Schema.Count, nameof(col));
                bool isSrc;
                int  iCol = _parent._bindings.MapColumnIndex(out isSrc, col);

                if (isSrc)
                {
                    return(_input.IsColumnActive(iCol));
                }
                return(_appendedRow.IsColumnActive(iCol));
            }
예제 #17
0
        /// <summary>
        /// Given the item type, typeDst, a row, and column index, return a ValueGetter for the vector-valued
        /// column with a conversion to a vector of typeDst, if needed. This is the weakly typed version of
        /// <see cref="GetVecGetterAs{TDst}(PrimitiveDataViewType, DataViewRow, int)"/>.
        /// </summary>
        public static Delegate GetVecGetterAs(PrimitiveDataViewType typeDst, DataViewRow row, int col)
        {
            Contracts.CheckValue(typeDst, nameof(typeDst));
            Contracts.CheckValue(row, nameof(row));
            Contracts.CheckParam(0 <= col && col < row.Schema.Count, nameof(col));
            Contracts.CheckParam(row.IsColumnActive(col), nameof(col), "column was not active");

            var typeSrc = row.Schema[col].Type as VectorType;

            Contracts.Check(typeSrc != null, "Source column type must be vector");

            Func <VectorType, PrimitiveDataViewType, GetterFactory, ValueGetter <VBuffer <int> > > del = GetVecGetterAsCore <int, int>;
            var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.ItemType.RawType, typeDst.RawType);

            return((Delegate)methodInfo.Invoke(null, new object[] { typeSrc, typeDst, GetterFactory.Create(row, col) }));
        }
예제 #18
0
        /// <summary>
        /// Given a destination type, IRow, and column index, return a ValueGetter for the column
        /// with a conversion to typeDst, if needed. This is a weakly typed version of
        /// <see cref="GetGetterAs{TDst}"/>.
        /// </summary>
        /// <seealso cref="GetGetterAs{TDst}"/>
        public static Delegate GetGetterAs(DataViewType typeDst, DataViewRow row, int col)
        {
            Contracts.CheckValue(typeDst, nameof(typeDst));
            Contracts.CheckParam(typeDst is PrimitiveDataViewType, nameof(typeDst));
            Contracts.CheckValue(row, nameof(row));
            Contracts.CheckParam(0 <= col && col < row.Schema.Count, nameof(col));
            Contracts.CheckParam(row.IsColumnActive(row.Schema[col]), nameof(col), "column was not active");

            var typeSrc = row.Schema[col].Type;

            Contracts.Check(typeSrc is PrimitiveDataViewType, "Source column type must be primitive");

            Func <DataViewType, DataViewType, DataViewRow, int, ValueGetter <int> > del = GetGetterAsCore <int, int>;
            var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(typeSrc.RawType, typeDst.RawType);

            return((Delegate)methodInfo.Invoke(null, new object[] { typeSrc, typeDst, row, col }));
        }
예제 #19
0
        public DataViewRow GetRow(DataViewRow input, Func <int, bool> active)
        {
            Contracts.CheckValue(input, nameof(input));
            Contracts.CheckValue(active, nameof(active));
            Contracts.CheckParam(input.Schema == InputSchema, nameof(input), "Schema did not match original schema");

            if (InnerMappers.Length == 0)
            {
                bool differentActive = false;
                for (int c = 0; c < input.Schema.Count; ++c)
                {
                    bool wantsActive = active(c);
                    bool isActive    = input.IsColumnActive(c);
                    differentActive |= wantsActive != isActive;

                    if (wantsActive && !isActive)
                    {
                        throw Contracts.ExceptParam(nameof(input), $"Mapper required column '{input.Schema[c].Name}' active but it was not.");
                    }
                }
                return(input);
            }

            // For each of the inner mappers, we will be calling their GetRow method, but to do so we need to know
            // what we need from them. The last one will just have the input, but the rest will need to be
            // computed based on the dependencies of the next one in the chain.
            var deps = new Func <int, bool> [InnerMappers.Length];

            deps[deps.Length - 1] = active;
            for (int i = deps.Length - 1; i >= 1; --i)
            {
                var outputColumns = InnerMappers[i].OutputSchema.Where(c => deps[i](c.Index));
                var cols          = InnerMappers[i].GetDependencies(outputColumns).ToArray();
                deps[i - 1] = c => cols.Length > 0 ? cols.Any(col => col.Index == c) : false;
            }

            DataViewRow result = input;

            for (int i = 0; i < InnerMappers.Length; ++i)
            {
                result = InnerMappers[i].GetRow(result, deps[i]);
            }

            return(result);
        }
            protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func <int, bool> activeOutput, out Action disposer)
            {
                Host.AssertValue(input);
                disposer = null;

                Host.Assert(input.IsColumnActive(_scoreColIndex));
                var getScore = input.GetGetter <float>(_scoreColIndex);

                float score = default;

                ValueGetter <float> probability = (ref float dst) =>
                {
                    getScore(ref score);
                    dst = _calibrator.PredictProbability(score);
                };

                return(probability);
            }
        DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable <DataViewSchema.Column> activeColumns)
        {
            Contracts.CheckValue(input, nameof(input));
            Contracts.CheckValue(activeColumns, nameof(activeColumns));
            Contracts.CheckParam(input.Schema == InputSchema, nameof(input), "Schema did not match original schema");

            var activeIndices = activeColumns.Select(c => c.Index).ToArray();

            if (InnerMappers.Length == 0)
            {
                bool differentActive = false;
                for (int c = 0; c < input.Schema.Count; ++c)
                {
                    bool wantsActive = activeIndices.Contains(c);
                    bool isActive    = input.IsColumnActive(input.Schema[c]);
                    differentActive |= wantsActive != isActive;

                    if (wantsActive && !isActive)
                    {
                        throw Contracts.ExceptParam(nameof(input), $"Mapper required column '{input.Schema[c].Name}' active but it was not.");
                    }
                }
                return(input);
            }

            // For each of the inner mappers, we will be calling their GetRow method, but to do so we need to know
            // what we need from them. The last one will just have the input, but the rest will need to be
            // computed based on the dependencies of the next one in the chain.
            IEnumerable <DataViewSchema.Column>[] deps = new IEnumerable <DataViewSchema.Column> [InnerMappers.Length];
            deps[deps.Length - 1] = OutputSchema.Where(c => activeIndices.Contains(c.Index));
            for (int i = deps.Length - 1; i >= 1; --i)
            {
                deps[i - 1] = InnerMappers[i].GetDependencies(deps[i]);
            }

            DataViewRow result = input;

            for (int i = 0; i < InnerMappers.Length; ++i)
            {
                result = InnerMappers[i].GetRow(result, deps[i]);
            }

            return(result);
        }
예제 #22
0
        protected override Delegate GetPredictedLabelGetter(DataViewRow output, out Delegate scoreGetter)
        {
            Host.AssertValue(output);
            Host.Assert(output.Schema == Bindings.RowMapper.OutputSchema);
            Host.Assert(output.IsColumnActive(output.Schema[Bindings.ScoreColumnIndex]));

            var scoreColumn = output.Schema[Bindings.ScoreColumnIndex];
            ValueGetter <float> mapperScoreGetter = output.GetGetter <float>(scoreColumn);

            long  cachedPosition = -1;
            float score          = 0;

            ValueGetter <float> scoreFn =
                (ref float dst) =>
            {
                EnsureCachedPosition(ref cachedPosition, ref score, output, mapperScoreGetter);
                dst = score;
            };

            scoreGetter = scoreFn;

            if (Bindings.PredColType is KeyDataViewType)
            {
                ValueGetter <uint> predFnAsKey =
                    (ref uint dst) =>
                {
                    EnsureCachedPosition(ref cachedPosition, ref score, output, mapperScoreGetter);
                    GetPredictedLabelCoreAsKey(score, ref dst);
                };
                return(predFnAsKey);
            }

            ValueGetter <bool> predFn =
                (ref bool dst) =>
            {
                EnsureCachedPosition(ref cachedPosition, ref score, output, mapperScoreGetter);
                GetPredictedLabelCore(score, ref dst);
            };

            return(predFn);
        }
예제 #23
0
        protected override Delegate GetPredictedLabelGetter(DataViewRow output, out Delegate scoreGetter)
        {
            Contracts.AssertValue(output);
            Contracts.Assert(output.Schema == Bindings.RowMapper.OutputSchema);
            Contracts.Assert(output.IsColumnActive(output.Schema[Bindings.ScoreColumnIndex]));

            ValueGetter <VBuffer <float> > mapperScoreGetter = output.GetGetter <VBuffer <float> >(output.Schema[Bindings.ScoreColumnIndex]);

            long            cachedPosition = -1;
            VBuffer <float> score          = default(VBuffer <float>);
            int             keyCount       = Bindings.PredColType is KeyType key?key.GetCountAsInt32(Host) : 0;

            int scoreLength = keyCount;

            ValueGetter <uint> predFn =
                (ref uint dst) =>
            {
                EnsureCachedPosition(ref cachedPosition, ref score, output, mapperScoreGetter);
                Contracts.Check(score.Length == scoreLength);
                int index = VectorUtils.ArgMin(in score);
                if (index < 0)
                {
                    dst = 0;
                }
                else
                {
                    dst = (uint)index + 1;
                }
            };
            ValueGetter <VBuffer <float> > scoreFn =
                (ref VBuffer <float> dst) =>
            {
                EnsureCachedPosition(ref cachedPosition, ref score, output, mapperScoreGetter);
                Contracts.Check(score.Length == scoreLength);
                score.CopyTo(ref dst);
            };

            scoreGetter = scoreFn;
            return(predFn);
        }