예제 #1
0
        public static IDataScorerTransform GetScorer(
            TScorerFactory scorer,
            IPredictor predictor,
            IDataView input,
            string featureColName,
            string groupColName,
            IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > customColumns,
            IHostEnvironment env,
            RoleMappedSchema trainSchema,
            IComponentFactory <IPredictor, ISchemaBindableMapper> mapperFactory = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValueOrNull(scorer);
            env.CheckValue(predictor, nameof(predictor));
            env.CheckValue(input, nameof(input));
            env.CheckValueOrNull(featureColName);
            env.CheckValueOrNull(groupColName);
            env.CheckValueOrNull(customColumns);
            env.CheckValueOrNull(trainSchema);

            var schema = new RoleMappedSchema(input.Schema, label: null, feature: featureColName, group: groupColName, custom: customColumns, opt: true);
            var sc     = GetScorerComponentAndMapper(predictor, scorer, schema, env, mapperFactory, out var mapper);

            return(sc.CreateComponent(env, input, mapper, trainSchema));
        }
예제 #2
0
        /// <summary>
        /// Given a predictor, an optional mapper factory, and an optional scorer factory settings,
        /// produces a compatible ISchemaBindableMapper.
        /// First, it tries to instantiate the bindable mapper using the mapper factory.
        /// Next, it tries to instantiate the bindable mapper using the <paramref name="scorerFactorySettings"/>
        /// (this will only succeed if there's a registered BindableMapper creation method with load name equal to the one
        /// of the scorer).
        /// If the above fails, it checks whether the predictor implements <see cref="ISchemaBindableMapper"/>
        /// directly.
        /// If this also isn't true, it will create a 'matching' standard mapper.
        /// </summary>
        public static ISchemaBindableMapper GetSchemaBindableMapper(
            IHostEnvironment env,
            IPredictor predictor,
            IComponentFactory<IPredictor, ISchemaBindableMapper> mapperFactory = null,
            ICommandLineComponentFactory scorerFactorySettings = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(predictor, nameof(predictor));
            env.CheckValueOrNull(mapperFactory);
            env.CheckValueOrNull(scorerFactorySettings);

            // if the mapperFactory was supplied, use it
            if (mapperFactory != null)
                return mapperFactory.CreateComponent(env, predictor);

            // See if we can instantiate a mapper using scorer arguments.
            if (scorerFactorySettings != null && TryCreateBindableFromScorer(env, predictor, scorerFactorySettings, out var bindable))
                return bindable;

            // The easy case is that the predictor implements the interface.
            bindable = predictor as ISchemaBindableMapper;
            if (bindable != null)
                return bindable;

            // Use one of the standard wrappers.
            if (predictor is IValueMapperDist)
                return new SchemaBindableBinaryPredictorWrapper(predictor);

            return new SchemaBindablePredictorWrapper(predictor);
        }
예제 #3
0
 /// <summary>
 /// Create a new <see cref="IDataView"/> over an enumerable of the items of user-defined type.
 /// The user maintains ownership of the <paramref name="data"/> and the resulting data view will
 /// never alter the contents of the <paramref name="data"/>.
 /// Since <see cref="IDataView"/> is assumed to be immutable, the user is expected to support
 /// multiple enumerations of the <paramref name="data"/> that would return the same results, unless
 /// the user knows that the data will only be cursored once.
 ///
 /// One typical usage for streaming data view could be: create the data view that lazily loads data
 /// as needed, then apply pre-trained transformations to it and cursor through it for transformation
 /// results.
 /// </summary>
 /// <typeparam name="TRow">The user-defined item type.</typeparam>
 /// <param name="data">The enumerable data containing type <typeparamref name="TRow"/> to convert to an<see cref="IDataView"/>.</param>
 /// <param name="schemaDefinition">The optional schema definition of the data view to create. If <c>null</c>,
 /// the schema definition is inferred from <typeparamref name="TRow"/>.</param>
 /// <returns>The constructed <see cref="IDataView"/>.</returns>
 /// <example>
 /// <format type="text/markdown">
 /// <![CDATA[
 /// [!code-csharp[LoadFromEnumerable](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/LoadFromEnumerable.cs)]
 /// ]]>
 /// </format>
 /// </example>
 public IDataView LoadFromEnumerable <TRow>(IEnumerable <TRow> data, SchemaDefinition schemaDefinition = null)
     where TRow : class
 {
     _env.CheckValue(data, nameof(data));
     _env.CheckValueOrNull(schemaDefinition);
     return(DataViewConstructionUtils.CreateFromEnumerable(_env, data, schemaDefinition));
 }
 /// <summary>
 /// Create an on-demand prediction engine.
 /// </summary>
 /// <param name="env">The host environment to use.</param>
 /// <param name="dataPipe">The transformation pipe that may or may not include a scorer.</param>
 /// <param name="ignoreMissingColumns">Whether to ignore missing columns in the data view.</param>
 /// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TSrc"/> type.</param>
 /// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TDst"/> type.</param>
 public static PredictionEngine <TSrc, TDst> CreatePredictionEngine <TSrc, TDst>(this IHostEnvironment env, IDataView dataPipe,
                                                                                 bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
     where TSrc : class
     where TDst : class, new()
 {
     Contracts.CheckValue(env, nameof(env));
     env.CheckValue(dataPipe, nameof(dataPipe));
     env.CheckValueOrNull(inputSchemaDefinition);
     env.CheckValueOrNull(outputSchemaDefinition);
     return(new PredictionEngine <TSrc, TDst>(env, dataPipe, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition));
 }
 /// <summary>
 /// Create a batch prediction engine.
 /// </summary>
 /// <param name="env">The host environment to use.</param>
 /// <param name="modelStream">The stream to deserialize the pipeline (transforms and predictor) from.</param>
 /// <param name="ignoreMissingColumns">Whether to ignore missing columns in the data view.</param>
 /// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TSrc"/> type.</param>
 /// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TDst"/> type.</param>
 public static BatchPredictionEngine <TSrc, TDst> CreateBatchPredictionEngine <TSrc, TDst>(this IHostEnvironment env, Stream modelStream,
                                                                                           bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
     where TSrc : class
     where TDst : class, new()
 {
     Contracts.CheckValue(env, nameof(env));
     env.CheckValue(modelStream, nameof(modelStream));
     env.CheckValueOrNull(inputSchemaDefinition);
     env.CheckValueOrNull(outputSchemaDefinition);
     return(new BatchPredictionEngine <TSrc, TDst>(env, modelStream, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition));
 }
예제 #6
0
 /// <summary>
 /// Create an on-demand prediction engine.
 /// </summary>
 /// <param name="env">The host environment to use.</param>
 /// <param name="transformer">The transformer.</param>
 /// <param name="ignoreMissingColumns">Whether to ignore missing columns in the data view.</param>
 /// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TSrc"/> type.</param>
 /// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TDst"/> type.</param>
 internal static PredictionEngine <TSrc, TDst> CreatePredictionEngine <TSrc, TDst>(this IHostEnvironment env, ITransformer transformer,
                                                                                   bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
     where TSrc : class
     where TDst : class, new()
 {
     Contracts.CheckValue(env, nameof(env));
     env.CheckValue(transformer, nameof(transformer));
     env.CheckValueOrNull(inputSchemaDefinition);
     env.CheckValueOrNull(outputSchemaDefinition);
     return(new PredictionEngine <TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition));
 }
예제 #7
0
        internal static RoleMappedData CreateExamples(this IHostEnvironment env, IDataView data, string features, string label = null,
                                                      string group = null, string weight = null, IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > custom = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValueOrNull(label);
            env.CheckValueOrNull(features);
            env.CheckValueOrNull(group);
            env.CheckValueOrNull(weight);
            env.CheckValueOrNull(custom);

            return(new RoleMappedData(data, label, features, group, weight, name: null, custom: custom));
        }
예제 #8
0
        /// <summary>
        /// This creates a filter transform that can 'accept' or 'decline' any row of the data based on the contents of the row
        /// or state of the cursor.
        /// This is a 'stateful non-savable' version of the filter: the filter function is guaranteed to be invoked once per
        /// every row of the data set, in sequence (non-parallelizable); one user-defined state object will be allocated per cursor and passed to the
        /// filter function every time.
        /// If <typeparamref name="TSrc"/> or <typeparamref name="TState"/> implement the <see cref="IDisposable" /> interface, they will be disposed after use.
        /// </summary>
        /// <typeparam name="TSrc">The type that describes what 'source' columns are consumed from the
        /// input <see cref="IDataView"/>.</typeparam>
        /// <typeparam name="TState">The type of the state object to allocate per cursor.</typeparam>
        /// <param name="env">The host environment to use.</param>
        /// <param name="source">The input data to apply transformation to.</param>
        /// <param name="filterFunc">The user-defined function that determines whether to keep the row or discard it. First parameter
        /// is the current row's contents, the second parameter is the cursor-specific state object.</param>
        /// <param name="initStateAction">The function that is called once per cursor to initialize state. Can be null.</param>
        /// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is
        /// inferred from the <typeparamref name="TSrc"/> type.</param>
        /// <returns></returns>
        public static ITransformTemplate CreateFilter <TSrc, TState>(IHostEnvironment env, IDataView source,
                                                                     Func <TSrc, TState, bool> filterFunc, Action <TState> initStateAction, SchemaDefinition inputSchemaDefinition = null)
            where TSrc : class, new()
            where TState : class, new()
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(source, nameof(source));
            env.CheckValue(filterFunc, nameof(filterFunc));
            env.CheckValueOrNull(initStateAction);
            env.CheckValueOrNull(inputSchemaDefinition);

            return(new StatefulFilterTransform <TSrc, object, TState>(env, source,
                                                                      (src, dst, state) => filterFunc(src, state), initStateAction, null, null, inputSchemaDefinition));
        }
예제 #9
0
        /// <summary>
        /// This creates a transform that generates additional columns to the provided <see cref="IDataView"/>.
        /// It does not change the number of rows, and can be seen as a result of application of the user's
        /// function to every row of the input data. Similarly to the existing <see cref="IDataTransform"/>'s,
        /// this object can be treated as both the 'transformation' algorithm (which can be then applied to
        /// different data by calling <see cref="ITransformTemplate.ApplyToData"/>), and the transformed data (which can be
        /// enumerated upon by calling <c>GetRowCursor</c> or <c>AsCursorable{TRow}</c>). If <typeparamref name="TSrc"/> or
        /// <typeparamref name="TDst"/> implement the <see cref="IDisposable" /> interface, they will be disposed after use.
        ///
        /// This is a 'stateless non-savable' version of the transform.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="source">The input data to apply transformation to.</param>
        /// <param name="mapAction">The function that performs the transformation. This must be a 'pure'
        /// function which should only depend on the current <typeparamref name="TSrc"/> object and be
        /// re-entrant: the system may call the function on some (or all) the input rows, may do this in
        /// any order and from multiple threads at once. The function may utilize closures, as long as this
        /// is done in re-entrant fashion and without side effects.</param>
        /// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is
        /// inferred from the <typeparamref name="TSrc"/> type.</param>
        /// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is
        /// inferred from the <typeparamref name="TDst"/> type.</param>
        /// <typeparam name="TSrc">The type that describes what 'source' columns are consumed from the
        /// input <see cref="IDataView"/>.</typeparam>
        /// <typeparam name="TDst">The type that describes what new columns are added by this transform.</typeparam>
        /// <returns>A non-savable mapping transform from the input source to the destination</returns>
        public static ITransformTemplate CreateMap <TSrc, TDst>(IHostEnvironment env, IDataView source, Action <TSrc, TDst> mapAction,
                                                                SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
            where TSrc : class, new()
            where TDst : class, new()
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(source, nameof(source));
            env.CheckValue(mapAction, nameof(mapAction));
            env.CheckValueOrNull(inputSchemaDefinition);
            env.CheckValueOrNull(outputSchemaDefinition);

            return(new MapTransform <TSrc, TDst>(env, source, mapAction, null, null,
                                                 inputSchemaDefinition, outputSchemaDefinition));
        }
예제 #10
0
        /// <summary>
        /// Configures a reader for text files.
        /// </summary>
        /// <typeparam name="TShape">The type shape parameter, which must be a valid-schema shape. As a practical
        /// matter this is generally not explicitly defined from the user, but is instead inferred from the return
        /// type of the <paramref name="func"/> where one takes an input <see cref="Context"/> and uses it to compose
        /// a shape-type instance describing what the columns are and how to load them from the file.</typeparam>
        /// <param name="env">The environment.</param>
        /// <param name="func">The delegate that describes what fields to read from the text file, as well as
        /// describing their input type. The way in which it works is that the delegate is fed a <see cref="Context"/>,
        /// and the user composes a shape type with <see cref="PipelineColumn"/> instances out of that <see cref="Context"/>.
        /// The resulting data will have columns with the names corresponding to their names in the shape type.</param>
        /// <param name="files">Input files. If <c>null</c> then no files are read, but this means that options or
        /// configurations that require input data for initialization (for example, <paramref name="hasHeader"/> or
        /// <see cref="Context.LoadFloat(int, int?)"/>) with a <c>null</c> second argument.</param>
        /// <param name="hasHeader">Data file has header with feature names.</param>
        /// <param name="separator">Text field separator.</param>
        /// <param name="allowQuoting">Whether the input -may include quoted values, which can contain separator
        /// characters, colons, and distinguish empty values from missing values. When true, consecutive separators
        /// denote a missing value and an empty value is denoted by <c>""</c>. When false, consecutive separators
        /// denote an empty value.</param>
        /// <param name="allowSparse">Whether the input may include sparse representations.</param>
        /// <param name="trimWhitspace">Remove trailing whitespace from lines.</param>
        /// <returns>A configured statically-typed reader for text files.</returns>
        public static DataReader <IMultiStreamSource, TShape> CreateReader <[IsShape] TShape>(
            IHostEnvironment env, Func <Context, TShape> func, IMultiStreamSource files = null,
            bool hasHeader     = false, char separator = '\t', bool allowQuoting = true, bool allowSparse = true,
            bool trimWhitspace = false)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(func, nameof(func));
            env.CheckValueOrNull(files);

            // Populate all args except the columns.
            var args = new TextLoader.Arguments();

            args.AllowQuoting   = allowQuoting;
            args.AllowSparse    = allowSparse;
            args.HasHeader      = hasHeader;
            args.Separators     = new[] { separator };
            args.TrimWhitespace = trimWhitspace;

            var rec = new TextReconciler(args, files);
            var ctx = new Context(rec);

            using (var ch = env.Start("Initializing " + nameof(TextLoader)))
            {
                var readerEst = StaticPipeUtils.ReaderEstimatorAnalyzerHelper(env, ch, ctx, rec, func);
                Contracts.AssertValue(readerEst);
                return(readerEst.Fit(files));
            }
        }
예제 #11
0
        public static IDisposable CreateAssemblyRegistrar(IHostEnvironment env, string loadAssembliesPath = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValueOrNull(loadAssembliesPath);

            return(new AssemblyRegistrar(env, loadAssembliesPath));
        }
예제 #12
0
        /// <summary>
        /// Given a predictor and an optional scorer SubComponent, produces a compatible ISchemaBindableMapper.
        /// First, it tries to instantiate the bindable mapper using the <paramref name="scorerSettings"/>
        /// (this will only succeed if there's a registered BindableMapper creation method with load name equal to the one
        /// of the scorer).
        /// If the above fails, it checks whether the predictor implements <see cref="ISchemaBindableMapper"/>
        /// directly.
        /// If this also isn't true, it will create a 'matching' standard mapper.
        /// </summary>
        public static ISchemaBindableMapper GetSchemaBindableMapper(IHostEnvironment env, IPredictor predictor,
                                                                    SubComponent <IDataScorerTransform, SignatureDataScorer> scorerSettings)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(predictor, nameof(predictor));
            env.CheckValueOrNull(scorerSettings);

            // See if we can instantiate a mapper using scorer arguments.
            if (scorerSettings.IsGood() && TryCreateBindableFromScorer(env, predictor, scorerSettings, out var bindable))
            {
                return(bindable);
            }

            // The easy case is that the predictor implements the interface.
            bindable = predictor as ISchemaBindableMapper;
            if (bindable != null)
            {
                return(bindable);
            }

            // Use one of the standard wrappers.
            if (predictor is IValueMapperDist)
            {
                return(new SchemaBindableBinaryPredictorWrapper(predictor));
            }

            return(new SchemaBindablePredictorWrapper(predictor));
        }
 /// <summary>
 /// Create a new <see cref="IDataView"/> over an in-memory collection of the items of user-defined type.
 /// The user maintains ownership of the <paramref name="data"/> and the resulting data view will
 /// never alter the contents of the <paramref name="data"/>.
 /// Since <see cref="IDataView"/> is assumed to be immutable, the user is expected to not
 /// modify the contents of <paramref name="data"/> while the data view is being actively cursored.
 ///
 /// One typical usage for in-memory data view could be: create the data view, train a predictor.
 /// Once the predictor is fully trained, modify the contents of the underlying collection and
 /// train another predictor.
 /// </summary>
 /// <typeparam name="TRow">The user-defined item type.</typeparam>
 /// <param name="env">The host environment to use for data view creation.</param>
 /// <param name="data">The data to wrap around.</param>
 /// <param name="schemaDefinition">The optional schema definition of the data view to create. If <c>null</c>,
 /// the schema definition is inferred from <typeparamref name="TRow"/>.</param>
 /// <returns>The constructed <see cref="IDataView"/>.</returns>
 public static IDataView CreateDataView <TRow>(this IHostEnvironment env, IList <TRow> data, SchemaDefinition schemaDefinition = null)
     where TRow : class
 {
     Contracts.CheckValue(env, nameof(env));
     env.CheckValue(data, nameof(data));
     env.CheckValueOrNull(schemaDefinition);
     return(DataViewConstructionUtils.CreateFromList(env, data, schemaDefinition));
 }
예제 #14
0
        /// <summary>
        /// Generate a strongly-typed cursorable wrapper of the <see cref="IDataView"/>.
        /// </summary>
        /// <typeparam name="TRow">The user-defined row type.</typeparam>
        /// <param name="data">The underlying data view.</param>
        /// <param name="env">The environment.</param>
        /// <param name="ignoreMissingColumns">Whether to ignore the case when a requested column is not present in the data view.</param>
        /// <param name="schemaDefinition">Optional user-provided schema definition. If it is not present, the schema is inferred from the definition of T.</param>
        /// <returns>The cursorable wrapper of <paramref name="data"/>.</returns>
        public static ICursorable <TRow> AsCursorable <TRow>(this IDataView data, IHostEnvironment env, bool ignoreMissingColumns = false,
                                                             SchemaDefinition schemaDefinition = null)
            where TRow : class, new()
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(data, nameof(data));
            env.CheckValueOrNull(schemaDefinition);

            return(TypedCursorable <TRow> .Create(env, data, ignoreMissingColumns, schemaDefinition));
        }
        public static IDataScorerTransform GetScorer(SubComponent <IDataScorerTransform, SignatureDataScorer> scorer,
                                                     IPredictor predictor, IDataView input, string featureColName, string groupColName,
                                                     IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > customColumns, IHostEnvironment env, RoleMappedSchema trainSchema)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValueOrNull(scorer);
            env.CheckValue(predictor, nameof(predictor));
            env.CheckValue(input, nameof(input));
            env.CheckValueOrNull(featureColName);
            env.CheckValueOrNull(groupColName);
            env.CheckValueOrNull(customColumns);
            env.CheckValueOrNull(trainSchema);

            var schema = TrainUtils.CreateRoleMappedSchemaOpt(input.Schema, featureColName, groupColName, customColumns);
            ISchemaBoundMapper mapper;
            var sc = GetScorerComponentAndMapper(predictor, scorer, schema, env, out mapper);

            return(sc.CreateInstance(env, input, mapper, trainSchema));
        }
예제 #16
0
        internal static IDataScorerTransform CreateDefaultScorer(this IHostEnvironment env, RoleMappedData data,
                                                                 IPredictor predictor, RoleMappedSchema trainSchema = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(data, nameof(data));
            env.CheckValue(predictor, nameof(predictor));
            env.CheckValueOrNull(trainSchema);

            return(ScoreUtils.GetScorer(predictor, data, env, trainSchema));
        }
예제 #17
0
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, IComponentFactory <IPredictor, ISchemaBindableMapper> mapperFactory)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(args, nameof(args));
            env.CheckValue(args.Trainer, nameof(args.Trainer),
                           "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead.");
            env.CheckValue(input, nameof(input));
            env.CheckValueOrNull(mapperFactory);

            return(Create(env, args, args.Trainer.CreateComponent(env), input, mapperFactory));
        }
 /// <param name="env">The environment.</param>
 /// <param name="registrationName">The registration name.</param>
 /// <param name="inputDataView">The input data view.</param>
 /// <param name="splitColumn">The column to use for splitting data into folds.</param>
 /// <param name="args">Cross validation arguments.</param>
 /// <param name="createExamples">The delegate to create RoleMappedData</param>
 /// <param name="applyTransformsToTestData">The delegate to apply the transforms from the train pipeline to the test data</param>
 /// <param name="scorer">The scorer</param>
 /// <param name="evaluator">The evaluator</param>
 /// <param name="getValidationDataView">The delegate to create validation data view</param>
 /// <param name="applyTransformsToValidationData">The delegate to apply the transforms from the train pipeline to the validation data</param>
 /// <param name="inputPredictor">The input predictor, for the continue training option</param>
 /// <param name="cmd">The command string.</param>
 /// <param name="loader">Original loader so we can construct correct pipeline for model saving.</param>
 /// <param name="savePerInstance">Whether to produce the per-instance data view.</param>
 /// <returns></returns>
 public FoldHelper(
     IHostEnvironment env,
     string registrationName,
     IDataView inputDataView,
     string splitColumn,
     Arguments args,
     Func <IHostEnvironment, IChannel, IDataView, ITrainer, RoleMappedData> createExamples,
     Func <IHostEnvironment, IChannel, IDataView, RoleMappedData, IDataView, RoleMappedData> applyTransformsToTestData,
     SubComponent <IDataScorerTransform, SignatureDataScorer> scorer,
     SubComponent <IMamlEvaluator, SignatureMamlEvaluator> evaluator,
     Func <IDataView> getValidationDataView = null,
     Func <IHostEnvironment, IChannel, IDataView, RoleMappedData, IDataView, RoleMappedData> applyTransformsToValidationData = null,
     IPredictor inputPredictor = null,
     string cmd           = null,
     IDataLoader loader   = null,
     bool savePerInstance = false)
 {
     Contracts.CheckValue(env, nameof(env));
     env.CheckNonWhiteSpace(registrationName, nameof(registrationName));
     env.CheckValue(inputDataView, nameof(inputDataView));
     env.CheckValue(splitColumn, nameof(splitColumn));
     env.CheckParam(args.NumFolds > 1, nameof(args.NumFolds));
     env.CheckValue(createExamples, nameof(createExamples));
     env.CheckValue(applyTransformsToTestData, nameof(applyTransformsToTestData));
     env.CheckParam(args.Trainer.IsGood(), nameof(args.Trainer));
     env.CheckValueOrNull(scorer);
     env.CheckValueOrNull(evaluator);
     env.CheckValueOrNull(args.Calibrator);
     env.CheckParam(args.MaxCalibrationExamples > 0, nameof(args.MaxCalibrationExamples));
     env.CheckParam(getValidationDataView == null || applyTransformsToValidationData != null, nameof(applyTransformsToValidationData));
     env.CheckValueOrNull(inputPredictor);
     env.CheckValueOrNull(cmd);
     env.CheckValueOrNull(args.OutputModelFile);
     env.CheckValueOrNull(loader);
     _env = env;
     _registrationName          = registrationName;
     _inputDataView             = inputDataView;
     _splitColumn               = splitColumn;
     _numFolds                  = args.NumFolds;
     _createExamples            = createExamples;
     _applyTransformsToTestData = applyTransformsToTestData;
     _trainer                         = args.Trainer;
     _scorer                          = scorer;
     _evaluator                       = evaluator;
     _calibrator                      = args.Calibrator;
     _maxCalibrationExamples          = args.MaxCalibrationExamples;
     _useThreads                      = args.UseThreads;
     _cacheData                       = args.CacheData;
     _getValidationDataView           = getValidationDataView;
     _applyTransformsToValidationData = applyTransformsToValidationData;
     _inputPredictor                  = inputPredictor;
     _cmd             = cmd;
     _outputModelFile = args.OutputModelFile;
     _loader          = loader;
     _savePerInstance = savePerInstance;
 }
예제 #19
0
        /// <summary>
        /// This is a 'stateful non-savable' version of the map transform: the mapping function is guaranteed to be invoked once per
        /// every row of the data set, in sequence; one user-defined state object will be allocated per cursor and passed to the
        /// map function every time. If <typeparamref name="TSrc"/>, <typeparamref name="TDst"/>, or
        /// <typeparamref name="TState"/> implement the <see cref="IDisposable" /> interface, they will be disposed after use.
        /// </summary>
        /// <typeparam name="TSrc">The type that describes what 'source' columns are consumed from the
        /// input <see cref="IDataView"/>.</typeparam>
        /// <typeparam name="TState">The type of the state object to allocate per cursor.</typeparam>
        /// <typeparam name="TDst">The type that describes what new columns are added by this transform.</typeparam>
        /// <param name="env">The host environment to use.</param>
        /// <param name="source">The input data to apply transformation to.</param>
        /// <param name="mapAction">The function that performs the transformation. The function should transform its <typeparamref name="TSrc"/>
        /// argument into its <typeparamref name="TDst"/> argument and can utilize the per-cursor <typeparamref name="TState"/> state.</param>
        /// <param name="initStateAction">The function that is called once per cursor to initialize state. Can be null.</param>
        /// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is
        /// inferred from the <typeparamref name="TSrc"/> type.</param>
        /// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is
        /// inferred from the <typeparamref name="TDst"/> type.</param>
        public static ITransformTemplate CreateMap <TSrc, TDst, TState>(IHostEnvironment env, IDataView source,
                                                                        Action <TSrc, TDst, TState> mapAction, Action <TState> initStateAction,
                                                                        SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
            where TSrc : class, new()
            where TDst : class, new()
            where TState : class, new()
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(source, nameof(source));
            env.CheckValue(mapAction, nameof(mapAction));
            env.CheckValueOrNull(initStateAction);
            env.CheckValueOrNull(inputSchemaDefinition);
            env.CheckValueOrNull(outputSchemaDefinition);

            return(new StatefulFilterTransform <TSrc, TDst, TState>(env, source,
                                                                    (src, dst, state) =>
            {
                mapAction(src, dst, state);
                return true;
            }, initStateAction, null, null, inputSchemaDefinition, outputSchemaDefinition));
        }
예제 #20
0
        /// <summary>
        /// Convert an <see cref="IDataView"/> into a strongly-typed <see cref="IEnumerable{TRow}"/>.
        /// </summary>
        /// <typeparam name="TRow">The user-defined row type.</typeparam>
        /// <param name="data">The underlying data view.</param>
        /// <param name="env">The environment.</param>
        /// <param name="reuseRowObject">Whether to return the same object on every row, or allocate a new one per row.</param>
        /// <param name="ignoreMissingColumns">Whether to ignore the case when a requested column is not present in the data view.</param>
        /// <param name="schemaDefinition">Optional user-provided schema definition. If it is not present, the schema is inferred from the definition of T.</param>
        /// <returns>The <see cref="IEnumerable{TRow}"/> that holds the data in <paramref name="data"/>. It can be enumerated multiple times.</returns>
        public static IEnumerable <TRow> AsEnumerable <TRow>(this IDataView data, IHostEnvironment env, bool reuseRowObject,
                                                             bool ignoreMissingColumns = false, SchemaDefinition schemaDefinition = null)
            where TRow : class, new()
        {
            Contracts.AssertValue(env);
            env.CheckValue(data, nameof(data));
            env.CheckValueOrNull(schemaDefinition);

            var engine = new PipeEngine <TRow>(env, data, ignoreMissingColumns, schemaDefinition);

            return(engine.RunPipe(reuseRowObject));
        }
        // REVIEW: AppendRowsDataView now only checks schema consistency up to column names and types.
        // A future task will be to ensure that the sources are consistent on the metadata level.

        /// <summary>
        /// Create a dataview by appending the rows of the sources.
        ///
        /// All sources must be consistent with the passed-in schema in the number of columns, column names,
        /// and column types. If schema is null, the first source's schema will be used.
        /// </summary>
        /// <param name="env">The host environment.</param>
        /// <param name="schema">The schema for the result. If this is null, the first source's schema will be used.</param>
        /// <param name="sources">The sources to be appended.</param>
        /// <returns>The resulting IDataView.</returns>
        public static IDataView Create(IHostEnvironment env, Schema schema, params IDataView[] sources)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(sources, nameof(sources));
            env.CheckNonEmpty(sources, nameof(sources), "There must be at least one source.");
            env.CheckParam(sources.All(s => s != null), nameof(sources));
            env.CheckValueOrNull(schema);
            if (sources.Length == 1)
            {
                return(sources[0]);
            }
            return(new AppendRowsDataView(env, schema, sources));
        }
        /// <summary>
        /// Save a transformer model and the schema of the data that was used to train it to the stream.
        /// </summary>
        /// <param name="model">The trained model to be saved.</param>
        /// <param name="inputSchema">The schema of the input to the transformer. This can be null.</param>
        /// <param name="stream">A writeable, seekable stream to save to.</param>
        public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream)
        {
            _env.CheckValue(model, nameof(model));
            _env.CheckValueOrNull(inputSchema);
            _env.CheckValue(stream, nameof(stream));

            using (var rep = RepositoryWriter.CreateNew(stream))
            {
                ModelSaveContext.SaveModel(rep, model, CompositeDataLoader <object, ITransformer> .TransformerDirectory);
                SaveInputSchema(inputSchema, rep);
                rep.Commit();
            }
        }
예제 #23
0
        private static bool TryCreateInstance <TRes>(IHostEnvironment env, Type signatureType, out TRes result, string name, string options, params object[] extra)
            where TRes : class
        {
            Contracts.CheckValue(env, nameof(env));
            env.Check(signatureType.BaseType == typeof(MulticastDelegate));
            env.CheckValueOrNull(name);

            string            nameLower = (name ?? "").ToLowerInvariant().Trim();
            LoadableClassInfo info      = FindClassCore(new LoadableClassInfo.Key(nameLower, signatureType));

            if (info == null)
            {
                result = null;
                return(false);
            }

            if (!typeof(TRes).IsAssignableFrom(info.Type))
            {
                throw env.Except("Loadable class '{0}' does not derive from '{1}'", name, typeof(TRes).FullName);
            }

            int carg = Utils.Size(extra);

            if (info.ExtraArgCount != carg)
            {
                throw env.Except(
                          "Wrong number of extra parameters for loadable class '{0}', need '{1}', given '{2}'",
                          name, info.ExtraArgCount, carg);
            }

            if (info.ArgType == null)
            {
                if (!string.IsNullOrEmpty(options))
                {
                    throw env.Except("Loadable class '{0}' doesn't support settings", name);
                }
                result = (TRes)info.CreateInstance(env, null, extra);
                return(true);
            }

            object args = info.CreateArguments();

            if (args == null)
            {
                throw Contracts.Except("Can't instantiate arguments object '{0}' for '{1}'", info.ArgType.Name, name);
            }

            ParseArguments(env, args, options, name);
            result = (TRes)info.CreateInstance(env, args, extra);
            return(true);
        }
예제 #24
0
        /// <summary>
        /// Creates a data scorer from the 'LoadName{settings}' string.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="settings">The settings string.</param>
        /// <param name="data">The data to score.</param>
        /// <param name="predictor">The predictor to score.</param>
        /// <param name="trainSchema">The training data schema from which the scorer can optionally extract
        /// additional information, e.g., label names. If this is <c>null</c>, no information will be
        /// extracted.</param>
        /// <returns>The scored data.</returns>
        public static IDataScorerTransform CreateScorer(this IHostEnvironment env, string settings,
                                                        RoleMappedData data, Predictor predictor, RoleMappedSchema trainSchema = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(data, nameof(data));
            env.CheckValue(predictor, nameof(predictor));
            env.CheckValueOrNull(trainSchema);

            ICommandLineComponentFactory scorerFactorySettings = ParseScorerSettings(settings);
            var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings);
            var mapper   = bindable.Bind(env, data.Schema);

            return(CreateCore <IDataScorerTransform, SignatureDataScorer>(env, settings, data.Data, mapper, trainSchema));
        }
        /// <summary>
        /// Walks back the Source chain of the <see cref="IDataTransform"/> up to the <paramref name="oldSource"/>
        /// (or <see cref="IDataLoader"/> if <paramref name="oldSource"/> is <c>null</c>),
        /// and reapplies all transforms in the chain, to produce the same chain but bound to the different data.
        /// It is valid to have no transforms: in this case the result will be equal to <paramref name="newSource"/>
        /// If <paramref name="oldSource"/> is specified and not found in the pipe, an exception is thrown.
        /// </summary>
        /// <param name="env">The environment to use.</param>
        /// <param name="chain">The end of the chain.</param>
        /// <param name="newSource">The new data to attach the chain to.</param>
        /// <param name="oldSource">The 'old source' of the pipe, that doesn't need to be reapplied. If null, all transforms are reapplied.</param>
        /// <returns>The resulting data view.</returns>
        public static IDataView ApplyAllTransformsToData(IHostEnvironment env, IDataView chain, IDataView newSource, IDataView oldSource = null)
        {
            // REVIEW: have a variation that would selectively apply transforms?
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(chain, nameof(chain));
            env.CheckValue(newSource, nameof(newSource));
            env.CheckValueOrNull(oldSource);

            // Backtrack the chain until we reach a chain start or a non-transform.
            // REVIEW: we 'unwrap' the composite data loader here and step through its pipeline.
            // It's probably more robust to make CompositeDataLoader not even be an IDataView, this
            // would force the user to do the right thing and unwrap on his end.
            var cdl = chain as CompositeDataLoader;

            if (cdl != null)
            {
                chain = cdl.View;
            }

            var            transforms = new List <IDataTransform>();
            IDataTransform xf;

            while ((xf = chain as IDataTransform) != null)
            {
                if (chain == oldSource)
                {
                    break;
                }
                transforms.Add(xf);
                chain = xf.Source;

                cdl = chain as CompositeDataLoader;
                if (cdl != null)
                {
                    chain = cdl.View;
                }
            }
            transforms.Reverse();

            env.Check(oldSource == null || chain == oldSource, "Source data not found in the chain");

            IDataView newChain = newSource;

            foreach (var transform in transforms)
            {
                newChain = ApplyTransformToData(env, transform, newChain);
            }

            return(newChain);
        }
예제 #26
0
        /// <summary>
        /// Create a TransformModel containing the given (optional) transforms applied to the
        /// given root schema.
        /// </summary>
        public TransformModel(IHostEnvironment env, Schema schemaRoot, IDataTransform[] xfs)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(schemaRoot, nameof(schemaRoot));
            env.CheckValueOrNull(xfs);

            IDataView view = new EmptyDataView(env, schemaRoot);

            _schemaRoot = view.Schema;

            if (Utils.Size(xfs) > 0)
            {
                foreach (var xf in xfs)
                {
                    env.AssertValue(xf, "xfs", "Transforms should not be null");
                    view = ApplyTransformUtils.ApplyTransformToData(env, xf, view);
                }
            }

            _chain = view;
        }
예제 #27
0
        /// <summary>
        /// Creates a data scorer from the 'LoadName{settings}' string.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="settings">The settings string.</param>
        /// <param name="data">The data to score.</param>
        /// <param name="predictor">The predictor to score.</param>
        /// <param name="trainSchema">The training data schema from which the scorer can optionally extract
        /// additional information, for example, label names. If this is <c>null</c>, no information will be
        /// extracted.</param>
        /// <returns>The scored data.</returns>
        internal static IDataScorerTransform CreateScorer(this IHostEnvironment env, string settings,
                                                          RoleMappedData data, IPredictor predictor, RoleMappedSchema trainSchema = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(data, nameof(data));
            env.CheckValue(predictor, nameof(predictor));
            env.CheckValueOrNull(trainSchema);

            Type factoryType   = typeof(IComponentFactory <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>);
            Type signatureType = typeof(SignatureDataScorer);

            ICommandLineComponentFactory scorerFactorySettings = CmdParser.CreateComponentFactory(
                factoryType,
                signatureType,
                settings);

            var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor, scorerFactorySettings: scorerFactorySettings);
            var mapper   = bindable.Bind(env, data.Schema);

            return(CreateCore <IDataScorerTransform>(env, factoryType, signatureType, settings, data.Data, mapper, trainSchema));
        }
예제 #28
0
        /// <summary>
        /// Save a transformer model and the loader used to create its input data to the stream.
        /// </summary>
        /// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
        /// for an empty transformer chain. Upon loading with <see cref="LoadWithDataLoader(Stream, out IDataLoader{IMultiStreamSource})"/>
        /// the returned value will be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
        /// <param name="loader">The loader that was used to create data to train the model.</param>
        /// <param name="stream">A writeable, seekable stream to save to.</param>
        public void Save <TSource>(ITransformer model, IDataLoader <TSource> loader, Stream stream)
        {
            _env.CheckValue(loader, nameof(loader));
            _env.CheckValueOrNull(model);
            _env.CheckValue(stream, nameof(stream));

            // For the sake of consistency of this API specifically, when called upon we save any transformer
            // in a single element transformer chain.
            var chainedModel    = model == null ? null : new TransformerChain <ITransformer>(model);
            var compositeLoader = new CompositeDataLoader <TSource, ITransformer>(loader, chainedModel);

            using (var rep = RepositoryWriter.CreateNew(stream))
            {
                ModelSaveContext.SaveModel(rep, compositeLoader, null);
                rep.Commit();
            }
        }
        /// <summary>
        /// This function performs a number of checks on the inputs and, if appropriate and possible, will produce
        /// a mapper with slots names on the output score column properly mapped. If this is not possible for any
        /// reason, it will just return the input bound mapper.
        /// </summary>
        private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(mapper, nameof(mapper));
            env.CheckValueOrNull(trainSchema);

            // The idea is that we will take the key values from the train schema label, and present
            // them as slot name metadata. But there are a number of conditions for this to actually
            // happen, so we test those here. If these are not

            if (trainSchema?.Label == null)
            {
                return(mapper); // We don't even have a label identified in a training schema.
            }
            var keyType = trainSchema.Label.Value.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type as VectorType;

            if (keyType == null || !CanWrap(mapper, keyType))
            {
                return(mapper);
            }

            // Great!! All checks pass.
            return(Utils.MarshalInvoke(WrapCore <int>, keyType.ItemType.RawType, env, mapper, trainSchema));
        }