/// <summary>
        /// Load an optional object from the repository directory.
        /// Returns false iff no stream was found for the object, iff result is set to null.
        /// Throws if loading fails for any other reason.
        /// </summary>
        public static bool LoadModelOrNull <TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
            where TRes : class
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(rep, nameof(rep));
            var ent = rep.OpenEntryOrNull(dir, ModelStreamName);

            if (ent != null)
            {
                using (ent)
                {
                    // Provide the repository, entry, and directory name to the loadable class ctor.
                    env.Assert(ent.Stream.Position == 0);
                    LoadModel <TRes, TSig>(env, out result, rep, ent, dir, extra);
                    return(true);
                }
            }

            if ((ent = rep.OpenEntryOrNull(dir, NameBinary)) != null)
            {
                using (ent)
                {
                    env.Assert(ent.Stream.Position == 0);
                    LoadModel <TRes, TSig>(env, out result, ent.Stream, extra);
                    return(true);
                }
            }

            result = null;
            return(false);
        }
Example #2
0
        /// <summary>
        /// Loads data view (loader and transforms) from <paramref name="rep"/> if <paramref name="loadTransforms"/> is set to true,
        /// otherwise loads loader only.
        /// </summary>
        public static ILegacyDataLoader LoadLoader(IHostEnvironment env, RepositoryReader rep, IMultiStreamSource files, bool loadTransforms)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(rep, nameof(rep));
            env.CheckValue(files, nameof(files));

            ILegacyDataLoader loader;

            // If loadTransforms is false, load the loader only, not the transforms.
            Repository.Entry ent = null;
            string           dir = "";

            if (!loadTransforms)
            {
                ent = rep.OpenEntryOrNull(dir = Path.Combine(DirDataLoaderModel, "Loader"), ModelLoadContext.ModelStreamName);
            }

            if (ent == null) // either loadTransforms is true, or it's not a composite loader
            {
                ent = rep.OpenEntry(dir = DirDataLoaderModel, ModelLoadContext.ModelStreamName);
            }

            env.CheckDecode(ent != null, "Loader is not found.");
            env.AssertNonEmpty(dir);
            using (ent)
            {
                env.Assert(ent.Stream.Position == 0);
                ModelLoadContext.LoadModel <ILegacyDataLoader, SignatureLoadDataLoader>(env, out loader, rep, ent, dir, files);
            }
            return(loader);
        }
        /// <summary>
        /// Loads and returns the loader and transforms from the specified repository reader.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="rep">The repository reader.</param>
        /// <param name="files">The data source to initialize the loader with.</param>
        /// <param name="extractInnerPipe">Whether to extract the transforms and loader from the wrapped CompositeDataLoader.</param>
        /// <returns>The created data view.</returns>
        public static IDataView LoadPipeline(IHostEnvironment env, RepositoryReader rep, IMultiStreamSource files, bool extractInnerPipe = false)
        {
            // REVIEW: Should not duplicate loading loader/transforms code. This method should call LoadLoader.
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(rep, nameof(rep));
            env.CheckValue(files, nameof(files));

            var entry = rep.OpenEntryOrNull(SchemaEntryName);

            if (entry != null)
            {
                var loader = new BinaryLoader(env, new BinaryLoader.Arguments(), entry.Stream);
                ModelLoadContext.LoadModel <ITransformer, SignatureLoadModel>(env, out var transformerChain, rep, DirTransformerChain);
                return(transformerChain.Transform(loader));
            }

            using (var ent = rep.OpenEntry(DirDataLoaderModel, ModelLoadContext.ModelStreamName))
            {
                ILegacyDataLoader loader;
                env.Assert(ent.Stream.Position == 0);
                ModelLoadContext.LoadModel <ILegacyDataLoader, SignatureLoadDataLoader>(env, out loader, rep, ent, DirDataLoaderModel, files);
                IDataView result = loader;
                if (extractInnerPipe)
                {
                    var cdl = loader as LegacyCompositeDataLoader;
                    result = cdl == null ? loader : cdl.View;
                }

                return(result);
            }
        }
Example #4
0
 /// <summary>
 /// Loads all transforms from the model stream, applies them sequentially to the provided data, and returns
 /// the resulting data. If there are no transforms in the stream, or if there's no DataLoader stream at all
 /// (this can happen if the model is produced by old TL), returns the source data.
 /// If the DataLoader stream is invalid, throws.
 /// </summary>
 /// <param name="env">The host environment to use.</param>
 /// <param name="data">The starting data view.</param>
 /// <param name="rep">The repository reader.</param>
 /// <returns>The resulting data view.</returns>
 public static IDataView LoadTransforms(IHostEnvironment env, IDataView data, RepositoryReader rep)
 {
     Contracts.CheckValue(env, nameof(env));
     env.CheckValue(data, nameof(data));
     env.CheckValue(rep, nameof(rep));
     using (var ent = rep.OpenEntryOrNull(DirDataLoaderModel, ModelLoadContext.ModelStreamName))
     {
         if (ent == null)
         {
             return(data);
         }
         var ctx = new ModelLoadContext(rep, ent, DirDataLoaderModel);
         return(LegacyCompositeDataLoader.LoadSelectedTransforms(ctx, data, env, x => true));
     }
 }
Example #5
0
        /// <summary>
        /// Return role/column-name pairs loaded from a repository.
        /// </summary>
        public static IEnumerable <KeyValuePair <ColumnRole, string> > LoadRoleMappingsOrNull(IHostEnvironment env, RepositoryReader rep)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register("RoleMappingUtils");

            var list = new List <KeyValuePair <string, string> >();

            var entry = rep.OpenEntryOrNull(DirTrainingInfo, RoleMappingFile);

            if (entry == null)
            {
                return(null);
            }
            entry.Dispose();

            using (var ch = h.Start("Loading role mappings"))
            {
                // REVIEW: Should really validate the schema here, and consider
                // ignoring this stream if it isn't as expected.
                var repoStreamWrapper = new RepositoryStreamWrapper(rep, DirTrainingInfo, RoleMappingFile);
                var loader            = new TextLoader(env, dataSample: repoStreamWrapper).Load(repoStreamWrapper);

                using (var cursor = loader.GetRowCursorForAllColumns())
                {
                    var roleGetter = cursor.GetGetter <ReadOnlyMemory <char> >(cursor.Schema[0]);
                    var colGetter  = cursor.GetGetter <ReadOnlyMemory <char> >(cursor.Schema[1]);
                    var role       = default(ReadOnlyMemory <char>);
                    var col        = default(ReadOnlyMemory <char>);
                    while (cursor.MoveNext())
                    {
                        roleGetter(ref role);
                        colGetter(ref col);
                        string roleStr = role.ToString();
                        string colStr  = col.ToString();

                        h.CheckDecode(!string.IsNullOrWhiteSpace(roleStr), "Role name must not be empty");
                        h.CheckDecode(!string.IsNullOrWhiteSpace(colStr), "Column name must not be empty");
                        list.Add(new KeyValuePair <string, string>(roleStr, colStr));
                    }
                }
            }

            return(TrainUtils.CheckAndGenerateCustomColumns(env, list.ToArray()));
        }
Example #6
0
        /// <summary>
        /// REVIEW: consider adding an overload that returns <see cref="ReadOnlyMemory{T}"/> of <see cref="char"/>
        /// Loads optionally feature names from the repository directory.
        /// Returns false iff no stream was found for feature names, iff result is set to null.
        /// </summary>
        public static bool TryLoadFeatureNames(out FeatureNameCollection featureNames, RepositoryReader rep)
        {
            Contracts.CheckValue(rep, nameof(rep));

            using (var ent = rep.OpenEntryOrNull(ModelFileUtils.DirTrainingInfo, "FeatureNames.bin"))
            {
                if (ent != null)
                {
                    using (var ctx = new ModelLoadContext(rep, ent, ModelFileUtils.DirTrainingInfo))
                    {
                        featureNames = FeatureNameCollection.Create(ctx);
                        return(true);
                    }
                }
            }

            featureNames = null;
            return(false);
        }