Ejemplo n.º 1
0
 internal PredictionEngine(IHostEnvironment env, IDataView dataPipe, bool ignoreMissingColumns,
                           SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
     : this(env, new TransformWrapper(env, env.CheckRef(dataPipe, nameof(dataPipe))), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
 {
 }
Ejemplo n.º 2
0
 internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
                           SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
     : this(env, TransformerChecker(env, transformer), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
 {
 }
Ejemplo n.º 3
0
        internal PipeEngine(IHostEnvironment env, IDataView pipe, bool ignoreMissingColumns, SchemaDefinition schemaDefinition = null)
        {
            Contracts.AssertValue(env);
            env.AssertValue(pipe);
            env.AssertValueOrNull(schemaDefinition);

            _cursorablePipe = pipe.AsCursorable <TDst>(env, ignoreMissingColumns, schemaDefinition);
            _counter        = 0;
        }
Ejemplo n.º 4
0
 internal PredictionEngine(IHostEnvironment env, Stream modelStream, bool ignoreMissingColumns,
                           SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
     : this(env, StreamChecker(env, modelStream), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
 {
 }
Ejemplo n.º 5
0
        internal BatchPredictionEngine(IHostEnvironment env, Stream modelStream, bool ignoreMissingColumns,
                                       SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
        {
            Contracts.AssertValue(env);
            Contracts.AssertValue(modelStream);
            Contracts.AssertValueOrNull(inputSchemaDefinition);
            Contracts.AssertValueOrNull(outputSchemaDefinition);

            // Initialize pipe.
            _srcDataView = DataViewConstructionUtils.CreateFromEnumerable(env, new TSrc[] { }, inputSchemaDefinition);
            var pipe = DataViewConstructionUtils.LoadPipeWithPredictor(env, modelStream, _srcDataView);

            _pipeEngine = new PipeEngine <TDst>(env, pipe, ignoreMissingColumns, outputSchemaDefinition);
        }
Ejemplo n.º 6
0
        internal BatchPredictionEngine(IHostEnvironment env, IDataView dataPipeline, bool ignoreMissingColumns,
                                       SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
        {
            Contracts.AssertValue(env);
            Contracts.AssertValue(dataPipeline);
            Contracts.AssertValueOrNull(inputSchemaDefinition);
            Contracts.AssertValueOrNull(outputSchemaDefinition);

            // Initialize pipe.
            _srcDataView = DataViewConstructionUtils.CreateFromEnumerable(env, new TSrc[] { }, inputSchemaDefinition);
            var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, dataPipeline, _srcDataView);

            _pipeEngine = new PipeEngine <TDst>(env, pipe, ignoreMissingColumns, outputSchemaDefinition);
        }
Ejemplo n.º 7
0
        private protected virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow <TSrc> inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns,
                                                            SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs <TDst> outputRow)
        {
            var cursorable = TypedCursorable <TDst> .Create(env, new EmptyDataView(env, mapper.OutputSchema), ignoreMissingColumns, outputSchemaDefinition);

            var outputRowLocal = mapper.GetRow(inputRow, col => true);

            outputRow = cursorable.GetRow(outputRowLocal);
            disposer  = inputRow.Dispose;
        }
Ejemplo n.º 8
0
        private protected PredictionEngineBase(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
                                               SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.AssertValue(transformer);
            Transformer = transformer;
            var makeMapper = TransformerChecker(env, transformer);

            env.AssertValue(makeMapper);
            _inputRow = DataViewConstructionUtils.CreateInputRow <TSrc>(env, inputSchemaDefinition);
            PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, out _disposer, out _outputRow);
        }
        /// <summary>
        /// Create a filter transform that is savable iff <paramref name="saveAction"/> and <paramref name="loadFunc"/> are
        /// not null.
        /// </summary>
        /// <param name="env">The host environment</param>
        /// <param name="source">The dataview upon which we construct the transform</param>
        /// <param name="filterFunc">The function by which we transform source to destination columns and decide whether
        /// to keep the row.</param>
        /// <param name="initStateAction">The function that is called once per cursor to initialize state. Can be null.</param>
        /// <param name="saveAction">An action that allows us to save state to the serialization stream. May be
        /// null simultaneously with <paramref name="loadFunc"/>.</param>
        /// <param name="loadFunc">A function that given the serialization stream and a data view, returns
        /// an <see cref="ITransformTemplate"/>. The intent is, this returned object should itself be a
        /// <see cref="CustomMappingTransformer{TSrc,TDst}"/>, but this is not strictly necessary. This delegate should be
        /// a static non-lambda method that this assembly can legally call. May be null simultaneously with
        /// <paramref name="saveAction"/>.</param>
        /// <param name="inputSchemaDefinition">The schema definition overrides for <typeparamref name="TSrc"/></param>
        /// <param name="outputSchemaDefinition">The schema definition overrides for <typeparamref name="TDst"/></param>
        public StatefulFilterTransform(IHostEnvironment env, IDataView source, Func <TSrc, TDst, TState, bool> filterFunc,
                                       Action <TState> initStateAction,
                                       Action <BinaryWriter> saveAction, LambdaTransform.LoadDelegate loadFunc,
                                       SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
            : base(env, RegistrationName, saveAction, loadFunc)
        {
            Host.AssertValue(source, "source");
            Host.AssertValue(filterFunc, "filterFunc");
            Host.AssertValueOrNull(initStateAction);
            Host.AssertValueOrNull(inputSchemaDefinition);
            Host.AssertValueOrNull(outputSchemaDefinition);

            _source                = source;
            _filterFunc            = filterFunc;
            _initStateAction       = initStateAction;
            _inputSchemaDefinition = inputSchemaDefinition;
            _typedSource           = TypedCursorable <TSrc> .Create(Host, Source, false, inputSchemaDefinition);

            var outSchema = InternalSchemaDefinition.Create(typeof(TDst), outputSchemaDefinition);

            _addedSchema = outSchema;
            _bindings    = new ColumnBindings(Schema.Create(Source.Schema), DataViewConstructionUtils.GetSchemaColumns(outSchema));
        }
        public static InternalSchemaDefinition Create(Type userType, SchemaDefinition userSchemaDefinition)
        {
            Contracts.AssertValue(userType);
            Contracts.AssertValue(userSchemaDefinition);

            Column[] dstCols = new Column[userSchemaDefinition.Count];

            for (int i = 0; i < userSchemaDefinition.Count; ++i)
            {
                var col = userSchemaDefinition[i];
                if (col.MemberName == null)
                {
                    throw Contracts.ExceptParam(nameof(userSchemaDefinition), "Null field name detected in schema definition");
                }

                bool       isVector;
                DataKind   kind;
                MemberInfo memberInfo = null;

                if (!col.IsComputed)
                {
                    memberInfo = userType.GetField(col.MemberName);

                    if (memberInfo == null)
                    {
                        memberInfo = userType.GetProperty(col.MemberName);
                    }

                    if (memberInfo == null)
                    {
                        throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field or property with name '{0}' found in type '{1}'",
                                                    col.MemberName,
                                                    userType.FullName);
                    }

                    //Clause to handle the field that may be used to expose the cursor channel.
                    //This field does not need a column.
                    if ((memberInfo is FieldInfo && (memberInfo as FieldInfo).FieldType == typeof(IChannel)) ||
                        (memberInfo is PropertyInfo && (memberInfo as PropertyInfo).PropertyType == typeof(IChannel)))
                    {
                        continue;
                    }

                    GetVectorAndKind(memberInfo, out isVector, out kind);
                }
                else
                {
                    var parameterType = col.ReturnType;
                    if (parameterType == null)
                    {
                        throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No return parameter found in computed column.");
                    }
                    GetVectorAndKind(parameterType, "returnType", out isVector, out kind);
                }
                // Infer the column name.
                var colName = string.IsNullOrEmpty(col.ColumnName) ? col.MemberName : col.ColumnName;
                // REVIEW: Because order is defined, we allow duplicate column names, since producing an IDataView
                // with duplicate column names is completely legal. Possible objection is that we should make it less
                // convenient to produce "hidden" columns, since this may not be of practical use to users.

                ColumnType colType;
                if (col.ColumnType == null)
                {
                    // Infer a type as best we can.
                    PrimitiveType itemType = PrimitiveType.FromKind(kind);
                    colType = isVector ? new VectorType(itemType) : (ColumnType)itemType;
                }
                else
                {
                    // Make sure that the types are compatible with the declared type, including
                    // whether it is a vector type.
                    if (isVector != col.ColumnType.IsVector)
                    {
                        throw Contracts.ExceptParam(nameof(userSchemaDefinition), "Column '{0}' is supposed to be {1}, but type of associated field '{2}' is {3}",
                                                    colName, col.ColumnType.IsVector ? "vector" : "scalar", col.MemberName, isVector ? "vector" : "scalar");
                    }
                    if (kind != col.ColumnType.ItemType.RawKind)
                    {
                        throw Contracts.ExceptParam(nameof(userSchemaDefinition), "Column '{0}' is supposed to have item kind {1}, but associated field has kind {2}",
                                                    colName, col.ColumnType.ItemType.RawKind, kind);
                    }
                    colType = col.ColumnType;
                }

                dstCols[i] = col.IsComputed ?
                             new Column(colName, colType, col.Generator, col.Metadata)
                    : new Column(colName, colType, memberInfo, col.Metadata);
            }
            return(new InternalSchemaDefinition(dstCols));
        }
        public static InternalSchemaDefinition Create(Type userType, Direction direction)
        {
            var userSchemaDefinition = SchemaDefinition.Create(userType, direction);

            return(Create(userType, userSchemaDefinition));
        }