private protected override void SaveCore(ModelSaveContext ctx)
        {
            base.SaveCore(ctx);
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // int: _labelCount
            // int[_labelCount]: _labelHistogram
            // int: _featureCount
            // int[_labelCount][_featureCount]: _featureHistogram
            // int[_labelCount]: _absentFeaturesLogProb
            ctx.Writer.WriteIntArray(_labelHistogram.AsSpan(0, _labelCount));
            ctx.Writer.Write(_featureCount);
            for (int i = 0; i < _labelCount; i += 1)
            {
                if (_labelHistogram[i] > 0)
                {
                    ctx.Writer.WriteIntsNoCount(_featureHistogram[i].AsSpan(0, _featureCount));
                }
            }

            ctx.Writer.WriteDoublesNoCount(_absentFeaturesLogProb.AsSpan(0, _labelCount));
        }
        public override void Save(ModelSaveContext ctx)
        {
            Host.AssertValue(ctx);

            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(Model.ToByteArray()); });

            Host.CheckNonEmpty(Inputs, nameof(Inputs));
            ctx.Writer.Write(Inputs.Length);
            foreach (var colName in Inputs)
            {
                ctx.SaveNonEmptyString(colName);
            }

            Host.CheckNonEmpty(Outputs, nameof(Outputs));
            ctx.Writer.Write(Outputs.Length);
            foreach (var colName in Outputs)
            {
                ctx.SaveNonEmptyString(colName);
            }
        }
Exemplo n.º 3
0
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // int: sizeof(Float)
            // int: id of column name
            // double: min
            // double: max
            // byte: complement
            // byte: includeMin
            // byte: includeMax
            ctx.Writer.Write(sizeof(float));
            ctx.SaveNonEmptyString(Source.Schema[_index].Name);
            Host.Assert(_min < _max);
            ctx.Writer.Write(_min);
            ctx.Writer.Write(_max);
            ctx.Writer.WriteBoolByte(_complement);
            ctx.Writer.WriteBoolByte(_includeMin);
            ctx.Writer.WriteBoolByte(_includeMax);
        }
Exemplo n.º 4
0
            public void Save(ModelSaveContext ctx)
            {
                _host.AssertValue(ctx);
                ctx.CheckAtModel();
                ctx.SetVersionInfo(GetVersionInfo());

                var buffer = new TFBuffer();

                _session.Graph.ToGraphDef(buffer);

                ctx.SaveBinaryStream("TFModel", w =>
                {
                    w.WriteByteArray(buffer.ToArray());
                });
                Contracts.AssertNonEmpty(_inputColNames);
                ctx.Writer.Write(_inputColNames.Length);
                foreach (var colName in _inputColNames)
                {
                    ctx.SaveNonEmptyString(colName);
                }

                ctx.SaveNonEmptyString(_outputColName);
            }
        /// <summary>
        /// Save model to the given context
        /// </summary>
        void ICanSaveModel.Save(ModelSaveContext ctx)
        {
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // int: number of rows (m), the limit on row
            // int: number of columns (n), the limit on column
            // int: rank of factor matrices (k)
            // float[m * k]: the left factor matrix
            // float[k * n]: the right factor matrix

            _host.Check(NumberOfRows > 0, "Number of rows must be positive");
            _host.Check(NumberOfColumns > 0, "Number of columns must be positive");
            _host.Check(ApproximationRank > 0, "Number of latent factors must be positive");
            ctx.Writer.Write(NumberOfRows);
            ctx.Writer.Write(NumberOfColumns);
            ctx.Writer.Write(ApproximationRank);
            _host.Check(Utils.Size(_leftFactorMatrix) == NumberOfRows * ApproximationRank, "Unexpected matrix size of a factor matrix (matrix P in LIBMF paper)");
            _host.Check(Utils.Size(_rightFactorMatrix) == NumberOfColumns * ApproximationRank, "Unexpected matrix size of a factor matrix (matrix Q in LIBMF paper)");
            Utils.WriteSinglesNoCount(ctx.Writer, _leftFactorMatrix);
            Utils.WriteSinglesNoCount(ctx.Writer, _rightFactorMatrix);
        }
        public void Save(ModelSaveContext ctx)
        {
            _host.AssertValue(ctx);
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
            // *** Binary format ***
            // stream: tensorFlow model.
            // int: number of input columns
            // for each input column
            //   int: id of int column name
            // int: number of output columns
            // for each output column
            //   int: id of output column name

            var buffer = new TFBuffer();

            Session.Graph.ToGraphDef(buffer);

            ctx.SaveBinaryStream("TFModel", w =>
            {
                w.WriteByteArray(buffer.ToArray());
            });
            _host.AssertNonEmpty(Inputs);
            ctx.Writer.Write(Inputs.Length);
            foreach (var colName in Inputs)
            {
                ctx.SaveNonEmptyString(colName);
            }

            _host.AssertNonEmpty(Outputs);
            ctx.Writer.Write(Outputs.Length);
            foreach (var colName in Outputs)
            {
                ctx.SaveNonEmptyString(colName);
            }
        }
Exemplo n.º 7
0
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Host.AssertValue(ctx);

            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(File.ReadAllBytes(Model.ModelStream.Name)); });

            Host.CheckNonEmpty(Inputs, nameof(Inputs));
            ctx.Writer.Write(Inputs.Length);
            foreach (var colName in Inputs)
            {
                ctx.SaveNonEmptyString(colName);
            }

            Host.CheckNonEmpty(Outputs, nameof(Outputs));
            ctx.Writer.Write(Outputs.Length);
            foreach (var colName in Outputs)
            {
                ctx.SaveNonEmptyString(colName);
            }

            // Save custom-provided shapes. Those shapes overwrite shapes loaded from the ONNX model file.
            int customShapeInfosLength = _options.CustomShapeInfos != null ? _options.CustomShapeInfos.Length : 0;

            ctx.Writer.Write(customShapeInfosLength);
            for (int i = 0; i < customShapeInfosLength; ++i)
            {
                var info = _options.CustomShapeInfos[i];
                ctx.SaveNonEmptyString(info.Name);
                ctx.Writer.WriteIntArray(info.Shape);
            }

            ctx.Writer.Write(_options.RecursionLimit);
        }
Exemplo n.º 8
0
        public override void Save(ModelSaveContext ctx)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            ctx.SetVersionInfo(GetVersionInfo());
            base.Save(ctx);

            // *** Binary format ***
            // foreach of the _labelCardinality dictionaries
            //     int: number N of elements in the dictionary.
            //     for each of the N elements:
            //         long: key
            //         Single: value

            foreach (var table in Tables)
            {
                ctx.Writer.Write(table.Count);
                foreach (var pair in table)
                {
                    ctx.Writer.Write(pair.Key);
                    Contracts.Assert(pair.Value >= 0);
                    ctx.Writer.Write(pair.Value);
                }
            }
        }
Exemplo n.º 9
0
 public void Save(ModelSaveContext ctx)
 {
     ctx.CheckAtModel();
     ctx.SetVersionInfo(GetVersionInfo());
     _transform.Save(ctx);
 }
 public void Save(ModelSaveContext ctx)
 {
     Contracts.AssertValue(ctx);
     ctx.SetVersionInfo(GetVersionInfo());
     SaveCore(ctx);
 }
Exemplo n.º 11
0
 protected override void SaveCore(ModelSaveContext ctx)
 {
     Contracts.Assert(!Normalize);
     base.SaveCore(ctx);
     ctx.SetVersionInfo(GetVersionInfo());
 }
Exemplo n.º 12
0
 void ICanSaveModel.Save(ModelSaveContext ctx)
 {
     ctx.CheckAtModel();
     ctx.SetVersionInfo(GetVersionInfo());
     ((ICanSaveModel)_transform).Save(ctx);
 }
Exemplo n.º 13
0
 private protected override void SaveCore(ModelSaveContext ctx)
 {
     base.SaveCore(ctx);
     ctx.SetVersionInfo(GetVersionInfo());
 }
        public void Save(ModelSaveContext ctx)
        {
            _host.AssertValue(ctx);
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // byte: indicator for frozen models
            // stream: tensorFlow model.
            // int: number of input columns
            // for each input column
            //   int: id of int column name
            // int: number of output columns
            // for each output column
            //   int: id of output column name
            var isFrozen = string.IsNullOrEmpty(_savedModelPath);

            ctx.Writer.WriteBoolByte(isFrozen);
            if (isFrozen)
            {
                var buffer = new TFBuffer();
                Session.Graph.ToGraphDef(buffer);
                ctx.SaveBinaryStream("TFModel", w =>
                {
                    w.WriteByteArray(buffer.ToArray());
                });
            }
            else
            {
                ctx.SaveBinaryStream("TFSavedModel", w =>
                {
                    string[] modelFilePaths = Directory.GetFiles(_savedModelPath, "*", SearchOption.AllDirectories);
                    w.Write(modelFilePaths.Length);

                    foreach (var fullPath in modelFilePaths)
                    {
                        var relativePath = fullPath.Substring(_savedModelPath.Length + 1);
                        w.Write(relativePath);

                        using (var fs = new FileStream(fullPath, FileMode.Open))
                        {
                            long fileLength = fs.Length;
                            w.Write(fileLength);
                            long actualWritten = fs.CopyRange(w.BaseStream, fileLength);
                            _host.Assert(actualWritten == fileLength);
                        }
                    }
                });
            }
            _host.AssertNonEmpty(Inputs);
            ctx.Writer.Write(Inputs.Length);
            foreach (var colName in Inputs)
            {
                ctx.SaveNonEmptyString(colName);
            }

            _host.AssertNonEmpty(Outputs);
            ctx.Writer.Write(Outputs.Length);
            foreach (var colName in Outputs)
            {
                ctx.SaveNonEmptyString(colName);
            }
        }
Exemplo n.º 15
0
 private protected override void SaveCore(ModelSaveContext ctx)
 {
     Contracts.AssertValue(ctx);
     ctx.SetVersionInfo(GetVersionInfo());
     _bindings.SaveModel(ctx);
 }
Exemplo n.º 16
0
        protected override void SaveCore(ModelSaveContext ctx)
        {
            base.SaveCore(ctx);
            ctx.SetVersionInfo(GetVersionInfo());

            Host.Assert(_biases.Length == _numClasses);
            Host.Assert(_biases.Length == _weights.Length);
#if DEBUG
            foreach (var fw in _weights)
            {
                Host.Assert(fw.Length == _numFeatures);
            }
#endif
            // *** Binary format ***
            // int: number of features
            // int: number of classes = number of biases
            // float[]: biases
            // (weight matrix, in CSR if sparse)
            // (see https://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000)
            // int: number of row start indices (_numClasses + 1 if sparse, 0 if dense)
            // int[]: row start indices
            // int: total number of column indices (0 if dense)
            // int[]: column index of each non-zero weight
            // int: total number of non-zero weights  (same as number of column indices if sparse, num of classes * num of features if dense)
            // float[]: non-zero weights
            // bool: whether label names are present
            // int[]: Id of label names (optional, in a separate stream)
            // LinearModelStatistics: model statistics (optional, in a separate stream)

            ctx.Writer.Write(_numFeatures);
            ctx.Writer.Write(_numClasses);
            ctx.Writer.WriteFloatsNoCount(_biases, _numClasses);
            // _weights == _weighsDense means we checked that all vectors in _weights
            // are actually dense, and so we assigned the same object, or it came dense
            // from deserialization.
            if (_weights == _weightsDense)
            {
                ctx.Writer.Write(0); // Number of starts.
                ctx.Writer.Write(0); // Number of indices.
                ctx.Writer.Write(_numFeatures * _weights.Length);
                foreach (var fv in _weights)
                {
                    Host.Assert(fv.Length == _numFeatures);
                    ctx.Writer.WriteFloatsNoCount(fv.Values, _numFeatures);
                }
            }
            else
            {
                // Number of starts.
                ctx.Writer.Write(_numClasses + 1);

                // Starts always starts with 0.
                int numIndices = 0;
                ctx.Writer.Write(numIndices);
                for (int i = 0; i < _weights.Length; i++)
                {
                    // REVIEW: Assuming the presence of *any* zero justifies
                    // writing in sparse format seems stupid, but might be difficult
                    // to change without changing the format since the presence of
                    // any sparse vector means we're writing indices anyway. Revisit.
                    // This is actually a bug waiting to happen: sparse/dense vectors
                    // can have different dot products even if they are logically the
                    // same vector.
                    numIndices += NonZeroCount(ref _weights[i]);
                    ctx.Writer.Write(numIndices);
                }

                ctx.Writer.Write(numIndices);
                {
                    // just scoping the count so we can use another further down
                    int count = 0;
                    foreach (var fw in _weights)
                    {
                        if (fw.IsDense)
                        {
                            for (int i = 0; i < fw.Length; i++)
                            {
                                if (fw.Values[i] != 0)
                                {
                                    ctx.Writer.Write(i);
                                    count++;
                                }
                            }
                        }
                        else
                        {
                            ctx.Writer.WriteIntsNoCount(fw.Indices, fw.Count);
                            count += fw.Count;
                        }
                    }
                    Host.Assert(count == numIndices);
                }

                ctx.Writer.Write(numIndices);

                {
                    int count = 0;
                    foreach (var fw in _weights)
                    {
                        if (fw.IsDense)
                        {
                            for (int i = 0; i < fw.Length; i++)
                            {
                                if (fw.Values[i] != 0)
                                {
                                    ctx.Writer.Write(fw.Values[i]);
                                    count++;
                                }
                            }
                        }
                        else
                        {
                            ctx.Writer.WriteFloatsNoCount(fw.Values, fw.Count);
                            count += fw.Count;
                        }
                    }
                    Host.Assert(count == numIndices);
                }
            }

            Contracts.AssertValueOrNull(_labelNames);
            if (_labelNames != null)
            {
                ctx.SaveBinaryStream(LabelNamesSubModelFilename, w => SaveLabelNames(ctx, w));
            }

            Contracts.AssertValueOrNull(_stats);
            if (_stats != null)
            {
                using (var statsCtx = new ModelSaveContext(ctx.Repository,
                                                           Path.Combine(ctx.Directory ?? "", ModelStatsSubModelFilename), ModelLoadContext.ModelStreamName))
                {
                    _stats.Save(statsCtx);
                    statsCtx.Done();
                }
            }
        }
Exemplo n.º 17
0
 private protected override void SaveModel(ModelSaveContext ctx)
 {
     Contracts.CheckValue(ctx, nameof(ctx));
     ctx.SetVersionInfo(GetVersionInfo());
     base.SaveModel(ctx);
 }
 public override void Save(ModelSaveContext ctx)
 {
     Contracts.CheckValue(ctx, nameof(ctx));
     ctx.SetVersionInfo(GetVersionInfo());
     base.Save(ctx);
 }
Exemplo n.º 19
0
        private static void Save(IChannel ch, ModelSaveContext ctx, CodecFactory factory, ref VBuffer <ReadOnlyMemory <char> > values)
        {
            Contracts.AssertValue(ch);
            ch.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // Codec parameterization: A codec parameterization that should be a ReadOnlyMemory codec
            // int: n, the number of bytes used to write the values
            // byte[n]: As encoded using the codec

            // Get the codec from the factory
            IValueCodec codec;
            var         result = factory.TryGetCodec(new VectorType(TextType.Instance), out codec);

            ch.Assert(result);
            ch.Assert(codec.Type.IsVector);
            ch.Assert(codec.Type.VectorSize == 0);
            ch.Assert(codec.Type.ItemType.RawType == typeof(ReadOnlyMemory <char>));
            IValueCodec <VBuffer <ReadOnlyMemory <char> > > textCodec = (IValueCodec <VBuffer <ReadOnlyMemory <char> > >)codec;

            factory.WriteCodec(ctx.Writer.BaseStream, codec);
            using (var mem = new MemoryStream())
            {
                using (var writer = textCodec.OpenWriter(mem))
                {
                    writer.Write(ref values);
                    writer.Commit();
                }
                ctx.Writer.WriteByteArray(mem.ToArray());
            }

            // Make this resemble, more or less, the auxiliary output from the TermTransform.
            // It will differ somewhat due to the vector being possibly sparse. To distinguish
            // between missing and empty, empties are not written at all, while missings are.
            var v = values;

            char[] buffer = null;
            ctx.SaveTextStream("Terms.txt",
                               writer =>
            {
                writer.WriteLine("# Number of terms = {0} of length {1}", v.Count, v.Length);
                foreach (var pair in v.Items())
                {
                    var text = pair.Value;
                    if (text.IsEmpty)
                    {
                        continue;
                    }
                    writer.Write("{0}\t", pair.Key);
                    // REVIEW: What about escaping this, *especially* for linebreaks?
                    // Do C# and .NET really have no equivalent to Python's "repr"? :(
                    if (text.IsEmpty)
                    {
                        writer.WriteLine();
                        continue;
                    }
                    Utils.EnsureSize(ref buffer, text.Length);

                    var span = text.Span;
                    for (int i = 0; i < text.Length; i++)
                    {
                        buffer[i] = span[i];
                    }

                    writer.WriteLine(buffer, 0, text.Length);
                }
            });
        }