/// <summary>
        /// Attempt to apply the data transform to a different data view source.
        /// If the transform in question implements <see cref="ITransformTemplate"/>, <see cref="ITransformTemplate.ApplyToData"/>
        /// is called. Otherwise, the transform is serialized into a byte array and then deserialized.
        /// </summary>
        /// <param name="env">The host to use</param>
        /// <param name="transform">The transform to apply.</param>
        /// <param name="newSource">The data view to apply the transform to.</param>
        /// <returns>The resulting data view.</returns>
        public static IDataTransform ApplyTransformToData(IHostEnvironment env, IDataTransform transform, IDataView newSource)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(transform, nameof(transform));
            env.CheckValue(newSource, nameof(newSource));
            var rebindable = transform as ITransformTemplate;

            if (rebindable != null)
            {
                return(rebindable.ApplyToData(env, newSource));
            }

            // Revert to serialization.
            using (var stream = new MemoryStream())
            {
                using (var rep = RepositoryWriter.CreateNew(stream, env))
                {
                    ModelSaveContext.SaveModel(rep, transform, "model");
                    rep.Commit();
                }

                stream.Position = 0;
                using (var rep = RepositoryReader.Open(stream, env))
                {
                    IDataTransform newData;
                    ModelLoadContext.LoadModel <IDataTransform, SignatureLoadDataTransform>(env,
                                                                                            out newData, rep, "model", newSource);
                    return(newData);
                }
            }
        }
Exemple #2
0
        /// <summary>
        /// Train a model on a single example,
        /// </summary>
        /// <typeparam name="TOutput"></typeparam>
        /// <param name="trainerMaker"></param>
        /// <param name="checker"></param>
        private static void TrivialHelper <TOutput>(Func <ITrainerHost, ITrainer <Instances, IPredictorProducing <TOutput> > > trainerMaker, Action <TOutput, TOutput> checker)
        {
            // The following simple instance should result in a "trivial" predictor for binary classification, regression, and multiclass, I think.
            ListInstances instances = new ListInstances();

            instances.AddInst(new Float[] { (Float)0.0 }, (Float)0);
            instances.CopyMetadata(null);
            ITrainerHost host = new TrainHost(new Random(1), 0);

            var trainer = trainerMaker(host);

            trainer.Train(instances);
            IPredictor <Instance, TOutput> predictor       = (IPredictor <Instance, TOutput>)trainer.CreatePredictor();
            IPredictor <Instance, TOutput> loadedPredictor = default(IPredictor <Instance, TOutput>);

            using (Stream stream = new MemoryStream())
            {
                using (RepositoryWriter writer = RepositoryWriter.CreateNew(stream, false))
                {
                    ModelSaveContext.SaveModel(writer, predictor, "foo");
                    writer.Commit();
                }
                stream.Position = 0;
                using (RepositoryReader reader = RepositoryReader.Open(stream, false))
                {
                    ModelLoadContext.LoadModel(out loadedPredictor, reader, "foo");
                }
                Assert.AreNotEqual(default(IPredictor <Instance, TOutput>), loadedPredictor, "did not load expected model");
            }

            TOutput result       = predictor.Predict(instances[0]);
            TOutput loadedResult = loadedPredictor.Predict(instances[0]);

            checker(result, loadedResult);
        }
Exemple #3
0
        public static void LoadModel(IHostEnvironment env, Stream modelStream, bool loadNames, out IPredictor predictor, out RoleMappedSchema schema)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(modelStream, nameof(modelStream));

            schema = null;
            using (var rep = RepositoryReader.Open(modelStream, env))
            {
                ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(env, out predictor, rep, ModelFileUtils.DirPredictor);

                if (loadNames)
                {
                    var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, rep);
                    if (roles != null)
                    {
                        var emptyView = ModelFileUtils.LoadPipeline(env, rep, new MultiFileSource(null));
                        schema = new RoleMappedSchema(emptyView.Schema, roles, opt: true);
                    }
                    else
                    {
                        FeatureNameCollection names;
                        if (ModelFileUtils.TryLoadFeatureNames(out names, rep))
                        {
                            schema = names.Schema;
                        }
                    }
                }
            }
        }
            private ILegacyDataLoader LoadTransformChain(ILegacyDataLoader srcData)
            {
                Host.Assert(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile));

                using (var file = Host.OpenInputFile(ImplOptions.InputModelFile))
                    using (var strm = file.OpenReadStream())
                        using (var rep = RepositoryReader.Open(strm, Host))
                            using (var pipeLoaderEntry = rep.OpenEntry(ModelFileUtils.DirDataLoaderModel, ModelLoadContext.ModelStreamName))
                                using (var ctx = new ModelLoadContext(rep, pipeLoaderEntry, ModelFileUtils.DirDataLoaderModel))
                                    return(LegacyCompositeDataLoader.Create(Host, ctx, srcData, x => true));
            }
Exemple #5
0
        private ILegacyDataLoader CreateLoaderFromBytes(byte[] loaderBytes, IMultiStreamSource files)
        {
            Contracts.CheckValue(loaderBytes, nameof(loaderBytes));
            Contracts.CheckValue(files, nameof(files));

            using (var stream = new MemoryStream(loaderBytes))
                using (var rep = RepositoryReader.Open(stream, _host))
                {
                    return(ModelFileUtils.LoadLoader(_host, rep, files, false));
                }
        }
 private static IDataView LoadTransforms(IHostEnvironment env, IDataView input, string modelFile)
 {
     var view = input;
     using (var file = env.OpenInputFile(modelFile))
     using (var strm = file.OpenReadStream())
     using (var rep = RepositoryReader.Open(strm, env))
     {
         view = ModelFileUtils.LoadTransforms(env, view, rep);
     }
     return view;
 }
            protected ILegacyDataLoader CreateRawLoader(
                Func <IHostEnvironment, IMultiStreamSource, ILegacyDataLoader> defaultLoaderFactory = null,
                string dataFile = null)
            {
                if (string.IsNullOrWhiteSpace(dataFile))
                {
                    dataFile = ImplOptions.DataFile;
                }

                ILegacyDataLoader loader;

                if (!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile) && ImplOptions.Loader == null)
                {
                    // Load the loader from the data model.
                    using (var file = Host.OpenInputFile(ImplOptions.InputModelFile))
                        using (var strm = file.OpenReadStream())
                            using (var rep = RepositoryReader.Open(strm, Host))
                                loader = LoadLoader(rep, dataFile, ImplOptions.LoadTransforms ?? true);
                }
                else
                {
                    // Either there is no input model file, or there is, but the loader is overridden.
                    IMultiStreamSource fileSource = new MultiFileSource(dataFile);
                    var loaderFactory             = ImplOptions.Loader;
                    if (loaderFactory == null)
                    {
                        var ext    = Path.GetExtension(dataFile);
                        var isText =
                            string.Equals(ext, ".txt", StringComparison.OrdinalIgnoreCase) ||
                            string.Equals(ext, ".tlc", StringComparison.OrdinalIgnoreCase);
                        var isBinary    = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase);
                        var isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase);

                        return(isText ? TextLoader.Create(Host, new TextLoader.Options(), fileSource) :
                               isBinary ? new BinaryLoader(Host, new BinaryLoader.Arguments(), fileSource) :
                               isTranspose ? new TransposeLoader(Host, new TransposeLoader.Arguments(), fileSource) :
                               defaultLoaderFactory != null?defaultLoaderFactory(Host, fileSource) :
                                   TextLoader.Create(Host, new TextLoader.Options(), fileSource));
                    }
                    else
                    {
                        loader = loaderFactory.CreateComponent(Host, fileSource);
                    }

                    if (ImplOptions.LoadTransforms == true)
                    {
                        Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile));
                        loader = LoadTransformChain(loader);
                    }
                }
                return(loader);
            }
        /// <summary>
        /// Load a <see cref="ICanForecast{T}"/> model.
        /// </summary>
        /// <typeparam name="T">The type of <see cref="ICanForecast{T}"/>, usually float.</typeparam>
        /// <param name="catalog"><see cref="ModelOperationsCatalog"/></param>
        /// <param name="filePath">File path to load the model from.</param>
        /// <returns><see cref="ICanForecast{T}"/> model.</returns>
        public static ICanForecast <T> LoadForecastingModel <T>(this ModelOperationsCatalog catalog, string filePath)
        {
            var env = CatalogUtils.GetEnvironment(catalog);

            using (var file = File.OpenRead(filePath))
            {
                using (var rep = RepositoryReader.Open(file, env))
                {
                    ModelLoadContext.LoadModel <ICanForecast <T>, SignatureLoadModel>(env, out var model, rep, LoaderSignature);
                    return(model);
                }
            }
        }
Exemple #9
0
        public void LoadOriginalBinaryLoaderModel()
        {
            var env = new MLContext().AddStandardComponents();

            using (var modelStream = File.OpenRead(Path.Combine("TestModels", "BinaryLoader-v3.11.0.0.zip")))
                using (var rep = RepositoryReader.Open(modelStream, env))
                {
                    IDataLoader result = ModelFileUtils.LoadLoader(env, rep, new MultiFileSource(null), true);

                    Assert.Equal(2, result.Schema.Count);
                    Assert.Equal("Image", result.Schema[0].Name);
                    Assert.Equal("Class", result.Schema[1].Name);
                }
        }
Exemple #10
0
        public void TestTextLoaderKeyTypeBackCompat()
        {
            // Model generated with the following command on a version of the code previous to the KeyType change that removed Min and Contiguous:
            // Train data=...\breast-cancer.txt loader =TextLoader{col=Label:R4:0 col=Features:R4:1-9 col=key:U4[0-*]:3} tr=LogisticRegression {} out=model.zip
            var    mlContext           = new MLContext(1);
            string textLoaderModelPath = GetDataPath("backcompat/textloader-with-key-model.zip");
            string breastCancerPath    = GetDataPath(TestDatasets.breastCancer.trainFilename);

            using (FileStream modelfs = File.OpenRead(textLoaderModelPath))
                using (var rep = RepositoryReader.Open(modelfs, mlContext))
                {
                    var result = ModelFileUtils.LoadLoader(mlContext, rep, new MultiFileSource(breastCancerPath), false);
                    Assert.True(result.Schema.TryGetColumnIndex("key", out int featureIdx));
                    Assert.True(result.Schema[featureIdx].Type is KeyDataViewType keyType && keyType.Count == typeof(uint).ToMaxInt());
                }
        }
Exemple #11
0
        public void LoadOldConcatTransformModel()
        {
            var env = new MLContext().AddStandardComponents();

            using (var modelStream = File.OpenRead(Path.Combine("TestModels", "ConcatTransform.zip")))
                using (var rep = RepositoryReader.Open(modelStream, env))
                {
                    var result = ModelFileUtils.LoadPipeline(env, rep, new MultiFileSource(null), true);

                    Assert.Equal(3, result.Schema.Count);
                    Assert.Equal("Label", result.Schema[0].Name);
                    Assert.Equal("Features", result.Schema[1].Name);
                    Assert.Equal("Features", result.Schema[2].Name);
                    Assert.Equal(9, (result.Schema[1].Type as VectorType)?.Size);
                    Assert.Equal(18, (result.Schema[2].Type as VectorType)?.Size);
                }
        }
        private IDataScorerTransform GetScorer(IHostEnvironment env, IDataView transforms, IPredictor pred, string testDataPath = null)
        {
            using (var ch = env.Start("Saving model"))
                using (var memoryStream = new MemoryStream())
                {
                    var trainRoles = new RoleMappedData(transforms, label: "Label", feature: "Features");

                    // Model cannot be saved with CacheDataView
                    TrainUtils.SaveModel(env, ch, memoryStream, pred, trainRoles);
                    memoryStream.Position = 0;
                    using (var rep = RepositoryReader.Open(memoryStream, ch))
                    {
                        IDataLoader    testPipe  = ModelFileUtils.LoadLoader(env, rep, new MultiFileSource(testDataPath), true);
                        RoleMappedData testRoles = new RoleMappedData(testPipe, label: "Label", feature: "Features");
                        return(ScoreUtils.GetScorer(pred, testRoles, env, testRoles.Schema));
                    }
                }
        }
            protected IDataLoader CreateRawLoader(string defaultLoader = "TextLoader", string dataFile = null)
            {
                if (string.IsNullOrWhiteSpace(dataFile))
                {
                    dataFile = Args.DataFile;
                }

                IDataLoader loader;

                if (!string.IsNullOrWhiteSpace(Args.InputModelFile) && !Args.Loader.IsGood())
                {
                    // Load the loader from the data model.
                    using (var file = Host.OpenInputFile(Args.InputModelFile))
                        using (var strm = file.OpenReadStream())
                            using (var rep = RepositoryReader.Open(strm, Host))
                                loader = LoadLoader(rep, dataFile, Args.LoadTransforms ?? true);
                }
                else
                {
                    // Either there is no input model file, or there is, but the loader is overridden.
                    var sub = Args.Loader;
                    if (!sub.IsGood())
                    {
                        var ext    = Path.GetExtension(dataFile);
                        var isText =
                            string.Equals(ext, ".txt", StringComparison.OrdinalIgnoreCase) ||
                            string.Equals(ext, ".tlc", StringComparison.OrdinalIgnoreCase);
                        var isBinary    = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase);
                        var isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase);
                        sub =
                            new SubComponent <IDataLoader, SignatureDataLoader>(
                                isText ? "TextLoader" : isBinary ? "BinaryLoader" : isTranspose ? "TransposeLoader" : defaultLoader);
                    }

                    loader = sub.CreateInstance(Host, new MultiFileSource(dataFile));

                    if (Args.LoadTransforms == true)
                    {
                        Host.CheckUserArg(!string.IsNullOrWhiteSpace(Args.InputModelFile), nameof(Args.InputModelFile));
                        loader = LoadTransformChain(loader);
                    }
                }
                return(loader);
            }
Exemple #14
0
        public static bool TryLoadPredictor(IChannel ch, IHostEnvironment env, string inputModelFile, out IPredictor inputPredictor)
        {
            Contracts.AssertValue(env);
            Contracts.AssertValue(ch);

            if (!string.IsNullOrEmpty(inputModelFile))
            {
                ch.Trace("Constructing predictor from input model");
                using (var file = env.OpenInputFile(inputModelFile))
                    using (var strm = file.OpenReadStream())
                        using (var rep = RepositoryReader.Open(strm, ch))
                        {
                            ch.Trace("Loading predictor");
                            return(ModelLoadContext.LoadModelOrNull <IPredictor, SignatureLoadModel>(env, out inputPredictor, rep, ModelFileUtils.DirPredictor));
                        }
            }

            inputPredictor = null;
            return(false);
        }
        public void TestTextLoaderBackCompat_VerWritt_0x0001000C()
        {
            // Checks backward compatibility with a text loader created with "verWrittenCur: 0x0001000C"
            // Model generated with:
            // loader=text{header+ col=SepalLength:Num:0 col=SepalWidth:Num:1 col=PetalLength:Num:2 col=PetalWidth:Num:2 col=Cat:TX:1-8 col=Num:9-14 col=Type:TX:4}
            var    mlContext           = new MLContext(1);
            string textLoaderModelPath = GetDataPath("backcompat/textloader_VerWritt_0x0001000C.zip");
            string irisPath            = GetDataPath(TestDatasets.irisData.trainFilename);

            IDataView iris;

            using (FileStream modelfs = File.OpenRead(textLoaderModelPath))
                using (var rep = RepositoryReader.Open(modelfs, mlContext))
                {
                    iris = ModelFileUtils.LoadLoader(mlContext, rep, new MultiFileSource(irisPath), false);
                }

            var previewIris  = iris.Preview(1);
            var irisFirstRow = new Dictionary <string, float>();

            irisFirstRow["SepalLength"] = 5.1f;
            irisFirstRow["SepalWidth"]  = 3.5f;
            irisFirstRow["PetalLength"] = 1.4f;
            irisFirstRow["PetalWidth"]  = 0.2f;

            Assert.Equal(5, previewIris.ColumnView.Length);
            Assert.Equal("SepalLength", previewIris.Schema[0].Name);
            Assert.Equal(NumberDataViewType.Single, previewIris.Schema[0].Type);
            int index = 0;

            foreach (var entry in irisFirstRow)
            {
                Assert.Equal(entry.Key, previewIris.RowView[0].Values[index].Key);
                Assert.Equal(entry.Value, previewIris.RowView[0].Values[index++].Value);
            }
            Assert.Equal("Type", previewIris.RowView[0].Values[index].Key);
            Assert.Equal("Iris-setosa", previewIris.RowView[0].Values[index].Value.ToString());
        }
Exemple #16
0
        public LoaderWrapper(IHostEnvironment env, ModelLoadContext ctx)
        {
            ctx.CheckAtModel(GetVersionInfo());
            ctx.LoadModel <IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));

            var loaderStream = new MemoryStream();

            using (var rep = RepositoryWriter.CreateNew(loaderStream))
            {
                ModelSaveContext.SaveModel(rep, loader, "Loader");
                rep.Commit();
            }

            _env           = env;
            _loaderFactory = (IMultiStreamSource source) =>
            {
                using (var rep = RepositoryReader.Open(loaderStream))
                {
                    ModelLoadContext.LoadModel <IDataLoader, SignatureLoadDataLoader>(env, out var ldr, rep, "Loader", source);
                    return(ldr);
                }
            };
        }
Exemple #17
0
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckUserArg(!string.IsNullOrWhiteSpace(args.InputModelFile), nameof(args.InputModelFile), "The input model file is required.");

            IPredictor       predictor;
            RoleMappedSchema trainSchema = null;

            using (var file = env.OpenInputFile(args.InputModelFile))
                using (var strm = file.OpenReadStream())
                    using (var rep = RepositoryReader.Open(strm, env))
                    {
                        ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(env, out predictor, rep, ModelFileUtils.DirPredictor);
                        trainSchema = ModelFileUtils.LoadRoleMappedSchemaOrNull(env, rep);
                    }

            string feat = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema,
                                                              nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema,
                                                               nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId);
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);

            return(ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, trainSchema));
        }
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register("LoadTransform");

            h.CheckValue(args, nameof(args));
            h.CheckValue(input, nameof(input));
            h.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile), "File does not exist");

            IDataView currentView;

            // If there are no 'tag' parameters, we load everything, regardless of 'comp'.
            bool complement = args.Complement || Utils.Size(args.Tag) == 0;
            var  allTags    = new HashSet <string>();

            for (int i = 0; i < Utils.Size(args.Tag); i++)
            {
                var curList = args.Tag[i];
                if (string.IsNullOrWhiteSpace(curList))
                {
                    continue;
                }

                foreach (var tag in curList.Split(','))
                {
                    if (!string.IsNullOrWhiteSpace(tag))
                    {
                        allTags.Add(tag.ToLower());
                    }
                }
            }

            Func <string, bool> predicate =
                tag =>
            {
                bool found = allTags.Contains(tag.ToLower());
                return(found == !complement);
            };

            using (var file = h.OpenInputFile(args.ModelFile))
                using (var strm = file.OpenReadStream())
                    using (var rep = RepositoryReader.Open(strm, h))
                        using (var pipeLoaderEntry = rep.OpenEntry(ModelFileUtils.DirDataLoaderModel, ModelLoadContext.ModelStreamName))
                            using (var ctx = new ModelLoadContext(rep, pipeLoaderEntry, ModelFileUtils.DirDataLoaderModel))
                            {
                                currentView = CompositeDataLoader.LoadSelectedTransforms(ctx, input, h, predicate);

                                if (currentView == input)
                                {
                                    // REVIEW: we are required to return an IDataTransform. Therefore, if we don't introduce a new transform
                                    // on top of 'input', we must throw (since input may not be a data transform).
                                    // We could of course introduce a 'no-op transform', or we could lift the requirement to always return an IDataTransform
                                    // associated with SignatureDataTransform.

                                    var criteria = string.Format(
                                        complement
                                ? "transforms that don't have tags from the list: '{0}'"
                                : "transforms that have tags from the list: '{0}'",
                                        string.Join(",", allTags));
                                    throw h.ExceptUserArg(nameof(args.Tag), "No transforms were found that match the search criteria ({0})", criteria);
                                }
                            }

            h.Assert(currentView is IDataTransform);
            return((IDataTransform)currentView);
        }
        // Factory method for SignatureDataTransform.
        private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Tree Featurizer Transform");

            host.CheckValue(args, nameof(args));
            host.CheckValue(input, nameof(input));
            host.CheckUserArg(!string.IsNullOrWhiteSpace(args.TrainedModelFile) || args.Trainer != null, nameof(args.TrainedModelFile),
                              "Please specify either a trainer or an input model file.");
            host.CheckUserArg(!string.IsNullOrEmpty(args.FeatureColumn), nameof(args.FeatureColumn), "Transform needs an input features column");

            IDataTransform xf;

            using (var ch = host.Start("Create Tree Ensemble Scorer"))
            {
                var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments()
                {
                    Suffix = args.Suffix
                };
                if (!string.IsNullOrWhiteSpace(args.TrainedModelFile))
                {
                    if (args.Trainer != null)
                    {
                        ch.Warning("Both an input model and a trainer were specified. Using the model file.");
                    }

                    ch.Trace("Loading model");
                    IPredictor predictor;
                    using (Stream strm = new FileStream(args.TrainedModelFile, FileMode.Open, FileAccess.Read))
                        using (var rep = RepositoryReader.Open(strm, ch))
                            ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(host, out predictor, rep, ModelFileUtils.DirPredictor);

                    ch.Trace("Creating scorer");
                    var data = TrainAndScoreTransformer.CreateDataFromArgs(ch, input, args);
                    Contracts.Assert(data.Schema.Feature.HasValue);

                    // Make sure that the given predictor has the correct number of input features.
                    if (predictor is CalibratedPredictorBase)
                    {
                        predictor = ((CalibratedPredictorBase)predictor).SubPredictor;
                    }
                    // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should
                    // be non-null.
                    var vm = predictor as IValueMapper;
                    ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type");
                    if (vm.InputType.VectorSize != data.Schema.Feature.Value.Type.VectorSize)
                    {
                        throw ch.ExceptUserArg(nameof(args.TrainedModelFile),
                                               "Predictor in model file expects {0} features, but data has {1} features",
                                               vm.InputType.VectorSize, data.Schema.Feature.Value.Type.VectorSize);
                    }

                    ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
                    var bound = bindable.Bind(env, data.Schema);
                    xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema);
                }
                else
                {
                    ch.AssertValue(args.Trainer);

                    ch.Trace("Creating TrainAndScoreTransform");

                    var trainScoreArgs = new TrainAndScoreTransformer.Arguments();
                    args.CopyTo(trainScoreArgs);
                    trainScoreArgs.Trainer = args.Trainer;

                    trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>(
                        (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema));

                    var mapperFactory = ComponentFactoryUtils.CreateFromFunction <IPredictor, ISchemaBindableMapper>(
                        (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor));

                    var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed);
                    var scoreXf    = TrainAndScoreTransformer.Create(host, trainScoreArgs, labelInput, mapperFactory);

                    if (input == labelInput)
                    {
                        return(scoreXf);
                    }
                    return((IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput));
                }
            }
            return(xf);
        }
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);
            Host.AssertNonEmpty(cmd);

            ch.Trace("Constructing trainer");
            ITrainer trainer = Args.Trainer.CreateInstance(Host);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing the training pipeline");
            IDataView trainPipe = CreateLoader();

            ISchema schema = trainPipe.Schema;
            string  label  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn),
                                                                 Args.LabelColumn, DefaultColumnNames.Label);
            string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn),
                                                                  Args.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn),
                                                               Args.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn),
                                                                Args.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn),
                                                              Args.NameColumn, DefaultColumnNames.Name);

            TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, Args.NormalizeFeatures);

            ch.Trace("Binding columns");
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
            var data       = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols);

            RoleMappedData validData = null;

            if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
            {
                if (!TrainUtils.CanUseValidationData(trainer))
                {
                    ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
                }
                else
                {
                    ch.Trace("Constructing the validation pipeline");
                    IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile);
                    validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe);
                    validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
                }
            }

            var predictor = TrainUtils.Train(Host, ch, data, trainer, _info.LoadNames[0], validData,
                                             Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor);

            IDataLoader testPipe;

            using (var file = !string.IsNullOrEmpty(Args.OutputModelFile) ?
                              Host.CreateOutputFile(Args.OutputModelFile) : Host.CreateTempFile(".zip"))
            {
                TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);

                ch.Trace("Constructing the testing pipeline");
                using (var stream = file.OpenReadStream())
                    using (var rep = RepositoryReader.Open(stream, ch))
                        testPipe = LoadLoader(rep, Args.TestFile, true);
            }

            // Score.
            ch.Trace("Scoring and evaluating");
            IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema);

            // Evaluate.
            var evalComp = Args.Evaluator;

            if (!evalComp.IsGood())
            {
                evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
            }
            var evaluator = evalComp.CreateInstance(Host);
            var dataEval  = new RoleMappedData(scorePipe, label, features,
                                               group, weight, name, customCols, opt: true);
            var metrics = evaluator.Evaluate(dataEval);

            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall))
            {
                throw ch.Except("No overall metrics found");
            }
            overall = evaluator.GetOverallResults(overall);
            MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1);
            evaluator.PrintAdditionalMetrics(ch, metrics);
            Dictionary <string, IDataView>[] metricValues = { metrics };
            SendTelemetryMetric(metricValues);
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInst     = evaluator.GetPerInstanceMetrics(dataEval);
                var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
                var idv         = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
            }
        }
Exemple #21
0
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);
            Host.AssertNonEmpty(cmd);

            ch.Trace("Constructing trainer");
            ITrainer trainer = ImplOptions.Trainer.CreateComponent(Host);

            IPredictor inputPredictor = null;

            if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing the training pipeline");
            IDataView trainPipe = CreateLoader();

            var    schema = trainPipe.Schema;
            string label  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn),
                                                                ImplOptions.LabelColumn, DefaultColumnNames.Label);
            string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn),
                                                                  ImplOptions.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn),
                                                               ImplOptions.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn),
                                                                ImplOptions.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn),
                                                              ImplOptions.NameColumn, DefaultColumnNames.Name);

            TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, ImplOptions.NormalizeFeatures);

            ch.Trace("Binding columns");
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, ImplOptions.CustomColumns);
            var data       = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols);

            RoleMappedData validData = null;

            if (!string.IsNullOrWhiteSpace(ImplOptions.ValidationFile))
            {
                if (!trainer.Info.SupportsValidation)
                {
                    ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
                }
                else
                {
                    ch.Trace("Constructing the validation pipeline");
                    IDataView validPipe = CreateRawLoader(dataFile: ImplOptions.ValidationFile);
                    validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe);
                    validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
                }
            }

            // In addition to the training set, some trainers can accept two data sets, validation set and test set,
            // in training phase. The major difference between validation set and test set is that training process may
            // indirectly use validation set to improve the model but the learned model should totally independent of test set.
            // Similar to validation set, the trainer can report the scores computed using test set.
            RoleMappedData testDataUsedInTrainer = null;

            if (!string.IsNullOrWhiteSpace(ImplOptions.TestFile))
            {
                // In contrast to the if-else block for validation above, we do not throw a warning if test file is provided
                // because this is TrainTest command.
                if (trainer.Info.SupportsTest)
                {
                    ch.Trace("Constructing the test pipeline");
                    IDataView testPipeUsedInTrainer = CreateRawLoader(dataFile: ImplOptions.TestFile);
                    testPipeUsedInTrainer = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, testPipeUsedInTrainer);
                    testDataUsedInTrainer = new RoleMappedData(testPipeUsedInTrainer, data.Schema.GetColumnRoleNames());
                }
            }

            var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
                                             ImplOptions.Calibrator, ImplOptions.MaxCalibrationExamples, ImplOptions.CacheData, inputPredictor, testDataUsedInTrainer);

            ILegacyDataLoader testPipe;
            bool hasOutfile   = !string.IsNullOrEmpty(ImplOptions.OutputModelFile);
            var  tempFilePath = hasOutfile ? null : Path.GetTempFileName();

            using (var file = new SimpleFileHandle(ch, hasOutfile ? ImplOptions.OutputModelFile : tempFilePath, true, !hasOutfile))
            {
                TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);
                ch.Trace("Constructing the testing pipeline");
                using (var stream = file.OpenReadStream())
                    using (var rep = RepositoryReader.Open(stream, ch))
                        testPipe = LoadLoader(rep, ImplOptions.TestFile, true);
            }

            // Score.
            ch.Trace("Scoring and evaluating");
            ch.Assert(ImplOptions.Scorer == null || ImplOptions.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line.");
            IDataScorerTransform scorePipe = ScoreUtils.GetScorer(ImplOptions.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema);

            // Evaluate.
            var evaluator = ImplOptions.Evaluator?.CreateComponent(Host) ??
                            EvaluateUtils.GetEvaluator(Host, scorePipe.Schema);
            var dataEval = new RoleMappedData(scorePipe, label, features,
                                              group, weight, name, customCols, opt: true);
            var metrics = evaluator.Evaluate(dataEval);

            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall))
            {
                throw ch.Except("No overall metrics found");
            }
            overall = evaluator.GetOverallResults(overall);
            MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, 1);
            evaluator.PrintAdditionalMetrics(ch, metrics);
            Dictionary <string, IDataView>[] metricValues = { metrics };
            SendTelemetryMetric(metricValues);
            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile))
            {
                var perInst     = evaluator.GetPerInstanceMetrics(dataEval);
                var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
                var idv         = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, idv);
            }
        }
            /// <summary>
            /// Loads multiple artifacts of interest from the input model file, given the context
            /// established by the command line arguments.
            /// </summary>
            /// <param name="ch">The channel to which to provide output.</param>
            /// <param name="wantPredictor">Whether we want a predictor from the model file. If
            /// <c>false</c> we will not even attempt to load a predictor. If <c>null</c> we will
            /// load the predictor, if present. If <c>true</c> we will load the predictor, or fail
            /// noisily if we cannot.</param>
            /// <param name="predictor">The predictor in the model, or <c>null</c> if
            /// <paramref name="wantPredictor"/> was false, or <paramref name="wantPredictor"/> was
            /// <c>null</c> and no predictor was present.</param>
            /// <param name="wantTrainSchema">Whether we want the training schema. Unlike
            /// <paramref name="wantPredictor"/>, this has no "hard fail if not present" option. If
            /// this is <c>true</c>, it is still possible for <paramref name="trainSchema"/> to remain
            /// <c>null</c> if there were no role mappings, or pipeline.</param>
            /// <param name="trainSchema">The training schema if <paramref name="wantTrainSchema"/>
            /// is true, and there were role mappings stored in the model.</param>
            /// <param name="pipe">The data pipe constructed from the combination of the
            /// model and command line arguments.</param>
            protected void LoadModelObjects(
                IChannel ch,
                bool?wantPredictor, out IPredictor predictor,
                bool wantTrainSchema, out RoleMappedSchema trainSchema,
                out ILegacyDataLoader pipe)
            {
                // First handle the case where there is no input model file.
                // Everything must come from the command line.

                using (var file = Host.OpenInputFile(ImplOptions.InputModelFile))
                    using (var strm = file.OpenReadStream())
                        using (var rep = RepositoryReader.Open(strm, Host))
                        {
                            // First consider loading the predictor.
                            if (wantPredictor == false)
                            {
                                predictor = null;
                            }
                            else
                            {
                                ch.Trace("Loading predictor");
                                predictor = ModelFileUtils.LoadPredictorOrNull(Host, rep);
                                if (wantPredictor == true)
                                {
                                    Host.Check(predictor != null, "Could not load predictor from model file");
                                }
                            }

                            // Next create the loader.
                            var loaderFactory           = ImplOptions.Loader;
                            ILegacyDataLoader trainPipe = null;
                            if (loaderFactory != null)
                            {
                                // The loader is overridden from the command line.
                                pipe = loaderFactory.CreateComponent(Host, new MultiFileSource(ImplOptions.DataFile));
                                if (ImplOptions.LoadTransforms == true)
                                {
                                    Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile));
                                    pipe = LoadTransformChain(pipe);
                                }
                            }
                            else
                            {
                                var loadTrans = ImplOptions.LoadTransforms ?? true;
                                pipe = LoadLoader(rep, ImplOptions.DataFile, loadTrans);
                                if (loadTrans)
                                {
                                    trainPipe = pipe;
                                }
                            }

                            if (Utils.Size(ImplOptions.Transforms) > 0)
                            {
                                pipe = LegacyCompositeDataLoader.Create(Host, pipe, ImplOptions.Transforms);
                            }

                            // Next consider loading the training data's role mapped schema.
                            trainSchema = null;
                            if (wantTrainSchema)
                            {
                                // First try to get the role mappings.
                                var trainRoleMappings = ModelFileUtils.LoadRoleMappingsOrNull(Host, rep);
                                if (trainRoleMappings != null)
                                {
                                    // Next create the training schema. In the event that the loaded pipeline happens
                                    // to be the training pipe, we can just use that. If it differs, then we need to
                                    // load the full pipeline from the model, relying upon the fact that all loaders
                                    // can be loaded with no data at all, to get their schemas.
                                    if (trainPipe == null)
                                    {
                                        trainPipe = ModelFileUtils.LoadLoader(Host, rep, new MultiFileSource(null), loadTransforms: true);
                                    }
                                    trainSchema = new RoleMappedSchema(trainPipe.Schema, trainRoleMappings);
                                }
                                // If the role mappings are null, an alternative would be to fail. However the idea
                                // is that the scorer should always still succeed, although perhaps with reduced
                                // functionality, even when the training schema is null, since not all versions of
                                // TLC models will have the role mappings preserved, I believe. And, we do want to
                                // maintain backwards compatibility.
                            }
                        }
            }