Example #1
0
 public TaggedScoreTransform(IHost host, ModelLoadContext ctx, IDataView input) :
     base(host, ctx, input, LoaderSignature)
 {
     _args = new Arguments();
     _args.Read(ctx, _host);
     ctx.LoadModel <IDataScorerTransform, SignatureLoadDataTransform>(_host, out _scorer, "scorer", input);
     _sourcePipe = Create(_host, _args, input, out _sourceCtx, _scorer);
 }
Example #2
0
 private static void LoadPredictors <TPredictor>(IHostEnvironment env, TPredictor[] predictors, ModelLoadContext ctx)
     where TPredictor : class
 {
     for (int i = 0; i < predictors.Length; i++)
     {
         ctx.LoadModel <TPredictor, SignatureLoadModel>(env, out predictors[i], string.Format(SubPredictorFmt, i));
     }
 }
        /// <summary>
        /// Loads all transforms from the <paramref name="ctx"/> that pass the <paramref name="isTransformTagAccepted"/> test,
        /// applies them sequentially to the <paramref name="srcLoader"/>, and returns the (composite) data loader.
        /// </summary>
        private static IDataLoader LoadTransforms(ModelLoadContext ctx, IDataLoader srcLoader, IHost host, Func <string, bool> isTransformTagAccepted)
        {
            Contracts.AssertValue(host, "host");
            host.AssertValue(srcLoader);
            host.AssertValue(ctx);

            // *** Binary format ***
            // int: sizeof(Float)
            // int: number of transforms
            // foreach transform: (starting from version VersionAddedTags)
            //     string: tag
            //     string: args string

            int cbFloat = ctx.Reader.ReadInt32();

            host.CheckDecode(cbFloat == sizeof(Float));

            int cxf = ctx.Reader.ReadInt32();

            host.CheckDecode(cxf >= 0);

            bool hasTags     = ctx.Header.ModelVerReadable >= VersionAddedTags;
            var  tagData     = new List <KeyValuePair <string, string> >();
            var  acceptedIds = new List <int>();

            for (int i = 0; i < cxf; i++)
            {
                string tag        = "";
                string argsString = null;
                if (hasTags)
                {
                    tag        = ctx.LoadNonEmptyString();
                    argsString = ctx.LoadStringOrNull();
                }
                if (!isTransformTagAccepted(tag))
                {
                    continue;
                }

                acceptedIds.Add(i);
                tagData.Add(new KeyValuePair <string, string>(tag, argsString));
            }

            host.Assert(tagData.Count == acceptedIds.Count);
            if (tagData.Count == 0)
            {
                return(srcLoader);
            }

            return(ApplyTransformsCore(host, srcLoader, tagData.ToArray(),
                                       (h, index, data) =>
            {
                IDataTransform xf;
                ctx.LoadModel <IDataTransform, SignatureLoadDataTransform>(host, out xf,
                                                                           string.Format(TransformDirTemplate, acceptedIds[index]), data);
                return xf;
            }));
        }
Example #4
0
            protected SchemaBindablePipelineEnsemble(IHostEnvironment env, ModelLoadContext ctx, string scoreColumnKind)
                : base(env, ctx, scoreColumnKind)
            {
                // *** Binary format ***
                // <base>
                // The combiner

                ctx.LoadModel <IOutputCombiner <T>, SignatureLoadModel>(Host, out Combiner, "Combiner");
            }
        private RowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, input)
        {
            // *** Binary format ***
            // _mapper

            ctx.LoadModel <IRowMapper, SignatureLoadRowMapper>(host, out _mapper, "Mapper", input.Schema);
            _bindings = new Bindings(input.Schema, this);
            CreateMetadata(_mapper.GetOutputColumns().Select(info => info.Metadata), _bindings, out _md);
        }
        private static void LoadPredictors <TPredictor>(IHostEnvironment env,
                                                        TPredictor[] predictors, out IPredictor reclassPredictor, ModelLoadContext ctx)
            where TPredictor : class
        {
            for (int i = 0; i < predictors.Length; i++)
            {
                ctx.LoadModel <TPredictor, SignatureLoadModel>(env, out predictors[i], string.Format("M2B{0}", i));
            }
            bool doesReclass = ctx.Reader.ReadByte() == 1;

            if (doesReclass)
            {
                ctx.LoadModel <IPredictor, SignatureLoadModel>(env, out reclassPredictor, "Reclassification");
            }
            else
            {
                reclassPredictor = null;
            }
        }
Example #7
0
        public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
        {
            ctx.CheckAtModel(GetVersionInfo());
            int n = ctx.Reader.ReadInt32();

            ctx.LoadModel <IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));

            IDataView data = loader;

            for (int i = 0; i < n; i++)
            {
                var dirName = string.Format(TransformDirTemplate, i);
                ctx.LoadModel <IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
                data = xf;
            }

            _env = env;
            _xf  = data;
        }
        protected SchemaBindablePredictorWrapperBase(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.AssertValue(ctx);

            // *** Binary format ***
            // <nothing>

            ctx.LoadModel <IPredictor, SignatureLoadModel>(env, out Predictor, ModelFileUtils.DirPredictor);
            ScoreType = GetScoreType(Predictor, out ValueMapper);
        }
        /// <summary>
        /// Loads all transforms from the <paramref name="ctx"/> that pass the <paramref name="isTransformTagAccepted"/> test,
        /// applies them sequentially to the <paramref name="srcView"/>, and returns the resulting data view.
        /// If there are no transforms in <paramref name="ctx"/> that are accepted, returns the original <paramref name="srcView"/>.
        /// The difference from the <c>Create</c> method above is that:
        /// - it doesn't wrap the results into a loader, just returns the last transform in the chain.
        /// - it accepts <see cref="IDataView"/> as input, not necessarily a loader.
        /// - it throws away the tag information.
        /// - it doesn't throw if the context is not representing a <see cref="LegacyCompositeDataLoader"/>: in this case it's assumed that no transforms
        ///   meet the test, and the <paramref name="srcView"/> is returned.
        /// Essentially, this is a helper method for the LoadTransform class.
        /// </summary>
        public static IDataView LoadSelectedTransforms(ModelLoadContext ctx, IDataView srcView, IHostEnvironment env, Func<string, bool> isTransformTagAccepted)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);

            h.CheckValue(ctx, nameof(ctx));
            h.Check(ctx.Reader.BaseStream.Position == ctx.FpMin + ctx.Header.FpModel);
            var ver = GetVersionInfo();
            if (ctx.Header.ModelSignature != ver.ModelSignature)
            {
                using (var ch = h.Start("ModelCheck"))
                {
                    ch.Info("The data model doesn't contain transforms.");
                }
                return srcView;
            }
            ModelHeader.CheckVersionInfo(ref ctx.Header, ver);

            h.CheckValue(srcView, nameof(srcView));
            h.CheckValue(isTransformTagAccepted, nameof(isTransformTagAccepted));

            // *** Binary format ***
            // int: sizeof(Float)
            // int: number of transforms
            // foreach transform: (starting from version VersionAddedTags)
            //     string: tag
            //     string: args string

            int cbFloat = ctx.Reader.ReadInt32();
            h.CheckDecode(cbFloat == sizeof(float));

            int cxf = ctx.Reader.ReadInt32();
            h.CheckDecode(cxf >= 0);

            bool hasTags = ctx.Header.ModelVerReadable >= VersionAddedTags;
            var curView = srcView;
            for (int i = 0; i < cxf; i++)
            {
                string tag = "";
                if (hasTags)
                {
                    tag = ctx.LoadNonEmptyString();
                    ctx.LoadStringOrNull(); // ignore the args string
                }
                if (!isTransformTagAccepted(tag))
                    continue;

                IDataTransform xf;
                ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out xf,
                    string.Format(TransformDirTemplate, i), curView);
                curView = xf;
            }

            return curView;
        }
Example #10
0
        public TreeEnsembleFeaturizerBindableMapper(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(LoaderSignature);
            _host.AssertValue(ctx);

            // *** Binary format ***
            // ensemble

            ctx.LoadModel <FastTreePredictionWrapper, SignatureLoadModel>(env, out _ensemble, "Ensemble");
            _totalLeafCount = CountLeaves(_ensemble);
        }
        // Factory method for SignatureLoadModel.
        private protected CalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx, string loaderSignature)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CalibratorTransformer <TICalibrator>)))
        {
            Contracts.AssertValue(ctx);

            _loaderSignature = loaderSignature;
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // model: _calibrator
            ctx.LoadModel <TICalibrator, SignatureLoadModel>(env, out _calibrator, "Calibrator");
        }
Example #12
0
            public Transformer(IHostEnvironment env, ModelLoadContext ctx)
            {
                Contracts.CheckValue(env, nameof(env));
                _host = env.Register(nameof(Transformer));
                _host.CheckValue(ctx, nameof(ctx));

                ctx.CheckAtModel(GetVersionInfo());
                int n = ctx.Reader.ReadInt32();

                ctx.LoadModel <IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));

                IDataView data = loader;

                for (int i = 0; i < n; i++)
                {
                    var dirName = string.Format(TransformDirTemplate, i);
                    ctx.LoadModel <IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
                    data = xf;
                }

                _xf = data;
            }
Example #13
0
        /// <summary>
        /// Load a <see cref="ICanForecast{T}"/> model.
        /// </summary>
        /// <typeparam name="T">The type of <see cref="ICanForecast{T}"/>, usually float.</typeparam>
        /// <param name="catalog"><see cref="ModelOperationsCatalog"/></param>
        /// <param name="filePath">File path to load the model from.</param>
        /// <returns><see cref="ICanForecast{T}"/> model.</returns>
        public static ICanForecast <T> LoadForecastingModel <T>(this ModelOperationsCatalog catalog, string filePath)
        {
            var env = CatalogUtils.GetEnvironment(catalog);

            using (var file = File.OpenRead(filePath))
            {
                using (var rep = RepositoryReader.Open(file, env))
                {
                    ModelLoadContext.LoadModel <ICanForecast <T>, SignatureLoadModel>(env, out var model, rep, LoaderSignature);
                    return(model);
                }
            }
        }
            public MultiCountTable(IHostEnvironment env, ModelLoadContext ctx)
                : base(env, LoaderSignature)
            {
                Host.CheckValue(ctx, nameof(ctx));

                ctx.CheckAtModel(GetVersionInfo());

                // *** Binary format ***
                // int: ColCount
                // int[]: SlotCount
                // count table (in a separate folder)

                ColCount  = ctx.Reader.ReadInt32();
                SlotCount = ctx.Reader.ReadIntArray(ColCount);
                ctx.LoadModel <CountTableBase, SignatureLoadModel>(Host, out _baseTable, "BaseTable");
            }
Example #15
0
        private ChainTransform(IHost host, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(host, "host");
            Contracts.CheckValue(input, "input");
            _host  = host;
            _input = input;
            _host.CheckValue(input, "input");
            _host.CheckValue(ctx, "ctx");
            int nb = ctx.Reader.ReadInt32();

            _host.Check(nb > 0, "nb");
            _dataTransforms = new IDataTransform[nb];
            for (int i = 0; i < _dataTransforms.Length; ++i)
            {
                ctx.LoadModel <IDataTransform, SignatureLoadDataTransform>(host, out _dataTransforms[i], string.Format("XF{0}", i),
                                                                           i == 0 ? input : _dataTransforms[i - 1]);
            }
        }
Example #16
0
        private protected BaseStacking(IHostEnvironment env, string name, ModelLoadContext ctx)
        {
            Contracts.AssertValue(env);
            env.AssertNonWhiteSpace(name);
            Host = env.Register(name);
            Host.AssertValue(ctx);

            // *** Binary format ***
            // int: sizeof(Single)
            // Float: _validationDatasetProportion
            int cbFloat = ctx.Reader.ReadInt32();

            env.CheckDecode(cbFloat == sizeof(Single));
            ValidationDatasetProportion = ctx.Reader.ReadFloat();
            env.CheckDecode(0 <= ValidationDatasetProportion && ValidationDatasetProportion < 1);

            ctx.LoadModel <IPredictorProducing <TOutput>, SignatureLoadModel>(env, out Meta, "MetaPredictor");
            CheckMeta();
        }
Example #17
0
            private LabelNameBindableMapper(IHost host, ModelLoadContext ctx)
            {
                Contracts.AssertValue(host);
                _host = host;
                _host.AssertValue(ctx);

                ctx.LoadModel <ISchemaBindableMapper, SignatureLoadModel>(_host, out _bindable, _innerDir);
                BinarySaver saver = new BinarySaver(_host, new BinarySaver.Arguments());
                ColumnType  type;
                object      value;

                _host.CheckDecode(saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out type, out value));
                _host.CheckDecode(type.IsVector);
                _host.CheckDecode(value != null);
                _type         = type.AsVector;
                _getter       = Utils.MarshalInvoke(DecodeInit <int>, _type.ItemType.RawType, value);
                _metadataKind = ctx.Header.ModelVerReadable >= VersionAddedMetadataKind?
                                ctx.LoadNonEmptyString() : MetadataUtils.Kinds.SlotNames;
            }
        /// <summary>
        /// Loads the entire composite data loader (loader + transforms) from the context.
        /// If there are no transforms, the underlying loader is returned.
        /// </summary>
        public static IDataLoader Create(IHostEnvironment env, ModelLoadContext ctx, IMultiStreamSource files)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);

            h.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            h.CheckValue(files, nameof(files));

            using (var ch = h.Start("Components"))
            {
                // First, load the loader.
                IDataLoader loader;
                ctx.LoadModel <IDataLoader, SignatureLoadDataLoader>(h, out loader, "Loader", files);

                // Now the transforms.
                h.Assert(!(loader is CompositeDataLoader));
                return(LoadTransforms(ctx, loader, h, x => true));
            }
        }
Example #19
0
        internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
        {
            Host = host;

            ctx.LoadModel <TModel, SignatureLoadModel>(host, out TModel model, DirModel);
            Model = model;

            // *** Binary format ***
            // model: prediction model.
            // stream: empty data view that contains train schema.
            // id of string: feature column.

            // Clone the stream with the schema into memory.
            var ms = new MemoryStream();

            ctx.TryLoadBinaryStream(DirTransSchema, reader =>
            {
                reader.BaseStream.CopyTo(ms);
            });

            ms.Position = 0;
            var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms);

            TrainSchema = loader.Schema;

            FeatureColumn = ctx.LoadStringOrNull();
            if (FeatureColumn == null)
            {
                FeatureColumnType = null;
            }
            else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
            {
                throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn);
            }
            else
            {
                FeatureColumnType = TrainSchema.GetColumnType(col);
            }

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
        }
Example #20
0
            public SsaForecastingBase(IHostEnvironment env, ModelLoadContext ctx, string name) : base(env, ctx, name)
            {
                // *** Binary format ***
                // <base>
                // bool: _isAdaptive
                // int32: Horizon
                // bool: ComputeConfidenceIntervals
                // State: StateRef
                // AdaptiveSingularSpectrumSequenceModeler: _model

                Host.CheckDecode(InitialWindowSize == 0);

                IsAdaptive      = ctx.Reader.ReadBoolean();
                Horizon         = ctx.Reader.ReadInt32();
                ConfidenceLevel = ctx.Reader.ReadFloat();
                StateRef        = new State(ctx.Reader);

                ctx.LoadModel <SequenceModelerBase <Single, Single>, SignatureLoadModel>(env, out Model, "SSA");
                Host.CheckDecode(Model != null);
                StateRef.InitState(this, Host);
            }
Example #21
0
        // Factory method for SignatureLoadModel.
        private protected CalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx, string loaderSignature)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CalibratorTransformer <TICalibrator>)))
        {
            Contracts.AssertValue(ctx);

            _loaderSignature = loaderSignature;
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // model: _calibrator
            // scoreColumnName: _scoreColumnName
            ctx.LoadModel <TICalibrator, SignatureLoadModel>(env, out _calibrator, "Calibrator");
            if (ctx.Header.ModelVerWritten >= 0x00010002)
            {
                _scoreColumnName = ctx.LoadString();
            }
            else
            {
                _scoreColumnName = DefaultColumnNames.Score;
            }
        }
Example #22
0
        public TaggedPredictTransform(IHost host, ModelLoadContext ctx, IDataView input) :
            base(host, ctx, input, LoaderSignature)
        {
            _args = new Arguments();
            _args.Read(ctx, _host);
            byte b = ctx.Reader.ReadByte();

            if (b != 177)
            {
                throw _host.Except("Corrupt file.");
            }
            if (_args.serialize)
            {
                ctx.LoadModel <IPredictor, SignatureLoadModel>(_host, out _predictor, "predictor");
            }
            else
            {
                _predictor = null;
            }
            _sourcePipe = Create(_host, _args, input, out _input, _predictor);
        }
            public BindableMapper(IHostEnvironment env, ModelLoadContext ctx)
            {
                Contracts.CheckValue(env, nameof(env));
                _env = env;
                _env.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel(GetVersionInfo());

                // *** Binary format ***
                // int: topContributionsCount
                // int: bottomContributionsCount
                // bool: normalize
                // bool: stringify
                ctx.LoadModel <IFeatureContributionMapper, SignatureLoadModel>(env, out Predictor, ModelFileUtils.DirPredictor);
                GenericMapper          = ScoreUtils.GetSchemaBindableMapper(_env, Predictor, null);
                _topContributionsCount = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(0 < _topContributionsCount && _topContributionsCount <= MaxTopBottom);
                _bottomContributionsCount = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(0 < _bottomContributionsCount && _bottomContributionsCount <= MaxTopBottom);
                _normalize = ctx.Reader.ReadBoolByte();
                Stringify  = ctx.Reader.ReadBoolByte();
            }
        /// <summary>
        /// The loading constructor of transformer chain. Reverse of <see cref="Save(ModelSaveContext)"/>.
        /// </summary>
        internal TransformerChain(IHostEnvironment env, ModelLoadContext ctx)
        {
            int len = ctx.Reader.ReadInt32();

            _transformers = new ITransformer[len];
            _scopes       = new TransformerScope[len];
            for (int i = 0; i < len; i++)
            {
                _scopes[i] = (TransformerScope)(ctx.Reader.ReadInt32());
                var dirName = string.Format(TransformDirTemplate, i);
                ctx.LoadModel <ITransformer, SignatureLoadModel>(env, out _transformers[i], dirName);
            }
            if (len > 0)
            {
                LastTransformer = _transformers[len - 1] as TLastTransformer;
            }
            else
            {
                LastTransformer = null;
            }
        }
Example #25
0
        private DeTrendTransform(IHost host, ModelLoadContext ctx, IDataView input) :
            base(host, input)
        {
            Host.CheckValue(input, "input");
            Host.CheckValue(ctx, "ctx");
            _args = new Arguments();
            _args.Read(ctx, Host);

            ctx.LoadModel <IPredictor, SignatureLoadModel>(host, out _trend, "trend");

            if (_args.columns == null || _args.columns.Length != 1)
            {
                Host.ExceptUserArg(nameof(_args.columns), "One column must be specified.");
            }
            int index = SchemaHelper.GetColumnIndex(input.Schema, _args.columns[0].Source);

            _schema = Schema.Create(new ExtendedSchema(input.Schema,
                                                       new[] { _args.columns[0].Name },
                                                       new[] { NumberType.R4 /*input.Schema.GetColumnType(index)*/ }));
            _lock      = new object();
            _transform = BuildTransform(_trend);
        }
Example #26
0
        public LoaderWrapper(IHostEnvironment env, ModelLoadContext ctx)
        {
            ctx.CheckAtModel(GetVersionInfo());
            ctx.LoadModel <IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));

            var loaderStream = new MemoryStream();

            using (var rep = RepositoryWriter.CreateNew(loaderStream))
            {
                ModelSaveContext.SaveModel(rep, loader, "Loader");
                rep.Commit();
            }

            _env           = env;
            _loaderFactory = (IMultiStreamSource source) =>
            {
                using (var rep = RepositoryReader.Open(loaderStream))
                {
                    ModelLoadContext.LoadModel <IDataLoader, SignatureLoadDataLoader>(env, out var ldr, rep, "Loader", source);
                    return(ldr);
                }
            };
        }
Example #27
0
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckUserArg(!string.IsNullOrWhiteSpace(args.InputModelFile), nameof(args.InputModelFile), "The input model file is required.");

            IPredictor       predictor;
            RoleMappedSchema trainSchema = null;

            using (var file = env.OpenInputFile(args.InputModelFile))
                using (var strm = file.OpenReadStream())
                    using (var rep = RepositoryReader.Open(strm, env))
                    {
                        ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(env, out predictor, rep, ModelFileUtils.DirPredictor);
                        trainSchema = ModelFileUtils.LoadRoleMappedSchemaOrNull(env, rep);
                    }

            string feat = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema,
                                                              nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema,
                                                               nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId);
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);

            return(ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, trainSchema));
        }
        // Factory method for SignatureDataTransform.
        private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Tree Featurizer Transform");

            host.CheckValue(args, nameof(args));
            host.CheckValue(input, nameof(input));
            host.CheckUserArg(!string.IsNullOrWhiteSpace(args.TrainedModelFile) || args.Trainer != null, nameof(args.TrainedModelFile),
                              "Please specify either a trainer or an input model file.");
            host.CheckUserArg(!string.IsNullOrEmpty(args.FeatureColumn), nameof(args.FeatureColumn), "Transform needs an input features column");

            IDataTransform xf;

            using (var ch = host.Start("Create Tree Ensemble Scorer"))
            {
                var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments()
                {
                    Suffix = args.Suffix
                };
                if (!string.IsNullOrWhiteSpace(args.TrainedModelFile))
                {
                    if (args.Trainer != null)
                    {
                        ch.Warning("Both an input model and a trainer were specified. Using the model file.");
                    }

                    ch.Trace("Loading model");
                    IPredictor predictor;
                    using (Stream strm = new FileStream(args.TrainedModelFile, FileMode.Open, FileAccess.Read))
                        using (var rep = RepositoryReader.Open(strm, ch))
                            ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(host, out predictor, rep, ModelFileUtils.DirPredictor);

                    ch.Trace("Creating scorer");
                    var data = TrainAndScoreTransformer.CreateDataFromArgs(ch, input, args);
                    Contracts.Assert(data.Schema.Feature.HasValue);

                    // Make sure that the given predictor has the correct number of input features.
                    if (predictor is CalibratedPredictorBase)
                    {
                        predictor = ((CalibratedPredictorBase)predictor).SubPredictor;
                    }
                    // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should
                    // be non-null.
                    var vm = predictor as IValueMapper;
                    ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type");
                    if (vm.InputType.VectorSize != data.Schema.Feature.Value.Type.VectorSize)
                    {
                        throw ch.ExceptUserArg(nameof(args.TrainedModelFile),
                                               "Predictor in model file expects {0} features, but data has {1} features",
                                               vm.InputType.VectorSize, data.Schema.Feature.Value.Type.VectorSize);
                    }

                    ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
                    var bound = bindable.Bind(env, data.Schema);
                    xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema);
                }
                else
                {
                    ch.AssertValue(args.Trainer);

                    ch.Trace("Creating TrainAndScoreTransform");

                    var trainScoreArgs = new TrainAndScoreTransformer.Arguments();
                    args.CopyTo(trainScoreArgs);
                    trainScoreArgs.Trainer = args.Trainer;

                    trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>(
                        (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema));

                    var mapperFactory = ComponentFactoryUtils.CreateFromFunction <IPredictor, ISchemaBindableMapper>(
                        (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor));

                    var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed);
                    var scoreXf    = TrainAndScoreTransformer.Create(host, trainScoreArgs, labelInput, mapperFactory);

                    if (input == labelInput)
                    {
                        return(scoreXf);
                    }
                    return((IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput));
                }
            }
            return(xf);
        }
 private WrappedPredictorWithNoDistInterface(IHostEnvironment env, ModelLoadContext ctx)
 {
     ctx.LoadModel <IPredictor, SignatureLoadModel>(env, out _predictor, "predictor");
     Contracts.CheckValue(_predictor, "_predictor");
 }
Example #30
0
 private protected RowToRowScorerBase(IHost host, ModelLoadContext ctx, IDataView input)
     : base(host, input)
 {
     ctx.LoadModel <ISchemaBindableMapper, SignatureLoadModel>(host, out Bindable, "SchemaBindableMapper");
 }