/// <summary> /// Return role/column-name pairs loaded from a repository. /// </summary> public static IEnumerable <KeyValuePair <ColumnRole, string> > LoadRoleMappingsOrNull(IHostEnvironment env, RepositoryReader rep) { Contracts.CheckValue(env, nameof(env)); var h = env.Register("RoleMappingUtils"); var list = new List <KeyValuePair <string, string> >(); var entry = rep.OpenEntryOrNull(DirTrainingInfo, RoleMappingFile); if (entry == null) { return(null); } entry.Dispose(); using (var ch = h.Start("Loading role mappings")) { // REVIEW: Should really validate the schema here, and consider // ignoring this stream if it isn't as expected. var loaderSub = new SubComponent <IDataLoader, SignatureDataLoader>("Text"); var loader = loaderSub.CreateInstance(env, new RepositoryStreamWrapper(rep, DirTrainingInfo, RoleMappingFile)); using (var cursor = loader.GetRowCursor(c => true)) { var roleGetter = cursor.GetGetter <DvText>(0); var colGetter = cursor.GetGetter <DvText>(1); var role = default(DvText); var col = default(DvText); while (cursor.MoveNext()) { roleGetter(ref role); colGetter(ref col); string roleStr = role.ToString(); string colStr = col.ToString(); h.CheckDecode(!string.IsNullOrWhiteSpace(roleStr), "Role name must not be empty"); h.CheckDecode(!string.IsNullOrWhiteSpace(colStr), "Column name must not be empty"); list.Add(new KeyValuePair <string, string>(roleStr, colStr)); } } ch.Done(); } return(TrainUtils.CheckAndGenerateCustomColumns(env, list.ToArray())); }
/// <summary> /// Returns the <see cref="RoleMappedSchema"/> from a repository, or <c>null</c> if there were no /// role mappings present. /// </summary> public static RoleMappedSchema LoadRoleMappedSchemaOrNull(IHostEnvironment env, RepositoryReader rep) { Contracts.CheckValue(env, nameof(env)); var h = env.Register("RoleMappingUtils"); var roleMappings = ModelFileUtils.LoadRoleMappingsOrNull(env, rep); if (roleMappings == null) { return(null); } var pipe = ModelFileUtils.LoadLoader(h, rep, new MultiFileSource(null), loadTransforms: true); return(new RoleMappedSchema(pipe.Schema, roleMappings)); }
/// <summary> /// REVIEW: consider adding an overload that returns <see cref="VBuffer{DvText}"/> /// Loads optionally feature names from the repository directory. /// Returns false iff no stream was found for feature names, iff result is set to null. /// </summary> public static bool TryLoadFeatureNames(out FeatureNameCollection featureNames, RepositoryReader rep) { Contracts.CheckValue(rep, nameof(rep)); using (var ent = rep.OpenEntryOrNull(ModelFileUtils.DirTrainingInfo, "FeatureNames.bin")) { if (ent != null) { using (var ctx = new ModelLoadContext(rep, ent, ModelFileUtils.DirTrainingInfo)) { featureNames = FeatureNameCollection.Create(ctx); return(true); } } } featureNames = null; return(false); }
/// <summary> /// Load an object from the repository directory. /// </summary> public static void LoadModel <TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra) where TRes : class { Contracts.CheckValue(env, nameof(env)); env.CheckValue(rep, nameof(rep)); if (!LoadModelOrNull <TRes, TSig>(env, out result, rep, dir, extra)) { throw env.ExceptDecode("Corrupt model file"); } env.AssertValue(result); }
/// <summary> /// Load an optional object from the repository directory. /// Returns false iff no stream was found for the object, iff result is set to null. /// Throws if loading fails for any other reason. /// </summary> public static bool LoadModelOrNull <TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra) where TRes : class { Contracts.CheckValue(env, nameof(env)); env.CheckValue(rep, nameof(rep)); var ent = rep.OpenEntryOrNull(dir, ModelStreamName); if (ent != null) { using (ent) { // Provide the repository, entry, and directory name to the loadable class ctor. env.Assert(ent.Stream.Position == 0); LoadModel <TRes, TSig>(env, out result, rep, ent, dir, extra); return(true); } } if ((ent = rep.OpenEntryOrNull(dir, NameBinary)) != null) { using (ent) { env.Assert(ent.Stream.Position == 0); LoadModel <TRes, TSig>(env, out result, ent.Stream, extra); return(true); } } result = null; return(false); }
/// <summary> /// Load from the given repository entry using the default loader(s) specified in the header. /// </summary> public static void LoadModel <TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra) where TRes : class { Contracts.CheckValue(env, nameof(env)); env.CheckValue(rep, nameof(rep)); if (!TryLoadModel <TRes, TSig>(env, out result, rep, ent, dir, extra)) { throw env.ExceptDecode("Couldn't load model: '{0}'", dir); } }
/// <summary> /// Try to load from the given repository entry using the default loader(s) specified in the header. /// Returns false iff the default loader(s) could not be bound to a compatible loadable class. /// </summary> private static bool TryLoadModel <TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra) where TRes : class { Contracts.CheckValue(env, nameof(env)); env.CheckValue(rep, nameof(rep)); long fp = ent.Stream.Position; using (var ctx = new ModelLoadContext(rep, ent, dir)) { env.Assert(fp == ctx.FpMin); if (ctx.TryLoadModelCore <TRes, TSig>(env, out result, extra)) { return(true); } } // TryLoadModelCore should rewind on failure. Contracts.Assert(fp == ent.Stream.Position); return(false); }