Пример #1
0
        // Factory method for SignatureLoadModel.
        private static PrincipalComponentAnalysisTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(nameof(PrincipalComponentAnalysisTransformer));

            host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            if (ctx.Header.ModelVerWritten == 0x00010001)
            {
                int cbFloat = ctx.Reader.ReadInt32();
                env.CheckDecode(cbFloat == sizeof(float));
            }
            return(new PrincipalComponentAnalysisTransformer(host, ctx));
        }
Пример #2
0
            public static ISchemaBindableMapper Create(IHostEnvironment env, ModelLoadContext ctx)
            {
                Contracts.CheckValue(env, nameof(env));
                var h = env.Register(LoaderSignature);

                h.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel(GetVersionInfo());

                // *** Binary format ***
                // byte[]: A chunk of data saving both the type and value of the label names, as saved by the BinarySaver.
                // int: string id of the metadata kind

                return(h.Apply("Loading Model", ch => new LabelNameBindableMapper(h, ctx)));
            }
Пример #3
0
        private static ExpressionTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // int: number of output columns
            // for each output column:
            //   int: number of inputs
            //   foreach input
            //     int: Id of the input column name
            //   int: Id of the expression
            //   int: Id of the output column name
            //   int: the index of the vector input (or -1)
            //   int[]: The data kinds of the input columns

            var columnCount = ctx.Reader.ReadInt32();

            env.CheckDecode(columnCount > 0);

            var columns = new ColumnInfo[columnCount];

            for (int i = 0; i < columnCount; i++)
            {
                var inputSize = ctx.Reader.ReadInt32();
                env.CheckDecode(inputSize >= 0);
                var inputColumnNames = new string[inputSize];
                for (int j = 0; j < inputSize; j++)
                {
                    inputColumnNames[j] = ctx.LoadNonEmptyString();
                }
                var expression        = ctx.LoadNonEmptyString();
                var outputColumnName  = ctx.LoadNonEmptyString();
                var vectorInputColumn = ctx.Reader.ReadInt32();
                env.CheckDecode(vectorInputColumn >= -1);

                var inputTypes = new DataViewType[inputSize];
                for (int j = 0; j < inputSize; j++)
                {
                    var dataKindIndex = ctx.Reader.ReadInt32();
                    var kind          = InternalDataKindExtensions.FromIndex(dataKindIndex);
                    inputTypes[j] = ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
                }
                var node = ExpressionEstimator.ParseAndBindLambda(env, expression, vectorInputColumn, inputTypes, out var perm);
                columns[i] = new ColumnInfo(env, inputColumnNames, inputTypes, expression, outputColumnName, vectorInputColumn, node, perm);
            }
            return(new ExpressionTransformer(env, columns));
        }
Пример #4
0
        private static IPredictorProducing <float> Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            var         predictor = new FastForestClassificationModelParameters(env, ctx);
            ICalibrator calibrator;

            ctx.LoadModelOrNull <ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
            if (calibrator == null)
            {
                return(predictor);
            }
            return(new SchemaBindableCalibratedPredictor(env, predictor, calibrator));
        }
Пример #5
0
        private static IPredictorProducing <float> Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            var         predictor = new LightGbmBinaryModelParameters(env, ctx);
            ICalibrator calibrator;

            ctx.LoadModelOrNull <ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
            if (calibrator == null)
            {
                return(predictor);
            }
            return(new ValueMapperCalibratedModelParameters <LightGbmBinaryModelParameters, ICalibrator>(env, predictor, calibrator));
        }
            public MultiCountTable(IHostEnvironment env, ModelLoadContext ctx)
                : base(env, LoaderSignature)
            {
                Host.CheckValue(ctx, nameof(ctx));

                ctx.CheckAtModel(GetVersionInfo());

                // *** Binary format ***
                // int: ColCount
                // int[]: SlotCount
                // count table (in a separate folder)

                ColCount  = ctx.Reader.ReadInt32();
                SlotCount = ctx.Reader.ReadIntArray(ColCount);
                ctx.LoadModel <CountTableBase, SignatureLoadModel>(Host, out _baseTable, "BaseTable");
            }
Пример #7
0
        public static TagViewTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, "env");
            var tagged = input as ITaggedDataView;

            if (tagged != null)
            {
                throw env.Except("The input view is already tagged. Don't tag again.");
            }
            var h = env.Register(RegistrationName);

            h.CheckValue(ctx, "ctx");
            h.CheckValue(input, "input");
            ctx.CheckAtModel(GetVersionInfo());
            return(h.Apply("Loading Model", ch => new TagViewTransform(h, ctx, input)));
        }
Пример #8
0
        // Factory method for SignatureLoadModel.
        private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            // *** Binary format ***
            // stream: tensorFlow model.
            // int: number of input columns
            // for each input column
            //   int: id of int column name
            // int: number of output columns
            // for each output column
            //   int: id of output column name
            byte[] modelBytes = null;
            if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
            {
                throw env.ExceptDecode();
            }
            var session   = TensorFlowUtils.LoadTFSession(env, modelBytes);
            var numInputs = ctx.Reader.ReadInt32();

            env.CheckDecode(numInputs > 0);
            string[] inputs = new string[numInputs];
            for (int j = 0; j < inputs.Length; j++)
            {
                inputs[j] = ctx.LoadNonEmptyString();
            }

            bool isMultiOutput = ctx.Header.ModelVerReadable >= 0x00010002;
            var  numOutputs    = 1;

            if (isMultiOutput)
            {
                numOutputs = ctx.Reader.ReadInt32();
            }

            env.CheckDecode(numOutputs > 0);
            var outputs = new string[numOutputs];

            for (int j = 0; j < outputs.Length; j++)
            {
                outputs[j] = ctx.LoadNonEmptyString();
            }

            return(new TensorFlowTransform(env, session, inputs, outputs));
        }
Пример #9
0
        // Factory for SignatureLoadModel.
        private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            var contractName = ctx.LoadString();

            object factoryObject = env.ComponentCatalog.GetExtensionValue(env, typeof(CustomMappingFactoryAttributeAttribute), contractName);

            if (!(factoryObject is ICustomMappingFactory mappingFactory))
            {
                throw env.Except($"The class with contract '{contractName}' must derive from '{typeof(CustomMappingFactory<,>).FullName}'.");
            }

            return(mappingFactory.CreateTransformer(env, contractName));
        }
Пример #10
0
        public PredictorBase(IHostEnvironment env, string registrationName, ModelLoadContext ctx)
        {
            Host = env.Register(registrationName);
            Host.AssertValue(ctx);
            var ty   = GetType();
            var meth = ty.GetMethod("GetVersionInfo", BindingFlags.Public | BindingFlags.Static);

            if (meth == null)
            {
                throw Contracts.Except($"Type '{ty}' does not have a public static method 'GetVersionInfo'.");
            }
            var over = meth.Invoke(null, null);
            var ver  = (VersionInfo)over;

            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(ver);
        }
Пример #11
0
        private static VectorToImageConvertingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);

            h.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            if (ctx.Header.ModelVerWritten <= VectorToImageConvertingTransformer.BeforeOrderVersion)
            {
                ctx.Reader.ReadFloat();
            }
            return(h.Apply("Loading Model",
                           ch =>
            {
                return new VectorToImageConvertingTransformer(h, ctx);
            }));
        }
        public static PcaTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);

            h.CheckValue(ctx, nameof(ctx));
            h.CheckValue(input, nameof(input));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // int: sizeof(Float)
            // <remainder handled in ctors>
            int cbFloat = ctx.Reader.ReadInt32();

            h.CheckDecode(cbFloat == sizeof(Float));
            return(h.Apply("Loading Model", ch => new PcaTransform(h, ctx, input)));
        }
Пример #13
0
        // Factory for SignatureLoadModel.
        private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            var contractName = ctx.LoadString();

            var composition = env.GetCompositionContainer();

            if (composition == null)
            {
                throw Contracts.Except("Unable to get the MEF composition container");
            }
            ITransformer transformer = composition.GetExportedValue <ITransformer>(contractName);

            return(transformer);
        }
Пример #14
0
        public NAFilter(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, input)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // int: sizeof(Float)
            // int: number of columns
            // int[]: ids of column names
            int cbFloat = ctx.Reader.ReadInt32();

            Host.CheckDecode(cbFloat == sizeof(Single) || cbFloat == sizeof(Double));
            int cinfo = ctx.Reader.ReadInt32();

            Host.CheckDecode(cinfo > 0);

            _infos = new ColInfo[cinfo];
            _srcIndexToInfoIndex = new Dictionary <int, int>(_infos.Length);
            var schema = Source.Schema;

            for (int i = 0; i < cinfo; i++)
            {
                string src = ctx.LoadNonEmptyString();
                int    index;
                if (!schema.TryGetColumnIndex(src, out index))
                {
                    throw Host.Except("Source column '{0}' not found", src);
                }
                if (_srcIndexToInfoIndex.ContainsKey(index))
                {
                    throw Host.Except("Source column '{0}' specified multiple times", src);
                }

                var type = schema.GetColumnType(index);
                if (!TestType(type))
                {
                    throw Host.Except("Column '{0}' does not have compatible numeric type", src);
                }

                _infos[i] = new ColInfo(index, type);
                _srcIndexToInfoIndex.Add(index, i);
            }
        }
Пример #15
0
        public static MultiClassConvertTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, "env");
            var h = env.Register(RegistrationName);

            h.CheckValue(ctx, "ctx");
            ctx.CheckAtModel(GetVersionInfo());
            h.CheckValue(input, "input");
            return(h.Apply("Loading Model",
                           ch =>
            {
                // *** Binary format ***
                // int: sizeof(Float)
                // <remainder handled in ctors>
                int cbFloat = ctx.Reader.ReadInt32();
                ch.CheckDecode(cbFloat == sizeof(float));
                return new MultiClassConvertTransform(h, ctx, input);
            }));
        }
Пример #16
0
        public static IPredictorProducing <Float> Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            var         predictor = new LinearBinaryPredictor(env, ctx);
            ICalibrator calibrator;

            ctx.LoadModelOrNull <ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
            if (calibrator == null)
            {
                return(predictor);
            }
            if (calibrator is IParameterMixer)
            {
                return(new ParameterMixingCalibratedPredictor(env, predictor, calibrator));
            }
            return(new SchemaBindableCalibratedPredictor(env, predictor, calibrator));
        }
Пример #17
0
        /// <summary>Creates instance of class from context.</summary>
        public static SkipTakeFilter Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);

            h.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // long: skip
            // long: take
            long skip = ctx.Reader.ReadInt64();

            h.CheckDecode(skip >= 0);
            long take = ctx.Reader.ReadInt64();

            h.CheckDecode(take >= 0);
            return(h.Apply("Loading Model", ch => new SkipTakeFilter(skip, take, h, input)));
        }
        // Factory method for SignatureLoadModel.
        internal DateTimeTransformer(IHostEnvironment host, ModelLoadContext ctx) :
            base(host.Register(nameof(DateTimeTransformer)))
        {
            Host.CheckValue(ctx, nameof(ctx));
            host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");
            ctx.CheckAtModel(GetVersionInfo());
            // *** Binary format ***
            // name of input column
            // column prefix
            // length of C++ state array
            // C++ byte state array

            _column = new LongTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString());

            var dataLength = ctx.Reader.ReadInt32();
            var data       = ctx.Reader.ReadByteArray(dataLength);

            _column.CreateTransformerFromSavedData(data);
        }
Пример #19
0
        public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
        {
            ctx.CheckAtModel(GetVersionInfo());
            int n = ctx.Reader.ReadInt32();

            ctx.LoadModel <IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));

            IDataView data = loader;

            for (int i = 0; i < n; i++)
            {
                var dirName = string.Format(TransformDirTemplate, i);
                ctx.LoadModel <IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
                data = xf;
            }

            _env = env;
            _xf  = data;
        }
Пример #20
0
        private static void Load(IChannel ch, ModelLoadContext ctx, CodecFactory factory, ref VBuffer <ReadOnlyMemory <char> > values)
        {
            Contracts.AssertValue(ch);
            ch.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // Codec parameterization: A codec parameterization that should be a ReadOnlyMemory codec
            // int: n, the number of bytes used to write the values
            // byte[n]: As encoded using the codec

            // Get the codec from the factory, and from the stream. We have to
            // attempt to read the codec from the stream, since codecs can potentially
            // be versioned based on their parameterization.
            IValueCodec codec;

            // This *could* happen if we have an old version attempt to read a new version.
            // Enabling this sort of binary classification is why we also need to write the
            // codec specification.
            if (!factory.TryReadCodec(ctx.Reader.BaseStream, out codec))
            {
                throw ch.ExceptDecode();
            }
            ch.AssertValue(codec);
            ch.CheckDecode(codec.Type.IsVector);
            ch.CheckDecode(codec.Type.ItemType.IsText);
            var textCodec = (IValueCodec <VBuffer <ReadOnlyMemory <char> > >)codec;

            var bufferLen = ctx.Reader.ReadInt32();

            ch.CheckDecode(bufferLen >= 0);
            using (var stream = new SubsetStream(ctx.Reader.BaseStream, bufferLen))
            {
                using (var reader = textCodec.OpenReader(stream, 1))
                {
                    reader.MoveNext();
                    values = default(VBuffer <ReadOnlyMemory <char> >);
                    reader.Get(ref values);
                }
                ch.CheckDecode(stream.ReadByte() == -1);
            }
        }
Пример #21
0
        /// <summary>
        /// Loads the entire composite data loader (loader + transforms) from the context.
        /// If there are no transforms, the underlying loader is returned.
        /// </summary>
        public static IDataLoader Create(IHostEnvironment env, ModelLoadContext ctx, IMultiStreamSource files)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);

            h.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            h.CheckValue(files, nameof(files));

            using (var ch = h.Start("Components"))
            {
                // First, load the loader.
                IDataLoader loader;
                ctx.LoadModel <IDataLoader, SignatureLoadDataLoader>(h, out loader, "Loader", files);

                // Now the transforms.
                h.Assert(!(loader is CompositeDataLoader));
                return(LoadTransforms(ctx, loader, h, x => true));
            }
        }
Пример #22
0
        // Factory method for SignatureLoadModel.
        private static OnnxTransform Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            byte[] modelBytes = null;
            if (!ctx.TryLoadBinaryStream("OnnxModel", r => modelBytes = r.ReadByteArray()))
            {
                throw env.ExceptDecode();
            }

            bool supportsMultiInputOutput = ctx.Header.ModelVerWritten > 0x00010001;

            var numInputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1;

            env.CheckDecode(numInputs > 0);
            var inputs = new string[numInputs];

            for (int j = 0; j < inputs.Length; j++)
            {
                inputs[j] = ctx.LoadNonEmptyString();
            }

            var numOutputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1;

            env.CheckDecode(numOutputs > 0);
            var outputs = new string[numOutputs];

            for (int j = 0; j < outputs.Length; j++)
            {
                outputs[j] = ctx.LoadNonEmptyString();
            }

            var args = new Arguments()
            {
                InputColumns = inputs, OutputColumns = outputs
            };

            return(new OnnxTransform(env, args, modelBytes));
        }
Пример #23
0
        // Factory method for SignatureLoadModel.
        private protected CalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx, string loaderSignature)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CalibratorTransformer <TICalibrator>)))
        {
            Contracts.AssertValue(ctx);

            _loaderSignature = loaderSignature;
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // model: _calibrator
            // scoreColumnName: _scoreColumnName
            ctx.LoadModel <TICalibrator, SignatureLoadModel>(env, out _calibrator, "Calibrator");
            if (ctx.Header.ModelVerWritten >= 0x00010002)
            {
                _scoreColumnName = ctx.LoadString();
            }
            else
            {
                _scoreColumnName = DefaultColumnNames.Score;
            }
        }
            public BindableMapper(IHostEnvironment env, ModelLoadContext ctx)
            {
                Contracts.CheckValue(env, nameof(env));
                _env = env;
                _env.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel(GetVersionInfo());

                // *** Binary format ***
                // int: topContributionsCount
                // int: bottomContributionsCount
                // bool: normalize
                // bool: stringify
                ctx.LoadModel <IFeatureContributionMapper, SignatureLoadModel>(env, out Predictor, ModelFileUtils.DirPredictor);
                GenericMapper          = ScoreUtils.GetSchemaBindableMapper(_env, Predictor, null);
                _topContributionsCount = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(0 < _topContributionsCount && _topContributionsCount <= MaxTopBottom);
                _bottomContributionsCount = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(0 < _bottomContributionsCount && _bottomContributionsCount <= MaxTopBottom);
                _normalize = ctx.Reader.ReadBoolByte();
                Stringify  = ctx.Reader.ReadBoolByte();
            }
Пример #25
0
        // Factory method for SignatureLoadModel.
        private static OnnxTransform Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            byte[] modelBytes = null;
            if (!ctx.TryLoadBinaryStream("OnnxModel", r => modelBytes = r.ReadByteArray()))
            {
                throw env.ExceptDecode();
            }

            var inputColumn  = ctx.LoadNonEmptyString();
            var outputColumn = ctx.LoadNonEmptyString();
            var args         = new Arguments()
            {
                InputColumn = inputColumn, OutputColumn = outputColumn
            };

            return(new OnnxTransform(env, args, modelBytes));
        }
Пример #26
0
        private RangeFilter(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, input)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // int: sizeof(Float)
            // int: id of column name
            // double: min
            // double: max
            // byte: complement
            int cbFloat = ctx.Reader.ReadInt32();

            Host.CheckDecode(cbFloat == sizeof(Float));

            var column = ctx.LoadNonEmptyString();
            var schema = Source.Schema;

            if (!schema.TryGetColumnIndex(column, out _index))
            {
                throw Host.Except("column", "Source column '{0}' not found", column);
            }

            _type = schema.GetColumnType(_index);
            if (_type != NumberType.R4 && _type != NumberType.R8 && _type.KeyCount == 0)
            {
                throw Host.Except("column", "Column '{0}' does not have compatible type", column);
            }

            _min = ctx.Reader.ReadDouble();
            _max = ctx.Reader.ReadDouble();
            if (!(_min <= _max))
            {
                throw Host.Except("min", "min must be less than or equal to max");
            }
            _complement = ctx.Reader.ReadBoolByte();
            _includeMin = ctx.Reader.ReadBoolByte();
            _includeMax = ctx.Reader.ReadBoolByte();
        }
Пример #27
0
        private RangeFilter(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, input)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // int: sizeof(Float)
            // int: id of column name
            // double: min
            // double: max
            // byte: complement
            int cbFloat = ctx.Reader.ReadInt32();

            Host.CheckDecode(cbFloat == sizeof(float));

            var column = ctx.LoadNonEmptyString();
            var schema = Source.Schema;

            if (!schema.TryGetColumnIndex(column, out _index))
            {
                throw Host.ExceptSchemaMismatch(nameof(schema), "source", column);
            }

            _type = schema[_index].Type;
            if (_type != NumberDataViewType.Single && _type != NumberDataViewType.Double && _type.GetKeyCount() == 0)
            {
                throw Host.ExceptSchemaMismatch(nameof(schema), "source", column, "Single, Double or Key", _type.ToString());
            }

            _min = ctx.Reader.ReadDouble();
            _max = ctx.Reader.ReadDouble();
            if (!(_min <= _max))
            {
                throw Host.Except("min", "min must be less than or equal to max");
            }
            _complement = ctx.Reader.ReadBoolByte();
            _includeMin = ctx.Reader.ReadBoolByte();
            _includeMax = ctx.Reader.ReadBoolByte();
        }
Пример #28
0
            public Transformer(IHostEnvironment env, ModelLoadContext ctx)
            {
                Contracts.CheckValue(env, nameof(env));
                _host = env.Register(nameof(Transformer));
                _host.CheckValue(ctx, nameof(ctx));

                ctx.CheckAtModel(GetVersionInfo());
                int n = ctx.Reader.ReadInt32();

                ctx.LoadModel <IDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));

                IDataView data = loader;

                for (int i = 0; i < n; i++)
                {
                    var dirName = string.Format(TransformDirTemplate, i);
                    ctx.LoadModel <IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
                    data = xf;
                }

                _xf = data;
            }
        public static ParquetLoader Create(IHostEnvironment env, ModelLoadContext ctx, IMultiStreamSource files)
        {
            Contracts.CheckValue(env, nameof(env));
            IHost host = env.Register(LoaderName);

            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            env.CheckValue(files, nameof(files));

            // *** Binary format ***
            // int: cached chunk size
            // bool: TreatBigIntegersAsDates flag

            Arguments args = new Arguments
            {
                ColumnChunkReadSize     = ctx.Reader.ReadInt32(),
                TreatBigIntegersAsDates = ctx.Reader.ReadBoolean()
            };

            return(host.Apply("Loading Model",
                              ch => new ParquetLoader(args, host, OpenStream(files))));
        }
        /// <summary>
        /// Creates an instance of the transform from a context.
        /// </summary>
        public static ITransformTemplate Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(LoaderSignature);

            host.CheckValue(ctx, nameof(ctx));

            host.CheckValue(input, nameof(input));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // int: Number of bytes the load method was serialized to
            // byte[n]: The serialized load method info
            // <arbitrary>: Arbitrary bytes saved by the save action

            var loadMethodBytes = ctx.Reader.ReadByteArray();

            host.CheckDecode(Utils.Size(loadMethodBytes) > 0);
            // Attempt to reconstruct the method.
            Exception error;
            var       loadFunc = DeserializeStaticDelegateOrNull(host, loadMethodBytes, out error);

            if (loadFunc == null)
            {
                host.AssertValue(error);
                throw error;
            }

            var bytes = ctx.Reader.ReadByteArray() ?? new byte[0];

            using (var ms = new MemoryStream(bytes))
                using (var reader = new BinaryReader(ms))
                {
                    var result = loadFunc(reader, env, input);
                    env.Check(result != null, "Load method returned null");
                    return(result);
                }
        }