예제 #1
0
        /// <summary>
        /// Create a composite schema of both the partitioned columns and the underlying loader columns.
        /// </summary>
        /// <param name="ectx">The exception context.</param>
        /// <param name="cols">The partitioned columns.</param>
        /// <param name="subLoader">The sub loader.</param>
        /// <returns>The resulting schema.</returns>
        private DataViewSchema CreateSchema(IExceptionContext ectx, Column[] cols, ILegacyDataLoader subLoader)
        {
            Contracts.AssertValue(cols);
            Contracts.AssertValue(subLoader);

            var builder = new DataViewSchema.Builder();

            builder.AddColumns(cols.Select(c => new DataViewSchema.DetachedColumn(c.Name, ColumnTypeExtensions.PrimitiveTypeFromKind(c.Type.Value), null)));
            var colSchema = builder.ToSchema();

            var subSchema = subLoader.Schema;

            if (subSchema.Count == 0)
            {
                return(colSchema);
            }
            else
            {
                var schemas = new DataViewSchema[]
                {
                    subSchema,
                    colSchema
                };

                return(new ZipBinding(schemas).OutputSchema);
            }
        }
예제 #2
0
        private void RunCore(IChannel ch)
        {
            Host.AssertValue(ch, "ch");
            IDataSaver saver;

            if (ImplOptions.Saver == null)
            {
                var ext         = Path.GetExtension(ImplOptions.OutputDataFile);
                var isBinary    = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase);
                var isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase);
                if (isBinary)
                {
                    saver = new BinarySaver(Host, new BinarySaver.Arguments());
                }
                else if (isTranspose)
                {
                    saver = new TransposeSaver(Host, new TransposeSaver.Arguments());
                }
                else
                {
                    saver = new TextSaver(Host, new TextSaver.Arguments());
                }
            }
            else
            {
                saver = ImplOptions.Saver.CreateComponent(Host);
            }

            ILegacyDataLoader loader = CreateAndSaveLoader();

            using (var file = Host.CreateOutputFile(ImplOptions.OutputDataFile))
                DataSaverUtils.SaveDataView(ch, saver, loader, file, ImplOptions.KeepHidden);
        }
        private LegacyCompositeDataLoader(IHost host, TransformEx[] transforms)
        {
            Contracts.AssertValue(host, "host");
            _host = host;
            _host.AssertNonEmpty(transforms);

            View   = transforms[transforms.Length - 1].Transform;
            _tview = View as ITransposeDataView;
            var srcLoader = transforms[0].Transform.Source as ILegacyDataLoader;

#if DEBUG
            // Assert that the transforms array is consistent: first one starts with loader,
            // they are chained together, the loader is not a composite.
            for (int i = 1; i < transforms.Length; i++)
            {
                _host.Assert(transforms[i].Transform.Source == transforms[i - 1].Transform, "Transforms are not linked");
            }

            _host.AssertValue(srcLoader, "loader", "Transform chain doesn't start with a loader");
            _host.Assert(!(srcLoader is LegacyCompositeDataLoader), "Can't have composite source loader");
#endif

            _loader     = srcLoader;
            _transforms = transforms;
        }
        /// <summary>
        /// Loads all transforms from the <paramref name="ctx"/> that pass the <paramref name="isTransformTagAccepted"/> test,
        /// applies them sequentially to the <paramref name="srcLoader"/>, and returns the (composite) data loader.
        /// </summary>
        private static ILegacyDataLoader LoadTransforms(ModelLoadContext ctx, ILegacyDataLoader srcLoader, IHost host, Func <string, bool> isTransformTagAccepted)
        {
            Contracts.AssertValue(host, "host");
            host.AssertValue(srcLoader);
            host.AssertValue(ctx);

            // *** Binary format ***
            // int: sizeof(Float)
            // int: number of transforms
            // foreach transform: (starting from version VersionAddedTags)
            //     string: tag
            //     string: args string

            int cbFloat = ctx.Reader.ReadInt32();

            host.CheckDecode(cbFloat == sizeof(float));

            int cxf = ctx.Reader.ReadInt32();

            host.CheckDecode(cxf >= 0);

            bool hasTags     = ctx.Header.ModelVerReadable >= VersionAddedTags;
            var  tagData     = new List <KeyValuePair <string, string> >();
            var  acceptedIds = new List <int>();

            for (int i = 0; i < cxf; i++)
            {
                string tag        = "";
                string argsString = null;
                if (hasTags)
                {
                    tag        = ctx.LoadNonEmptyString();
                    argsString = ctx.LoadStringOrNull();
                }
                if (!isTransformTagAccepted(tag))
                {
                    continue;
                }

                acceptedIds.Add(i);
                tagData.Add(new KeyValuePair <string, string>(tag, argsString));
            }

            host.Assert(tagData.Count == acceptedIds.Count);
            if (tagData.Count == 0)
            {
                return(srcLoader);
            }

            return(ApplyTransformsCore(host, srcLoader, tagData.ToArray(),
                                       (h, index, data) =>
            {
                IDataTransform xf;
                ctx.LoadModel <IDataTransform, SignatureLoadDataTransform>(host, out xf,
                                                                           string.Format(TransformDirTemplate, acceptedIds[index]), data);
                return xf;
            }));
        }
 /// <param name="env">The environment.</param>
 /// <param name="registrationName">The registration name.</param>
 /// <param name="inputDataView">The input data view.</param>
 /// <param name="splitColumn">The column to use for splitting data into folds.</param>
 /// <param name="args">Cross validation arguments.</param>
 /// <param name="createExamples">The delegate to create RoleMappedData</param>
 /// <param name="applyTransformsToTestData">The delegate to apply the transforms from the train pipeline to the test data</param>
 /// <param name="scorer">The scorer</param>
 /// <param name="evaluator">The evaluator</param>
 /// <param name="getValidationDataView">The delegate to create validation data view</param>
 /// <param name="applyTransformsToValidationData">The delegate to apply the transforms from the train pipeline to the validation data</param>
 /// <param name="inputPredictor">The input predictor, for the continue training option</param>
 /// <param name="cmd">The command string.</param>
 /// <param name="loader">Original loader so we can construct correct pipeline for model saving.</param>
 /// <param name="savePerInstance">Whether to produce the per-instance data view.</param>
 /// <returns></returns>
 public FoldHelper(
     IHostEnvironment env,
     string registrationName,
     IDataView inputDataView,
     string splitColumn,
     Arguments args,
     Func <IHostEnvironment, IChannel, IDataView, ITrainer, RoleMappedData> createExamples,
     Func <IHostEnvironment, IChannel, IDataView, RoleMappedData, IDataView, RoleMappedData> applyTransformsToTestData,
     IComponentFactory <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform> scorer,
     IComponentFactory <IMamlEvaluator> evaluator,
     Func <IDataView> getValidationDataView = null,
     Func <IHostEnvironment, IChannel, IDataView, RoleMappedData, IDataView, RoleMappedData> applyTransformsToValidationData = null,
     IPredictor inputPredictor = null,
     string cmd = null,
     ILegacyDataLoader loader = null,
     bool savePerInstance     = false)
 {
     Contracts.CheckValue(env, nameof(env));
     env.CheckNonWhiteSpace(registrationName, nameof(registrationName));
     env.CheckValue(inputDataView, nameof(inputDataView));
     env.CheckValue(splitColumn, nameof(splitColumn));
     env.CheckParam(args.NumFolds > 1, nameof(args.NumFolds));
     env.CheckValue(createExamples, nameof(createExamples));
     env.CheckValue(applyTransformsToTestData, nameof(applyTransformsToTestData));
     env.CheckValue(args.Trainer, nameof(args.Trainer));
     env.CheckValueOrNull(scorer);
     env.CheckValueOrNull(evaluator);
     env.CheckValueOrNull(args.Calibrator);
     env.CheckParam(args.MaxCalibrationExamples > 0, nameof(args.MaxCalibrationExamples));
     env.CheckParam(getValidationDataView == null || applyTransformsToValidationData != null, nameof(applyTransformsToValidationData));
     env.CheckValueOrNull(inputPredictor);
     env.CheckValueOrNull(cmd);
     env.CheckValueOrNull(args.OutputModelFile);
     env.CheckValueOrNull(loader);
     _env = env;
     _registrationName          = registrationName;
     _inputDataView             = inputDataView;
     _splitColumn               = splitColumn;
     _numFolds                  = args.NumFolds;
     _createExamples            = createExamples;
     _applyTransformsToTestData = applyTransformsToTestData;
     _trainer                         = args.Trainer;
     _scorer                          = scorer;
     _evaluator                       = evaluator;
     _calibrator                      = args.Calibrator;
     _maxCalibrationExamples          = args.MaxCalibrationExamples;
     _useThreads                      = args.UseThreads;
     _cacheData                       = args.CacheData;
     _getValidationDataView           = getValidationDataView;
     _applyTransformsToValidationData = applyTransformsToValidationData;
     _inputPredictor                  = inputPredictor;
     _cmd             = cmd;
     _outputModelFile = args.OutputModelFile;
     _loader          = loader;
     _savePerInstance = savePerInstance;
 }
예제 #6
0
        private byte[] SaveLoaderToBytes(ILegacyDataLoader loader)
        {
            Contracts.CheckValue(loader, nameof(loader));

            using (var stream = new MemoryStream())
            {
                LoaderUtils.SaveLoader(loader, stream);
                return(stream.GetBuffer());
            }
        }
        /// <summary>
        /// Creates a <see cref="LegacyCompositeDataLoader"/> that starts with the <paramref name="srcLoader"/>,
        /// and follows with transforms created from the <paramref name="transformArgs"/> array.
        /// If there are no transforms, the <paramref name="srcLoader"/> is returned.
        /// </summary>
        public static ILegacyDataLoader Create(IHostEnvironment env, ILegacyDataLoader srcLoader,
                                               params KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >[] transformArgs)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);

            h.CheckValue(srcLoader, nameof(srcLoader));
            h.CheckValueOrNull(transformArgs);
            return(CreateCore(h, srcLoader, transformArgs));
        }
        private static ILegacyDataLoader CreateCore(IHost host, ILegacyDataLoader srcLoader,
                                                    KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >[] transformArgs)
        {
            Contracts.AssertValue(host, "host");
            host.AssertValue(srcLoader, "srcLoader");
            host.AssertValueOrNull(transformArgs);

            if (Utils.Size(transformArgs) == 0)
            {
                return(srcLoader);
            }
            private ILegacyDataLoader LoadTransformChain(ILegacyDataLoader srcData)
            {
                Host.Assert(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile));

                using (var file = Host.OpenInputFile(ImplOptions.InputModelFile))
                    using (var strm = file.OpenReadStream())
                        using (var rep = RepositoryReader.Open(strm, Host))
                            using (var pipeLoaderEntry = rep.OpenEntry(ModelFileUtils.DirDataLoaderModel, ModelLoadContext.ModelStreamName))
                                using (var ctx = new ModelLoadContext(rep, pipeLoaderEntry, ModelFileUtils.DirDataLoaderModel))
                                    return(LegacyCompositeDataLoader.Create(Host, ctx, srcData, x => true));
            }
        /// <summary>
        /// Saves <paramref name="loader"/> to the specified <paramref name="file"/>.
        /// </summary>
        public static void SaveLoader(ILegacyDataLoader loader, IFileHandle file)
        {
            Contracts.CheckValue(loader, nameof(loader));
            Contracts.CheckValue(file, nameof(file));
            Contracts.CheckParam(file.CanWrite, nameof(file), "Must be writable");

            using (var stream = file.CreateWriteStream())
            {
                SaveLoader(loader, stream);
            }
        }
예제 #11
0
        private void RunCore(IChannel ch)
        {
            ILegacyDataLoader loader = CreateAndSaveLoader();

            using (var schemaWriter = new StringWriter())
            {
                RunOnData(schemaWriter, ImplOptions, loader);
                var str = schemaWriter.ToString();
                ch.AssertNonEmpty(str);
                ch.Info(str);
            }
        }
        /// <summary>
        /// Saves <paramref name="loader"/> to the specified <paramref name="stream"/>.
        /// </summary>
        public static void SaveLoader(ILegacyDataLoader loader, Stream stream)
        {
            Contracts.CheckValue(loader, nameof(loader));
            Contracts.CheckValue(stream, nameof(stream));
            Contracts.CheckParam(stream.CanWrite, nameof(stream), "Must be writable");

            using (var rep = RepositoryWriter.CreateNew(stream))
            {
                ModelSaveContext.SaveModel(rep, loader, ModelFileUtils.DirDataLoaderModel);
                rep.Commit();
            }
        }
        /// <summary>
        /// Appends transforms to the <paramref name="srcLoader"/> and returns a loader that contains these new transforms.
        /// If there are no transforms to append, returns <paramref name="srcLoader"/> intact, otherwise creates a
        /// <see cref="LegacyCompositeDataLoader"/>. The transforms are created by sequentially invoking the provided lambda,
        /// one time for each element of <paramref name="tagData"/>.
        /// </summary>
        /// <param name="env">The host environment.</param>
        /// <param name="srcLoader">The source loader.</param>
        /// <param name="tagData">The array of (tag, creationInfo) pairs. Can be an empty array or null, in which case
        /// the function returns <paramref name="srcLoader"/>.</param>
        /// <param name="createTransform">The delegate to invoke at each transform creation.
        /// Delegate parameters are: host environment, transform index (0 to <c>tagData.Length</c>), source data view.
        /// It should return the <see cref="IDataView"/> that should share the same loader as the source data view.</param>
        /// <returns>The resulting data loader.</returns>
        public static ILegacyDataLoader ApplyTransforms(IHostEnvironment env, ILegacyDataLoader srcLoader,
            KeyValuePair<string, string>[] tagData, Func<IHostEnvironment, int, IDataView, IDataView> createTransform)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);

            h.CheckValue(srcLoader, nameof(srcLoader));
            h.CheckValueOrNull(tagData);
            h.CheckValue(createTransform, nameof(createTransform));
            if (Utils.Size(tagData) == 0)
                return srcLoader;
            return ApplyTransformsCore(h, srcLoader, tagData, createTransform);
        }
        /// <summary>
        /// Creates a <see cref="ILegacyDataLoader"/> from the specified source loader, followed by
        /// the transforms that are loaded from the <paramref name="ctx"/>, tags filtered by
        /// by the <paramref name="isTransformTagAccepted"/>.
        /// If the <paramref name="ctx"/> contains no accepted transforms, the <paramref name="srcLoader"/> is
        /// returned intact.
        /// </summary>
        public static ILegacyDataLoader Create(IHostEnvironment env, ModelLoadContext ctx,
                                               ILegacyDataLoader srcLoader, Func <string, bool> isTransformTagAccepted)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);

            h.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            h.CheckValue(srcLoader, nameof(srcLoader));
            h.CheckValue(isTransformTagAccepted, nameof(isTransformTagAccepted));

            return(LoadTransforms(ctx, srcLoader, h, isTransformTagAccepted));
        }
        /// <summary>
        /// Apply one transform to the data loader, and returns a (composite) data loader that contains the result.
        /// The transform is created by invoking the lambda for a data source, and it should return an
        /// <see cref="IDataView"/> that shares the same loader as the provided source.
        /// </summary>
        public static ILegacyDataLoader ApplyTransform(IHostEnvironment env, ILegacyDataLoader srcLoader,
            string tag, string creationArgs, Func<IHostEnvironment, IDataView, IDataView> createTransform)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);

            h.CheckValue(srcLoader, nameof(srcLoader));
            h.CheckValueOrNull(tag);
            h.CheckValueOrNull(creationArgs);
            h.CheckValue(createTransform, nameof(createTransform));
            var tagData = new[] { new KeyValuePair<string, string>(tag, creationArgs) };
            return ApplyTransformsCore(env.Register(RegistrationName), srcLoader, tagData, (e, index, data) => createTransform(e, data));
        }
예제 #16
0
            protected override bool MoveNextCore()
            {
                // Iterate sub cursor or move to the next file.
                while (_subCursor == null || !_subCursor.MoveNext())
                {
                    // Cleanup old sub cursor
                    if (_subCursor != null)
                    {
                        _subCursor.Dispose();
                        _subCursor = null;
                    }

                    if (!TryGetNextPathAndValues(out string path, out string relativePath, out List <string> values))
                    {
                        return(false);
                    }

                    ILegacyDataLoader loader = null;
                    try
                    {
                        // Load the sub cursor and reset the data.
                        loader = _parent.CreateLoaderFromBytes(_parent._subLoaderBytes, new MultiFileSource(path));
                    }
                    catch (Exception e)
                    {
                        Ch.Warning($"Failed to load file {path} due to a loader exception. Moving on to the next file. Ex: {e.Message}");
                        continue;
                    }

                    _subCursor = loader.GetRowCursor(_subActivecolumnsNeeded);

                    try
                    {
                        UpdateSubGetters();
                        UpdateColumnValues(relativePath, values);
                    }
                    catch (InvalidOperationException e)
                    {
                        // Failed to load this file so skip.
                        Ch.Warning(MessageSensitivity.Schema, e.Message);
                        if (_subCursor != null)
                        {
                            _subCursor.Dispose();
                            _subCursor = null;
                        }
                    }
                }

                return(true);
            }
예제 #17
0
        public void LoadBinaryLoaderModelVersion3()
        {
            var env = new MLContext(1).AddStandardComponents();

            using (var modelStream = File.OpenRead(Path.Combine("TestModels", "BinaryLoader-v3.11.0.0.zip")))
                using (var rep = RepositoryReader.Open(modelStream, env))
                {
                    ILegacyDataLoader result = ModelFileUtils.LoadLoader(env, rep, new MultiFileSource(null), true);

                    Assert.Equal(2, result.Schema.Count);
                    Assert.Equal("Image", result.Schema[0].Name);
                    Assert.Equal("Class", result.Schema[1].Name);
                }
        }
        private static ILegacyDataLoader CreateCore(IHost host, ILegacyDataLoader srcLoader,
                                                    KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >[] transformArgs)
        {
            Contracts.AssertValue(host, "host");
            host.AssertValue(srcLoader, "srcLoader");
            host.AssertValueOrNull(transformArgs);

            if (Utils.Size(transformArgs) == 0)
            {
                return(srcLoader);
            }

            string GetTagData(IComponentFactory <IDataView, IDataTransform> factory)
            {
                // When coming from the command line, preserve the string arguments.
                // For other factories, we aren't able to get the string.
                return((factory as ICommandLineComponentFactory)?.ToString());
            }

            var tagData = transformArgs
                          .Select(x => new KeyValuePair <string, string>(x.Key, GetTagData(x.Value)))
                          .ToArray();

            // Warn if tags coincide with ones already present in the loader.
            var composite = srcLoader as LegacyCompositeDataLoader;

            if (composite != null)
            {
                using (var ch = host.Start("TagValidation"))
                {
                    foreach (var pair in tagData)
                    {
                        if (!string.IsNullOrEmpty(pair.Key) && composite._transforms.Any(x => x.Tag == pair.Key))
                        {
                            ch.Warning("The transform with tag '{0}' already exists in the chain", pair.Key);
                        }
                    }
                }
            }

            return(ApplyTransformsCore(host, srcLoader, tagData,
                                       (env, index, data) => transformArgs[index].Value.CreateComponent(env, data)));
        }
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);

            IPredictor inputPredictor = null;

            if (ImplOptions.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, ImplOptions.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing data pipeline");
            ILegacyDataLoader loader = CreateRawLoader();

            // If the per-instance results are requested and there is no name column, add a GenerateNumberTransform.
            var preXf = ImplOptions.PreTransforms;

            if (!string.IsNullOrEmpty(ImplOptions.OutputDataFile))
            {
                string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(ImplOptions.NameColumn), ImplOptions.NameColumn, DefaultColumnNames.Name);
                if (name == null)
                {
                    preXf = preXf.Concat(
                        new[]
                    {
                        new KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >(
                            "", ComponentFactoryUtils.CreateFromFunction <IDataView, IDataTransform>(
                                (env, input) =>
                        {
                            var args     = new GenerateNumberTransform.Options();
                            args.Columns = new[] { new GenerateNumberTransform.Column()
                                                   {
                                                       Name = DefaultColumnNames.Name
                                                   }, };
                            args.UseCounter = true;
                            return(new GenerateNumberTransform(env, args, input));
                        }))
                    }).ToArray();
                }
            }
            loader = LegacyCompositeDataLoader.Create(Host, loader, preXf);

            ch.Trace("Binding label and features columns");

            IDataView pipe = loader;
            var       stratificationColumn = GetSplitColumn(ch, loader, ref pipe);
            var       scorer    = ImplOptions.Scorer;
            var       evaluator = ImplOptions.Evaluator;

            Func <IDataView> validDataCreator = null;

            if (ImplOptions.ValidationFile != null)
            {
                validDataCreator =
                    () =>
                {
                    // Fork the command.
                    var impl = new CrossValidationCommand(this);
                    return(impl.CreateRawLoader(dataFile: ImplOptions.ValidationFile));
                };
            }

            FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn,
                                             ImplOptions, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
                                             validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(ImplOptions.OutputDataFile));
            var tasks = fold.GetCrossValidationTasks();

            var eval = evaluator?.CreateComponent(Host) ??
                       EvaluateUtils.GetEvaluator(Host, tasks[0].Result.ScoreSchema);

            // Print confusion matrix and fold results for each fold.
            for (int i = 0; i < tasks.Length; i++)
            {
                var dict = tasks[i].Result.Metrics;
                MetricWriter.PrintWarnings(ch, dict);
                eval.PrintFoldResults(ch, dict);
            }

            // Print the overall results.
            if (!TryGetOverallMetrics(tasks.Select(t => t.Result.Metrics).ToArray(), out var overallList))
            {
                throw ch.Except("No overall metrics found");
            }

            var overall = eval.GetOverallResults(overallList.ToArray());

            MetricWriter.PrintOverallMetrics(Host, ch, ImplOptions.SummaryFilename, overall, ImplOptions.NumFolds);
            eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray());
            Dictionary <string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray();
            SendTelemetryMetric(metricValues);

            // Save the per-instance results.
            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputDataFile))
            {
                var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, ImplOptions.CollateMetrics,
                                                                                ImplOptions.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames);
                if (variableSizeVectorColumnNames.Length > 0)
                {
                    ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.",
                               string.Join(", ", variableSizeVectorColumnNames));
                }
                if (ImplOptions.CollateMetrics)
                {
                    ch.Assert(perInstance.Length == 1);
                    MetricWriter.SavePerInstance(Host, ch, ImplOptions.OutputDataFile, perInstance[0]);
                }
                else
                {
                    int i = 0;
                    foreach (var idv in perInstance)
                    {
                        MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(ImplOptions.OutputDataFile, i), idv);
                        i++;
                    }
                }
            }
        }
        private static ILegacyDataLoader ApplyTransformsCore(IHost host, ILegacyDataLoader srcLoader,
                                                             KeyValuePair <string, string>[] tagData, Func <IHostEnvironment, int, IDataView, IDataView> createTransform)
        {
            Contracts.AssertValue(host, "host");
            host.AssertValue(srcLoader, "srcLoader");
            host.AssertNonEmpty(tagData);
            host.AssertValue(createTransform, "createTransform");

            // If the loader is a composite, we need to start with its underlying pipeline end.
            var               exes      = new List <TransformEx>();
            var               composite = srcLoader as LegacyCompositeDataLoader;
            IDataView         srcView;
            ILegacyDataLoader pipeStart;

            if (composite != null)
            {
                srcView = composite.View;
                exes.AddRange(composite._transforms);
                pipeStart = composite._loader;
            }
            else
            {
                srcView = pipeStart = srcLoader;
            }

            IDataView view = srcView;

            using (var ch = host.Start("Transforms"))
            {
                int count        = Utils.Size(tagData);
                var newlyCreated = new List <TransformEx>();
                for (int i = 0; i < count; i++)
                {
                    // REVIEW: this might cause silent automatic tag conflicts if the pipeline is short-circuited.
                    // Maybe it's better to allow empty tags?
                    var tag = tagData[i].Key;
                    if (string.IsNullOrEmpty(tag))
                    {
                        tag = GenerateTag(exes.Count);
                    }

                    var newDataView = createTransform(host, i, view);
                    // Append the newly created transforms to the exes list.
                    // If the newTransform is a 'no-op' transform, i.e. equal to the original view,
                    // the exes array will not be modified: there's no reason to record details of a no-op transform,
                    // especially since this would overwrite the useful details of the upstream transform.
                    newlyCreated.Clear();
                    IDataView curDataView = newDataView;
                    while (true)
                    {
                        var cur = curDataView as IDataTransform;
                        if (cur == null)
                        {
                            // We reached all the way back to the pipe start. The exes accumulated so far are irrelevant.
                            ch.Check(curDataView == pipeStart,
                                     "The transform has corrupted the chain (chain no longer starts with the same loader).");
                            exes.Clear();
                            break;
                        }

                        int index = exes.FindLastIndex(x => x.Transform == cur);
                        if (index >= 0)
                        {
                            // We found a transform in exes to attach to.
                            if (index < exes.Count - 1)
                            {
                                // The transform short-circuited some of the existing ones, remove them.
                                exes.RemoveRange(index + 1, exes.Count - index - 1);
                            }
                            break;
                        }

                        newlyCreated.Add(new TransformEx(tag, tagData[i].Value, cur));
                        curDataView = cur.Source;
                    }

                    newlyCreated.Reverse();
                    exes.AddRange(newlyCreated);

                    view = newDataView;
                }
            }

            return(view == srcView ? srcLoader : new LegacyCompositeDataLoader(host, exes.ToArray()));
        }
예제 #21
0
        private void Run(IChannel ch)
        {
            ILegacyDataLoader loader  = null;
            IPredictor        rawPred = null;
            IDataView         view;
            RoleMappedSchema  trainSchema = null;

            if (_model == null && _predictiveModel == null)
            {
                if (string.IsNullOrEmpty(ImplOptions.InputModelFile))
                {
                    loader      = CreateLoader();
                    rawPred     = null;
                    trainSchema = null;
                    Host.CheckUserArg(ImplOptions.LoadPredictor != true, nameof(ImplOptions.LoadPredictor),
                                      "Cannot be set to true unless " + nameof(ImplOptions.InputModelFile) + " is also specified.");
                }
                else
                {
                    LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader);
                }

                view = loader;
            }
            else if (_model != null)
            {
                view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema));
            }
            else
            {
                view        = _predictiveModel.TransformModel.Apply(Host, new EmptyDataView(Host, _predictiveModel.TransformModel.InputSchema));
                rawPred     = _predictiveModel.Predictor;
                trainSchema = _predictiveModel.GetTrainingSchema(Host);
            }

            // Create the ONNX context for storing global information
            var assembly    = System.Reflection.Assembly.GetExecutingAssembly();
            var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location);
            var ctx         = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion,
                                                  ModelVersion, _domain, ImplOptions.OnnxVersion);

            // Get the transform chain.
            IDataView source;
            IDataView end;
            LinkedList <ITransformCanSaveOnnx> transforms;

            GetPipe(ctx, ch, view, out source, out end, out transforms);
            Host.Assert(transforms.Count == 0 || transforms.Last.Value == end);

            // If we have a predictor, try to get the scorer for it.
            if (rawPred != null)
            {
                RoleMappedData data;
                if (trainSchema != null)
                {
                    data = new RoleMappedData(end, trainSchema.GetColumnRoleNames());
                }
                else
                {
                    // We had a predictor, but no roles stored in the model. Just suppose
                    // default column names are OK, if present.
                    data = new RoleMappedData(end, DefaultColumnNames.Label,
                                              DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true);
                }

                var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema);
                var scoreOnnx = scorePipe as ITransformCanSaveOnnx;
                if (scoreOnnx?.CanSaveOnnx(ctx) == true)
                {
                    Host.Assert(scorePipe.Source == end);
                    end = scorePipe;
                    transforms.AddLast(scoreOnnx);

                    if (rawPred.PredictionKind == PredictionKind.BinaryClassification || rawPred.PredictionKind == PredictionKind.MulticlassClassification)
                    {
                        // Check if the PredictedLabel Column is a KeyDataViewType and has KeyValue Annotations.
                        // If it does, add a KeyToValueMappingTransformer, to enable NimbusML to get the values back
                        // when using an ONNX model, as described in https://github.com/dotnet/machinelearning/pull/4841
                        var predictedLabelColumn = scorePipe.Schema.GetColumnOrNull(DefaultColumnNames.PredictedLabel);
                        if (predictedLabelColumn.HasValue && HasKeyValues(predictedLabelColumn.Value))
                        {
                            var outputData = new KeyToValueMappingTransformer(Host, DefaultColumnNames.PredictedLabel).Transform(scorePipe);
                            end = outputData;
                            transforms.AddLast(outputData as ITransformCanSaveOnnx);
                        }
                    }
                }
                else
                {
                    Contracts.CheckUserArg(_loadPredictor != true,
                                           nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as ONNX.");
                    ch.Warning("We do not know how to save the predictor as ONNX. Ignoring.");
                }
            }
            else
            {
                Contracts.CheckUserArg(_loadPredictor != true,
                                       nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present.");
            }

            // Convert back to values the KeyDataViewType "pass-through" columns
            // (i.e those that remained untouched by the model). This is done to enable NimbusML to get these values
            // as described in https://github.com/dotnet/machinelearning/pull/4841

            var passThroughColumnNames = GetPassThroughKeyDataViewTypeColumnsNames(source, end);

            foreach (var name in passThroughColumnNames)
            {
                var outputData = new KeyToValueMappingTransformer(Host, name).Transform(end);
                end = outputData;
                transforms.AddLast(end as ITransformCanSaveOnnx);
            }

            var model = ConvertTransformListToOnnxModel(ctx, ch, source, end, transforms, _inputsToDrop, _outputsToDrop);

            using (var file = Host.CreateOutputFile(_outputModelPath))
                using (var stream = file.CreateWriteStream())
                    model.WriteTo(stream);

            if (_outputJsonModelPath != null)
            {
                using (var file = Host.CreateOutputFile(_outputJsonModelPath))
                    using (var stream = file.CreateWriteStream())
                        using (var writer = new StreamWriter(stream))
                        {
                            var parsedJson = JsonConvert.DeserializeObject(model.ToString());
                            writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented));
                        }
            }

            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile))
            {
                Contracts.Assert(loader != null);

                ch.Trace("Saving the data pipe");
                // Should probably include "end"?
                SaveLoader(loader, ImplOptions.OutputModelFile);
            }
        }
 private ILegacyDataLoader CreateTransformChain(ILegacyDataLoader loader)
 {
     return(LegacyCompositeDataLoader.Create(Host, loader, ImplOptions.Transforms));
 }
            /// <summary>
            /// Loads multiple artifacts of interest from the input model file, given the context
            /// established by the command line arguments.
            /// </summary>
            /// <param name="ch">The channel to which to provide output.</param>
            /// <param name="wantPredictor">Whether we want a predictor from the model file. If
            /// <c>false</c> we will not even attempt to load a predictor. If <c>null</c> we will
            /// load the predictor, if present. If <c>true</c> we will load the predictor, or fail
            /// noisily if we cannot.</param>
            /// <param name="predictor">The predictor in the model, or <c>null</c> if
            /// <paramref name="wantPredictor"/> was false, or <paramref name="wantPredictor"/> was
            /// <c>null</c> and no predictor was present.</param>
            /// <param name="wantTrainSchema">Whether we want the training schema. Unlike
            /// <paramref name="wantPredictor"/>, this has no "hard fail if not present" option. If
            /// this is <c>true</c>, it is still possible for <paramref name="trainSchema"/> to remain
            /// <c>null</c> if there were no role mappings, or pipeline.</param>
            /// <param name="trainSchema">The training schema if <paramref name="wantTrainSchema"/>
            /// is true, and there were role mappings stored in the model.</param>
            /// <param name="pipe">The data pipe constructed from the combination of the
            /// model and command line arguments.</param>
            protected void LoadModelObjects(
                IChannel ch,
                bool?wantPredictor, out IPredictor predictor,
                bool wantTrainSchema, out RoleMappedSchema trainSchema,
                out ILegacyDataLoader pipe)
            {
                // First handle the case where there is no input model file.
                // Everything must come from the command line.

                using (var file = Host.OpenInputFile(ImplOptions.InputModelFile))
                    using (var strm = file.OpenReadStream())
                        using (var rep = RepositoryReader.Open(strm, Host))
                        {
                            // First consider loading the predictor.
                            if (wantPredictor == false)
                            {
                                predictor = null;
                            }
                            else
                            {
                                ch.Trace("Loading predictor");
                                predictor = ModelFileUtils.LoadPredictorOrNull(Host, rep);
                                if (wantPredictor == true)
                                {
                                    Host.Check(predictor != null, "Could not load predictor from model file");
                                }
                            }

                            // Next create the loader.
                            var loaderFactory           = ImplOptions.Loader;
                            ILegacyDataLoader trainPipe = null;
                            if (loaderFactory != null)
                            {
                                // The loader is overridden from the command line.
                                pipe = loaderFactory.CreateComponent(Host, new MultiFileSource(ImplOptions.DataFile));
                                if (ImplOptions.LoadTransforms == true)
                                {
                                    Host.CheckUserArg(!string.IsNullOrWhiteSpace(ImplOptions.InputModelFile), nameof(ImplOptions.InputModelFile));
                                    pipe = LoadTransformChain(pipe);
                                }
                            }
                            else
                            {
                                var loadTrans = ImplOptions.LoadTransforms ?? true;
                                pipe = LoadLoader(rep, ImplOptions.DataFile, loadTrans);
                                if (loadTrans)
                                {
                                    trainPipe = pipe;
                                }
                            }

                            if (Utils.Size(ImplOptions.Transforms) > 0)
                            {
                                pipe = LegacyCompositeDataLoader.Create(Host, pipe, ImplOptions.Transforms);
                            }

                            // Next consider loading the training data's role mapped schema.
                            trainSchema = null;
                            if (wantTrainSchema)
                            {
                                // First try to get the role mappings.
                                var trainRoleMappings = ModelFileUtils.LoadRoleMappingsOrNull(Host, rep);
                                if (trainRoleMappings != null)
                                {
                                    // Next create the training schema. In the event that the loaded pipeline happens
                                    // to be the training pipe, we can just use that. If it differs, then we need to
                                    // load the full pipeline from the model, relying upon the fact that all loaders
                                    // can be loaded with no data at all, to get their schemas.
                                    if (trainPipe == null)
                                    {
                                        trainPipe = ModelFileUtils.LoadLoader(Host, rep, new MultiFileSource(null), loadTransforms: true);
                                    }
                                    trainSchema = new RoleMappedSchema(trainPipe.Schema, trainRoleMappings);
                                }
                                // If the role mappings are null, an alternative would be to fail. However the idea
                                // is that the scorer should always still succeed, although perhaps with reduced
                                // functionality, even when the training schema is null, since not all versions of
                                // TLC models will have the role mappings preserved, I believe. And, we do want to
                                // maintain backwards compatibility.
                            }
                        }
            }
 protected void SaveLoader(ILegacyDataLoader loader, string path)
 {
     using (var file = Host.CreateOutputFile(path))
         LoaderUtils.SaveLoader(loader, file);
 }
예제 #25
0
        private void Run(IChannel ch)
        {
            ILegacyDataLoader loader  = null;
            IPredictor        rawPred = null;
            IDataView         view;
            RoleMappedSchema  trainSchema = null;

            if (_model == null)
            {
                if (string.IsNullOrEmpty(ImplOptions.InputModelFile))
                {
                    loader      = CreateLoader();
                    rawPred     = null;
                    trainSchema = null;
                    Host.CheckUserArg(ImplOptions.LoadPredictor != true, nameof(ImplOptions.LoadPredictor),
                                      "Cannot be set to true unless " + nameof(ImplOptions.InputModelFile) + " is also specifified.");
                }
                else
                {
                    LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader);
                }

                view = loader;
            }
            else
            {
                view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema));
            }

            // Create the ONNX context for storing global information
            var assembly    = System.Reflection.Assembly.GetExecutingAssembly();
            var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location);
            var ctx         = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion,
                                                  ModelVersion, _domain, ImplOptions.OnnxVersion);

            // Get the transform chain.
            IDataView source;
            IDataView end;
            LinkedList <ITransformCanSaveOnnx> transforms;

            GetPipe(ctx, ch, view, out source, out end, out transforms);
            Host.Assert(transforms.Count == 0 || transforms.Last.Value == end);

            // If we have a predictor, try to get the scorer for it.
            if (rawPred != null)
            {
                RoleMappedData data;
                if (trainSchema != null)
                {
                    data = new RoleMappedData(end, trainSchema.GetColumnRoleNames());
                }
                else
                {
                    // We had a predictor, but no roles stored in the model. Just suppose
                    // default column names are OK, if present.
                    data = new RoleMappedData(end, DefaultColumnNames.Label,
                                              DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true);
                }

                var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema);
                var scoreOnnx = scorePipe as ITransformCanSaveOnnx;
                if (scoreOnnx?.CanSaveOnnx(ctx) == true)
                {
                    Host.Assert(scorePipe.Source == end);
                    end = scorePipe;
                    transforms.AddLast(scoreOnnx);
                }
                else
                {
                    Contracts.CheckUserArg(_loadPredictor != true,
                                           nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as ONNX.");
                    ch.Warning("We do not know how to save the predictor as ONNX. Ignoring.");
                }
            }
            else
            {
                Contracts.CheckUserArg(_loadPredictor != true,
                                       nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present.");
            }

            var model = ConvertTransformListToOnnxModel(ctx, ch, source, end, transforms, _inputsToDrop, _outputsToDrop);

            using (var file = Host.CreateOutputFile(_outputModelPath))
                using (var stream = file.CreateWriteStream())
                    model.WriteTo(stream);

            if (_outputJsonModelPath != null)
            {
                using (var file = Host.CreateOutputFile(_outputJsonModelPath))
                    using (var stream = file.CreateWriteStream())
                        using (var writer = new StreamWriter(stream))
                        {
                            var parsedJson = JsonConvert.DeserializeObject(model.ToString());
                            writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented));
                        }
            }

            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile))
            {
                Contracts.Assert(loader != null);

                ch.Trace("Saving the data pipe");
                // Should probably include "end"?
                SaveLoader(loader, ImplOptions.OutputModelFile);
            }
        }