示例#1
0
        public static void Save(ModelSaveContext ctx, ref VBuffer <DvText> names)
        {
            Contracts.AssertValue(ctx);
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // int: number of features (size)
            // int: number of indices (-1 if dense)
            // int[]: indices (if sparse)
            // int[]: ids of names (matches either number of features or number of indices)

            ctx.Writer.Write(names.Length);
            if (names.IsDense)
            {
                ctx.Writer.Write(-1);
                for (int i = 0; i < names.Length; i++)
                {
                    ctx.SaveStringOrNull(names.Values[i].ToString());
                }
            }
            else
            {
                ctx.Writer.Write(names.Count);
                for (int ii = 0; ii < names.Count; ii++)
                {
                    ctx.Writer.Write(names.Indices[ii]);
                }
                for (int ii = 0; ii < names.Count; ii++)
                {
                    ctx.SaveStringOrNull(names.Values[ii].ToString());
                }
            }
        }
示例#2
0
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // model: prediction model.
            // stream: empty data view that contains train schema.
            // string: feature column name.
            // string: the name of the columns where tree prediction values are stored.
            // string: the name of the columns where trees' leave are stored.
            // string: the name of the columns where trees' paths are stored.

            ctx.SaveModel(Model, DirModel);
            ctx.SaveBinaryStream(DirTransSchema, writer =>
            {
                using (var ch = Host.Start("Saving train schema"))
                {
                    var saver = new BinarySaver(Host, new BinarySaver.Arguments {
                        Silent = true
                    });
                    DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream);
                }
            });

            ctx.SaveString(_featureDetachedColumn.Name);
            ctx.SaveStringOrNull(_treesColumnName);
            ctx.SaveStringOrNull(_leavesColumnName);
            ctx.SaveStringOrNull(_pathsColumnName);
        }
示例#3
0
        private protected override void SaveCore(ModelSaveContext ctx)
        {
            Contracts.AssertValue(ctx);
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // <base info>
            // id of string: train label column
            base.SaveCore(ctx);

            ctx.SaveStringOrNull(_trainLabelColumn);
            ctx.SaveStringOrNull(_scoreColumn);
            ctx.SaveStringOrNull(_predictedLabelColumn);
        }
        private static void SaveCore(ModelSaveContext ctx, Action <ModelSaveContext> loaderSaveAction, TransformEx[] transforms)
        {
            Contracts.AssertValue(ctx);
            Contracts.AssertValue(loaderSaveAction);
            Contracts.AssertValueOrNull(transforms);

            // *** Binary format ***
            // int: sizeof(Float)
            // int: number of transforms
            // foreach transform: (starting from version VersionAddedTags)
            //     string: tag
            //     string: args string

            ctx.Writer.Write(sizeof(Float));
            ctx.Writer.Write(transforms.Length);

            using (var loaderCtx = new ModelSaveContext(ctx.Repository, Path.Combine(ctx.Directory ?? "", "Loader"), ModelLoadContext.ModelStreamName))
            {
                loaderSaveAction(loaderCtx);
                loaderCtx.Done();
            }

            for (int i = 0; i < transforms.Length; i++)
            {
                var dirName = string.Format(TransformDirTemplate, i);
                ctx.SaveModel(transforms[i].Transform, dirName);

                Contracts.AssertNonEmpty(transforms[i].Tag);
                ctx.SaveNonEmptyString(transforms[i].Tag);
                ctx.SaveStringOrNull(transforms[i].ArgsString);
            }
        }
示例#5
0
        public virtual void Save(ModelSaveContext ctx)
        {
            // *** Binary format **
            // int: Id of the score column name
            // int: Id of the label column name

            ctx.SaveNonEmptyString(ScoreCol);
            ctx.SaveStringOrNull(LabelCol);
        }
示例#6
0
 public void Save(ModelSaveContext ctx)
 {
     // *** Binary format ***
     // int: the stopwords list language
     // int: the id of languages column name
     ctx.Writer.Write((int)Lang);
     Contracts.Assert((LangsColIndex >= 0 && _langsColName != null) ||
                      (LangsColIndex == -1 && _langsColName == null));
     ctx.SaveStringOrNull(_langsColName);
 }
示例#7
0
            internal void Save(ModelSaveContext ctx)
            {
                Contracts.AssertValue(ctx);

                // *** Binary format ***
                // int: number of columns
                // foreach column:
                //   int: id of column name
                //   byte: DataKind
                //   byte: bool of whether this is a key type
                //   for a key type:
                //     ulong: count for key range
                //   int: number of segments
                //   foreach segment:
                //     string id: name
                //     int: min
                //     int: lim
                //     byte: force vector (verWrittenCur: verIsVectorSupported)
                ctx.Writer.Write(Infos.Length);
                for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                {
                    var info = Infos[iinfo];
                    ctx.SaveNonEmptyString(info.Name);
                    var type = info.ColType.GetItemType();
                    InternalDataKind rawKind = type.GetRawKind();
                    Contracts.Assert((InternalDataKind)(byte)rawKind == rawKind);
                    ctx.Writer.Write((byte)rawKind);
                    ctx.Writer.WriteBoolByte(type is KeyDataViewType);
                    if (type is KeyDataViewType key)
                    {
                        ctx.Writer.Write(key.Count);
                    }

                    if (info.Segments is null)
                    {
                        ctx.Writer.Write(0);
                    }
                    else
                    {
                        ctx.Writer.Write(info.Segments.Length);
                        foreach (var seg in info.Segments)
                        {
                            ctx.SaveStringOrNull(seg.Name);
                            ctx.Writer.Write(seg.Min);
                            ctx.Writer.Write(seg.Lim);
                            ctx.Writer.WriteBoolByte(seg.ForceVector);
                        }
                    }
                }
            }
        public void Save(ModelSaveContext ctx)
        {
            // *** Binary format ***
            // int: Number of trees
            // Regression trees (num trees of these)
            // double: Bias
            // int: Id to InputInitializationContent string

            BinaryWriter writer = ctx.Writer;

            writer.Write(NumTrees);
            foreach (RegressionTree tree in Trees)
            {
                tree.Save(ctx);
            }
            writer.Write(Bias);
            ctx.SaveStringOrNull(_firstInputInitializationContent);
        }
示例#9
0
        protected virtual void SaveCore(ModelSaveContext ctx)
        {
            // *** Binary format ***
            // model: prediction model.
            // stream: empty data view that contains train schema.
            // id of string: feature column.

            ctx.SaveModel(Model, DirModel);
            ctx.SaveBinaryStream(DirTransSchema, writer =>
            {
                using (var ch = Host.Start("Saving train schema"))
                {
                    var saver = new BinarySaver(Host, new BinarySaver.Arguments {
                        Silent = true
                    });
                    DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream);
                }
            });

            ctx.SaveStringOrNull(FeatureColumn);
        }
            public void Save(ModelSaveContext ctx)
            {
                Contracts.AssertValue(ctx);

                // *** Binary format ***
                // int: count of group column infos (ie, count of source columns)
                // For each group column info
                //     int: the tokenizer language
                //     int: the id of source column name
                //     int: the id of languages column name
                //     bool: whether the types output is required
                //     For each column info that belongs to this group column info
                //     (either one column info for tokens or two for tokens and types)
                //          int: the id of the column name

                Contracts.Assert(Utils.Size(GroupInfos) > 0);
                ctx.Writer.Write(GroupInfos.Length);

                int iinfo = 0;
                for (int groupId = 0; groupId < GroupInfos.Length; groupId++)
                {
                    var groupInfo = GroupInfos[groupId];

                    Contracts.Assert(Enum.IsDefined(typeof(Language), groupInfo.Lang));
                    ctx.Writer.Write((int)groupInfo.Lang);
                    ctx.SaveNonEmptyString(groupInfo.SrcColName);
                    ctx.SaveStringOrNull(groupInfo.LangsColName);
                    ctx.Writer.WriteBoolByte(groupInfo.RequireTypes);

                    int count = groupInfo.RequireTypes ? 2 : 1;
                    int lim = iinfo + count;
                    for (; iinfo < lim; iinfo++)
                    {
                        Contracts.Assert(Infos[iinfo].GroupInfoId == groupId);
                        ctx.SaveNonEmptyString(GetColumnNameCore(iinfo));
                    }
                }
                Contracts.Assert(iinfo == Infos.Length);
            }
示例#11
0
        private protected override void SaveCore(ModelSaveContext ctx)
        {
            base.SaveCore(ctx);

            // *** Binary format ***
            // int: model count
            // int: weight count (0 or model count)
            // Single[]: weights
            // for each model:
            //   int: number of SelectedFeatures (in bits)
            //   byte[]: selected features (as many as needed for number of bits == (numSelectedFeatures + 7) / 8)
            //   int: number of Metric values
            //   for each Metric:
            //     Single: metric value
            //     int: metric name (id of the metric name in the string table)

            ctx.Writer.Write(Models.Length);
            ctx.Writer.WriteSingleArray(Weights);

            // Save other streams.
            for (int i = 0; i < Models.Length; i++)
            {
                var model = Models[i];
                ctx.SaveModel(model.Predictor, string.Format(SubPredictorFmt, i));
                Host.AssertValueOrNull(model.SelectedFeatures);
                ctx.Writer.WriteBitArray(model.SelectedFeatures);
                Host.AssertValueOrNull(model.Metrics);
                int numMetrics = Utils.Size(model.Metrics);
                ctx.Writer.Write(numMetrics);
                for (int j = 0; j < numMetrics; j++)
                {
                    var metric = model.Metrics[j];
                    ctx.Writer.Write((Single)metric.Value);
                    ctx.SaveStringOrNull(metric.Key);
                }
            }
            ctx.SaveModel(Combiner, @"Combiner");
        }
示例#12
0
 private protected virtual void SaveCore(ModelSaveContext ctx)
 {
     SaveModelCore(ctx);
     ctx.SaveStringOrNull(FeatureColumnName);
 }