private LinearBinaryPredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { // For model version earlier than 0x00020001, there is no model statisitcs. if (ctx.Header.ModelVerWritten <= 0x00020001) { return; } // *** Binary format *** // (Base class) // LinearModelStatistics: model statistics (optional, in a separate stream) string statsDir = Path.Combine(ctx.Directory ?? "", ModelStatsSubModelFilename); using (var statsEntry = ctx.Repository.OpenEntryOrNull(statsDir, ModelLoadContext.ModelStreamName)) { if (statsEntry == null) { _stats = null; } else { using (var statsCtx = new ModelLoadContext(ctx.Repository, statsEntry, statsDir)) _stats = LinearModelStatistics.Create(Host, statsCtx); } } }
private MulticlassLogisticRegressionPredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { // *** Binary format *** // int: number of features // int: number of classes = number of biases // float[]: biases // (weight matrix, in CSR if sparse) // (see https://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000) // int: number of row start indices (_numClasses + 1 if sparse, 0 if dense) // int[]: row start indices // int: total number of column indices (0 if dense) // int[]: column index of each non-zero weight // int: total number of non-zero weights (same as number of column indices if sparse, num of classes * num of features if dense) // float[]: non-zero weights // int[]: Id of label names (optional, in a separate stream) // LinearModelStatistics: model statistics (optional, in a separate stream) _numFeatures = ctx.Reader.ReadInt32(); Host.CheckDecode(_numFeatures >= 1); _numClasses = ctx.Reader.ReadInt32(); Host.CheckDecode(_numClasses >= 1); _biases = ctx.Reader.ReadFloatArray(_numClasses); int numStarts = ctx.Reader.ReadInt32(); if (numStarts == 0) { // The weights are entirely dense. int numIndices = ctx.Reader.ReadInt32(); Host.CheckDecode(numIndices == 0); int numWeights = ctx.Reader.ReadInt32(); Host.CheckDecode(numWeights == _numClasses * _numFeatures); _weights = new VBuffer <float> [_numClasses]; for (int i = 0; i < _weights.Length; i++) { var w = ctx.Reader.ReadFloatArray(_numFeatures); _weights[i] = new VBuffer <float>(_numFeatures, w); } _weightsDense = _weights; } else { // Read weight matrix as CSR. Host.CheckDecode(numStarts == _numClasses + 1); int[] starts = ctx.Reader.ReadIntArray(numStarts); Host.CheckDecode(starts[0] == 0); Host.CheckDecode(Utils.IsSorted(starts)); int numIndices = ctx.Reader.ReadInt32(); Host.CheckDecode(numIndices == starts[starts.Length - 1]); var indices = new int[_numClasses][]; for (int i = 0; i < indices.Length; i++) { indices[i] = ctx.Reader.ReadIntArray(starts[i + 1] - starts[i]); Host.CheckDecode(Utils.IsIncreasing(0, indices[i], _numFeatures)); } int numValues = ctx.Reader.ReadInt32(); Host.CheckDecode(numValues == numIndices); _weights = new VBuffer <float> [_numClasses]; for (int i = 0; i < _weights.Length; i++) { float[] values = ctx.Reader.ReadFloatArray(starts[i + 1] - starts[i]); _weights[i] = new VBuffer <float>(_numFeatures, Utils.Size(values), values, indices[i]); } } WarnOnOldNormalizer(ctx, GetType(), Host); InputType = new VectorType(NumberType.R4, _numFeatures); OutputType = new VectorType(NumberType.R4, _numClasses); // REVIEW: Should not save the label names duplicately with the predictor again. // Get it from the label column schema metadata instead. string[] labelNames = null; if (ctx.TryLoadBinaryStream(LabelNamesSubModelFilename, r => labelNames = LoadLabelNames(ctx, r))) { _labelNames = labelNames; } string statsDir = Path.Combine(ctx.Directory ?? "", ModelStatsSubModelFilename); using (var statsEntry = ctx.Repository.OpenEntryOrNull(statsDir, ModelLoadContext.ModelStreamName)) { if (statsEntry == null) { _stats = null; } else { using (var statsCtx = new ModelLoadContext(ctx.Repository, statsEntry, statsDir)) _stats = LinearModelStatistics.Create(Host, statsCtx); } } }