Esempio n. 1
0
        /// <summary>
        /// Saves the transformer to file.
        /// </summary>
        /// <param name="ctx">The <see cref="ModelSaveContext"/> that facilitates saving to the <see cref="Repository"/>.</param>
        public void Save(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // model: prediction model.
            // stream: empty data view that contains train schema.
            // ids of strings: feature columns.
            // float: scorer threshold
            // id of string: scorer threshold column

            ctx.SaveModel(Model, DirModel);
            ctx.SaveBinaryStream(DirTransSchema, writer =>
            {
                using (var ch = Host.Start("Saving train schema"))
                {
                    var saver = new BinarySaver(Host, new BinarySaver.Arguments {
                        Silent = true
                    });
                    DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream);
                }
            });

            for (int i = 0; i < Model.FieldCount; i++)
            {
                ctx.SaveString(FeatureColumns[i]);
            }

            ctx.Writer.Write(_threshold);
            ctx.SaveString(_thresholdColumn);
        }
Esempio n. 2
0
 public sealed override void Save(ModelSaveContext ctx)
 {
     Contracts.AssertValue(ctx);
     ctx.CheckAtModel();
     ctx.SaveModel(Bindable, "SchemaBindableMapper");
     SaveCore(ctx);
 }
        public void Save(ModelSaveContext ctx)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // ** Binary format **
            // int: tailing directory count
            // Schema of the loader
            // int[]: srcColumns
            // byte[]: subloader
            // model: file path spec

            ctx.Writer.Write(_tailingDirCount);

            // Save the schema
            var noRows    = new EmptyDataView(_host, Schema);
            var saverArgs = new BinarySaver.Arguments();

            saverArgs.Silent = true;
            var saver = new BinarySaver(_host, saverArgs);

            using (var strm = new MemoryStream())
            {
                var allColumns = Enumerable.Range(0, Schema.Count).ToArray();
                saver.SaveData(strm, noRows, allColumns);
                ctx.SaveBinaryStream(SchemaCtxName, w => w.WriteByteArray(strm.ToArray()));
            }
            ctx.Writer.WriteIntArray(_srcDirIndex);

            ctx.Writer.WriteByteArray(_subLoaderBytes);
            ctx.SaveModel(_pathParser, FilePathSpecCtxName);
        }
Esempio n. 4
0
            public void Save(ModelSaveContext ctx)
            {
                _host.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel();
                ctx.SetVersionInfo(GetVersionInfo());

                var dataPipe   = _xf;
                var transforms = new List <IDataTransform>();

                while (dataPipe is IDataTransform xf)
                {
                    transforms.Add(xf);
                    dataPipe = xf.Source;
                    Contracts.AssertValue(dataPipe);
                }
                transforms.Reverse();

                ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema));

                ctx.Writer.Write(transforms.Count);
                for (int i = 0; i < transforms.Count; i++)
                {
                    var dirName = string.Format(TransformDirTemplate, i);
                    ctx.SaveModel(transforms[i], dirName);
                }
            }
Esempio n. 5
0
        private static void SaveCore(ModelSaveContext ctx, Action <ModelSaveContext> loaderSaveAction, TransformEx[] transforms)
        {
            Contracts.AssertValue(ctx);
            Contracts.AssertValue(loaderSaveAction);
            Contracts.AssertValueOrNull(transforms);

            // *** Binary format ***
            // int: sizeof(Float)
            // int: number of transforms
            // foreach transform: (starting from version VersionAddedTags)
            //     string: tag
            //     string: args string

            ctx.Writer.Write(sizeof(Float));
            ctx.Writer.Write(transforms.Length);

            using (var loaderCtx = new ModelSaveContext(ctx.Repository, Path.Combine(ctx.Directory ?? "", "Loader"), ModelLoadContext.ModelStreamName))
            {
                loaderSaveAction(loaderCtx);
                loaderCtx.Done();
            }

            for (int i = 0; i < transforms.Length; i++)
            {
                var dirName = string.Format(TransformDirTemplate, i);
                ctx.SaveModel(transforms[i].Transform, dirName);

                Contracts.AssertNonEmpty(transforms[i].Tag);
                ctx.SaveNonEmptyString(transforms[i].Tag);
                ctx.SaveStringOrNull(transforms[i].ArgsString);
            }
        }
        public override void Save(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();

            Host.Assert(InitialWindowSize == 0);
            Host.Assert(2 <= SeasonalWindowSize);
            Host.Assert(0 <= DiscountFactor && DiscountFactor <= 1);
            Host.Assert(Enum.IsDefined(typeof(ErrorFunctionUtils.ErrorFunction), ErrorFunction));
            Host.Assert(Model != null);

            // *** Binary format ***
            // <base>
            // int: _seasonalWindowSize
            // float: _discountFactor
            // byte: _errorFunction
            // bool: _isAdaptive
            // State: StateRef
            // AdaptiveSingularSpectrumSequenceModeler: _model

            base.Save(ctx);
            ctx.Writer.Write(SeasonalWindowSize);
            ctx.Writer.Write(DiscountFactor);
            ctx.Writer.Write((byte)ErrorFunction);
            ctx.Writer.Write(IsAdaptive);
            StateRef.Save(ctx.Writer);

            ctx.SaveModel(Model, "SSA");
        }
Esempio n. 7
0
        public void Save(ModelSaveContext ctx)
        {
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            var dataPipe   = _xf;
            var transforms = new List <IDataTransform>();

            while (dataPipe is IDataTransform xf)
            {
                // REVIEW: a malicious user could construct a loop in the Source chain, that would
                // cause this method to iterate forever (and throw something when the list overflows). There's
                // no way to insulate from ALL malicious behavior.
                transforms.Add(xf);
                dataPipe = xf.Source;
                Contracts.AssertValue(dataPipe);
            }
            transforms.Reverse();

            ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema));

            ctx.Writer.Write(transforms.Count);
            for (int i = 0; i < transforms.Count; i++)
            {
                var dirName = string.Format(TransformDirTemplate, i);
                ctx.SaveModel(transforms[i], dirName);
            }
        }
            public override void Save(ModelSaveContext ctx)
            {
                Host.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel();
                ctx.SetVersionInfo(GetVersionInfo());

                // *** Binary format ***
                // number of columns
                // for each column, number of slots
                // Sub-models:
                // count tables (each in a separate folder)

                Host.Assert(_countTables.Length > 0);
                ctx.Writer.Write(_countTables.Length);

                for (int i = 0; i < _countTables.Length; i++)
                {
                    var size = _countTables[i].Length;
                    Host.Assert(size > 0);
                    ctx.Writer.Write(size);
                    for (int j = 0; j < size; j++)
                    {
                        var tableName = string.Format("Table_{0:000}_{1:000}", i, j);
                        ctx.SaveModel(_countTables[i][j], tableName);
                    }
                }
            }
Esempio n. 9
0
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // model: prediction model.
            // stream: empty data view that contains train schema.
            // string: feature column name.
            // string: the name of the columns where tree prediction values are stored.
            // string: the name of the columns where trees' leave are stored.
            // string: the name of the columns where trees' paths are stored.

            ctx.SaveModel(Model, DirModel);
            ctx.SaveBinaryStream(DirTransSchema, writer =>
            {
                using (var ch = Host.Start("Saving train schema"))
                {
                    var saver = new BinarySaver(Host, new BinarySaver.Arguments {
                        Silent = true
                    });
                    DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream);
                }
            });

            ctx.SaveString(_featureDetachedColumn.Name);
            ctx.SaveStringOrNull(_treesColumnName);
            ctx.SaveStringOrNull(_leavesColumnName);
            ctx.SaveStringOrNull(_pathsColumnName);
        }
Esempio n. 10
0
        /// <summary>
        /// Attempt to apply the data transform to a different data view source.
        /// If the transform in question implements <see cref="ITransformTemplate"/>, <see cref="ITransformTemplate.ApplyToData"/>
        /// is called. Otherwise, the transform is serialized into a byte array and then deserialized.
        /// </summary>
        /// <param name="env">The host to use</param>
        /// <param name="transform">The transform to apply.</param>
        /// <param name="newSource">The data view to apply the transform to.</param>
        /// <returns>The resulting data view.</returns>
        public static IDataTransform ApplyTransformToData(IHostEnvironment env, IDataTransform transform, IDataView newSource)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(transform, nameof(transform));
            env.CheckValue(newSource, nameof(newSource));
            var rebindable = transform as ITransformTemplate;

            if (rebindable != null)
            {
                return(rebindable.ApplyToData(env, newSource));
            }

            // Revert to serialization.
            using (var stream = new MemoryStream())
            {
                using (var rep = RepositoryWriter.CreateNew(stream, env))
                {
                    ModelSaveContext.SaveModel(rep, transform, "model");
                    rep.Commit();
                }

                stream.Position = 0;
                using (var rep = RepositoryReader.Open(stream, env))
                {
                    IDataTransform newData;
                    ModelLoadContext.LoadModel <IDataTransform, SignatureLoadDataTransform>(env,
                                                                                            out newData, rep, "model", newSource);
                    return(newData);
                }
            }
        }
 public void Save(ModelSaveContext ctx)
 {
     Contracts.CheckValue(ctx, "ctx");
     ctx.CheckAtModel();
     ctx.SetVersionInfo(GetVersionInfo());
     Contracts.CheckValue(_predictor, "_predictor");
     ctx.SaveModel(_predictor, "predictor");
 }
Esempio n. 12
0
 protected override void SaveModel(ModelSaveContext ctx)
 {
     _host.CheckValue(ctx, "ctx");
     ctx.CheckAtModel();
     ctx.SetVersionInfo(GetVersionInfo());
     _args.Write(ctx, _host);
     ctx.SaveModel(_scorer, "scorer");
 }
Esempio n. 13
0
        public void Save(ModelSaveContext ctx)
        {
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
            var ldr = Read(new MultiFileSource(null));

            ctx.SaveModel(ldr, "Loader");
        }
Esempio n. 14
0
 public override void Save(ModelSaveContext ctx)
 {
     Host.CheckValue(_trend, "No trend predictor was ever trained. The model cannot be saved.");
     Host.CheckValue(ctx, "ctx");
     ctx.CheckAtModel();
     ctx.SetVersionInfo(GetVersionInfo());
     _args.Write(ctx, Host);
     ctx.SaveModel(_trend, "trend");
 }
        private protected override void SaveCore(ModelSaveContext ctx)
        {
            base.SaveCore(ctx);
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // int: number of classes
            ctx.Writer.Write(_numClasses);

            // Save other streams.
            for (int i = 0; i < _numClasses; i++)
            {
                int index = GetIndex(i, 0);
                ctx.SaveModel(_predictors[index + i], string.Format(SubPredictorFmt, i));
                for (int j = 0; j < i; j++)
                    ctx.SaveModel(_predictors[index + j], string.Format(SubPredictorFmt2, i, j));
            }
        }
        public virtual void Save(ModelSaveContext ctx)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();

            // *** Binary format ***
            // <nothing>

            ctx.SaveModel(Predictor, ModelFileUtils.DirPredictor);
        }
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Contracts.AssertValue(ctx);
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // model: _calibrator
            ctx.SaveModel(_calibrator, @"Calibrator");
        }
Esempio n. 18
0
            protected override void SaveCore(ModelSaveContext ctx)
            {
                Host.AssertValue(ctx);

                // *** Binary format ***
                // <base>
                // The combiner

                ctx.SaveModel(Combiner, "Combiner");
            }
            public void SaveCore(ModelSaveContext ctx, IHost host, VersionInfo versionInfo)
            {
                host.Check(Classes.Count > 0, "The model cannot be saved, it was never trained.");
                host.Check(Classes.Count == Classes.Length, "The model cannot be saved, it was never trained.");
                ctx.SetVersionInfo(versionInfo);
                ctx.Writer.WriteIntArray(Classes.Indices);
                if (LabelType == NumberDataViewType.Single)
                {
                    ctx.Writer.WriteSingleArray(Classes.Values as float[]);
                }
                else if (LabelType == NumberDataViewType.Byte)
                {
                    ctx.Writer.WriteByteArray(Classes.Values as byte[]);
                }
                else if (LabelType == NumberDataViewType.UInt16)
                {
                    ctx.Writer.WriteUIntArray((Classes.Values as ushort[]).Select(c => (uint)c).ToArray());
                }
                else if (LabelType == NumberDataViewType.UInt32)
                {
                    ctx.Writer.WriteUIntArray(Classes.Values as uint[]);
                }
                else
                {
                    throw host.Except("Unexpected type for LabelType.");
                }

                ctx.Writer.Write(_singleColumn ? 1 : 0);
                ctx.Writer.Write(_labelKey ? 1 : 0);
                var preds = Predictors;

                ctx.Writer.Write(preds.Length);
                for (int i = 0; i < preds.Length; i++)
                {
                    ctx.SaveModel(preds[i], string.Format("M2B{0}", i));
                }
                ctx.Writer.Write(_reclassificationPredictor != null ? (byte)1 : (byte)0);
                if (_reclassificationPredictor != null)
                {
                    ctx.SaveModel(_reclassificationPredictor, "Reclassification");
                }
                ctx.Writer.Write((byte)213);
            }
        public override void Save(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // _mapper

            ctx.SaveModel(_mapper, "Mapper");
        }
Esempio n. 21
0
 public void Save(ModelSaveContext ctx)
 {
     _host.CheckValue(ctx, "ctx");
     ctx.CheckAtModel();
     ctx.SetVersionInfo(GetVersionInfo());
     ctx.Writer.Write(_dataTransforms.Length);
     for (int i = 0; i < _dataTransforms.Length; ++i)
     {
         ctx.SaveModel(_dataTransforms[i], string.Format("XF{0}", i));
     }
 }
        /// <summary>
        /// Saves <paramref name="loader"/> to the specified <paramref name="stream"/>.
        /// </summary>
        public static void SaveLoader(ILegacyDataLoader loader, Stream stream)
        {
            Contracts.CheckValue(loader, nameof(loader));
            Contracts.CheckValue(stream, nameof(stream));
            Contracts.CheckParam(stream.CanWrite, nameof(stream), "Must be writable");

            using (var rep = RepositoryWriter.CreateNew(stream))
            {
                ModelSaveContext.SaveModel(rep, loader, ModelFileUtils.DirDataLoaderModel);
                rep.Commit();
            }
        }
 public void SaveTo(IHostEnvironment env, Stream outputStream)
 {
     using (var ch = env.Start("Saving pipeline"))
     {
         using (var rep = RepositoryWriter.CreateNew(outputStream, ch))
         {
             ch.Trace("Saving transformer chain");
             ModelSaveContext.SaveModel(rep, this, TransformerChain.LoaderSignature);
             rep.Commit();
         }
     }
 }
Esempio n. 24
0
        protected virtual void SaveCore(ModelSaveContext ctx)
        {
            Host.Assert(Meta != null);

            // *** Binary format ***
            // int: sizeof(Single)
            // Float: _validationDatasetProportion
            ctx.Writer.Write(sizeof(Single));
            ctx.Writer.Write(ValidationDatasetProportion);

            ctx.SaveModel(Meta, "MetaPredictor");
        }
Esempio n. 25
0
        public void Save(ModelSaveContext ctx)
        {
            _host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // ensemble

            _host.AssertValue(_ensemble);
            ctx.SaveModel(_ensemble, "Ensemble");
        }
Esempio n. 26
0
 public override void Save(ModelSaveContext ctx)
 {
     _host.CheckValue(ctx, "ctx");
     ctx.CheckAtModel();
     ctx.SetVersionInfo(GetVersionInfo());
     _args.Write(ctx, _host);
     ctx.Writer.Write(_predictor != null ? (byte)1 : (byte)0);
     ctx.Writer.Write(_cali != null ? (byte)1 : (byte)0);
     ctx.Writer.Write(_scorer != null ? (byte)1 : (byte)0);
     if (_predictor != null)
     {
         ctx.SaveModel(_predictor, "predictor");
     }
     if (_cali != null)
     {
         ctx.SaveModel(_cali, "calibrator");
     }
     if (_scorer != null)
     {
         ctx.SaveModel(_scorer, "scorer");
     }
 }
Esempio n. 27
0
        /// <summary>
        /// Saves <paramref name="loader"/> to the specified <paramref name="file"/>.
        /// </summary>
        public static void SaveLoader(IDataLoader loader, IFileHandle file)
        {
            Contracts.CheckValue(loader, nameof(loader));
            Contracts.CheckValue(file, nameof(file));
            Contracts.CheckParam(file.CanWrite, nameof(file), "Must be writable");

            using (var stream = file.CreateWriteStream())
                using (var rep = RepositoryWriter.CreateNew(stream))
                {
                    ModelSaveContext.SaveModel(rep, loader, ModelFileUtils.DirDataLoaderModel);
                    rep.Commit();
                }
        }
Esempio n. 28
0
 protected void Write(ModelSaveContext ctx)
 {
     _args.Write(ctx, _host);
     ctx.Writer.Write((byte)177);
     if (_args.serialize)
     {
         if (_predictor == null)
         {
             throw _host.Except("_predictor cannot be null.");
         }
         ctx.SaveModel(_predictor, "predictor");
     }
 }
Esempio n. 29
0
        private protected override void SaveCore(ModelSaveContext ctx)
        {
            base.SaveCore(ctx);

            // *** Binary format ***
            // int: model count
            // int: weight count (0 or model count)
            // Single[]: weights
            // for each model:
            //   int: number of SelectedFeatures (in bits)
            //   byte[]: selected features (as many as needed for number of bits == (numSelectedFeatures + 7) / 8)
            //   int: number of Metric values
            //   for each Metric:
            //     Single: metric value
            //     int: metric name (id of the metric name in the string table)

            ctx.Writer.Write(Models.Length);
            ctx.Writer.WriteSingleArray(Weights);

            // Save other streams.
            for (int i = 0; i < Models.Length; i++)
            {
                var model = Models[i];
                ctx.SaveModel(model.Predictor, string.Format(SubPredictorFmt, i));
                Host.AssertValueOrNull(model.SelectedFeatures);
                ctx.Writer.WriteBitArray(model.SelectedFeatures);
                Host.AssertValueOrNull(model.Metrics);
                int numMetrics = Utils.Size(model.Metrics);
                ctx.Writer.Write(numMetrics);
                for (int j = 0; j < numMetrics; j++)
                {
                    var metric = model.Metrics[j];
                    ctx.Writer.Write((Single)metric.Value);
                    ctx.SaveStringOrNull(metric.Key);
                }
            }
            ctx.SaveModel(Combiner, @"Combiner");
        }
Esempio n. 30
0
            public void Save(ModelSaveContext ctx)
            {
                Contracts.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel();
                ctx.SetVersionInfo(GetVersionInfo());

                // *** Binary format ***
                // byte[]: A chunk of data saving both the type and value of the label names, as saved by the BinarySaver.
                // int: string id of the metadata kind

                ctx.SaveModel(_bindable, _innerDir);
                Utils.MarshalActionInvoke(SaveCore <int>, _type.ItemType.RawType, ctx);
                ctx.SaveNonEmptyString(_metadataKind);
            }