예제 #1
0
            public static TensorFlowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema schema)
            {
                Contracts.CheckValue(env, nameof(env));
                env.CheckValue(ctx, nameof(ctx));
                ctx.CheckAtModel(GetVersionInfo());

                var numInputs = ctx.Reader.ReadInt32();

                Contracts.CheckDecode(numInputs > 0);

                string[] source = new string[numInputs];
                for (int j = 0; j < source.Length; j++)
                {
                    source[j] = ctx.LoadNonEmptyString();
                }

                byte[] data = null;
                if (!ctx.TryLoadBinaryStream("TFModel", r => data = r.ReadByteArray()))
                {
                    throw env.ExceptDecode();
                }

                var outputColName = ctx.LoadNonEmptyString();

                return(new TensorFlowMapper(env, schema, data, source, outputColName));
            }
예제 #2
0
        // Factory method for SignatureLoadModel.
        private static OnnxTransform Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            byte[] modelBytes = null;
            if (!ctx.TryLoadBinaryStream("OnnxModel", r => modelBytes = r.ReadByteArray()))
                throw env.ExceptDecode();

            bool supportsMultiInputOutput = ctx.Header.ModelVerWritten > 0x00010001;

            var numInputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1;

            env.CheckDecode(numInputs > 0);
            var inputs = new string[numInputs];
            for (int j = 0; j < inputs.Length; j++)
                inputs[j] = ctx.LoadNonEmptyString();

            var numOutputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1;

            env.CheckDecode(numOutputs > 0);
            var outputs = new string[numOutputs];
            for (int j = 0; j < outputs.Length; j++)
                outputs[j] = ctx.LoadNonEmptyString();

            var args = new Arguments() { InputColumns = inputs, OutputColumns = outputs };

            return new OnnxTransform(env, args, modelBytes);
        }
예제 #3
0
        // Factory method for SignatureLoadModel.
        private static OnnxTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            byte[] modelBytes = null;
            if (!ctx.TryLoadBinaryStream("OnnxModel", r => modelBytes = r.ReadByteArray()))
            {
                throw env.ExceptDecode();
            }

            bool supportsMultiInputOutput = ctx.Header.ModelVerWritten > 0x00010001;

            var numInputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1;

            env.CheckDecode(numInputs > 0);
            var inputs = new string[numInputs];

            for (int j = 0; j < inputs.Length; j++)
            {
                inputs[j] = ctx.LoadNonEmptyString();
            }

            var numOutputs = (supportsMultiInputOutput) ? ctx.Reader.ReadInt32() : 1;

            env.CheckDecode(numOutputs > 0);
            var outputs = new string[numOutputs];

            for (int j = 0; j < outputs.Length; j++)
            {
                outputs[j] = ctx.LoadNonEmptyString();
            }

            // Save custom-provided shapes. Those shapes overwrite shapes loaded from the ONNX model file.
            int customShapeInfosLength = ctx.Reader.ReadInt32(); // 0 means no custom shape. Non-zero means count of custom shapes.

            CustomShapeInfo[] loadedCustomShapeInfos = null;
            if (customShapeInfosLength > 0)
            {
                loadedCustomShapeInfos = new CustomShapeInfo[customShapeInfosLength];
                for (int i = 0; i < customShapeInfosLength; ++i)
                {
                    var name  = ctx.LoadNonEmptyString();
                    var shape = ctx.Reader.ReadIntArray();
                    loadedCustomShapeInfos[i] = new CustomShapeInfo()
                    {
                        Name = name, Shape = shape
                    };
                }
            }

            var options = new Options()
            {
                InputColumns = inputs, OutputColumns = outputs, CustomShapeInfos = loadedCustomShapeInfos
            };

            return(new OnnxTransformer(env, options, modelBytes));
        }
예제 #4
0
            private static void CheckBinaryLabel(bool user, IHostEnvironment env, IPredictorModel[] predictors)
            {
                int classCount = CheckLabelColumn(env, predictors, true);

                if (classCount != 2)
                {
                    var error = string.Format("Expected label to have exactly 2 classes, instead has {0}", classCount);
                    throw user?env.ExceptParam(nameof(predictors), error) : env.ExceptDecode(error);
                }
            }
예제 #5
0
 /// <summary>
 /// Load from the given repository entry using the default loader(s) specified in the header.
 /// </summary>
 public static void LoadModel <TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, Repository.Entry ent, string dir, params object[] extra)
     where TRes : class
 {
     Contracts.CheckValue(env, nameof(env));
     env.CheckValue(rep, nameof(rep));
     if (!TryLoadModel <TRes, TSig>(env, out result, rep, ent, dir, extra))
     {
         throw env.ExceptDecode("Couldn't load model: '{0}'", dir);
     }
 }
예제 #6
0
 /// <summary>
 /// Load an object from the repository directory.
 /// </summary>
 public static void LoadModel <TRes, TSig>(IHostEnvironment env, out TRes result, RepositoryReader rep, string dir, params object[] extra)
     where TRes : class
 {
     Contracts.CheckValue(env, nameof(env));
     env.CheckValue(rep, nameof(rep));
     if (!LoadModelOrNull <TRes, TSig>(env, out result, rep, dir, extra))
     {
         throw env.ExceptDecode("Corrupt model file");
     }
     env.AssertValue(result);
 }
            public static Bindings Create(IHostEnvironment env, ModelLoadContext ctx, Schema input, OptionalColumnTransform parent)
            {
                Contracts.AssertValue(ctx);
                Contracts.AssertValue(input);

                // *** Binary format ***
                // Schema of the data view containing the optional columns
                // int: number of added columns
                // for each added column
                //   int: id of output column name
                //   ColumnType: the type of the column

                byte[] buffer = null;
                if (!ctx.TryLoadBinaryStream("Schema.idv", r => buffer = r.ReadByteArray()))
                {
                    throw env.ExceptDecode();
                }
                BinaryLoader loader = null;
                var          strm   = new MemoryStream(buffer, writable: false);

                loader = new BinaryLoader(env, new BinaryLoader.Arguments(), strm);

                int size = ctx.Reader.ReadInt32();

                Contracts.CheckDecode(size > 0);

                var saver       = new BinarySaver(env, new BinarySaver.Arguments());
                var names       = new string[size];
                var columnTypes = new ColumnType[size];
                var srcCols     = new int[size];
                var srcColsWithOptionalColumn = new int[size];

                for (int i = 0; i < size; i++)
                {
                    names[i]       = ctx.LoadNonEmptyString();
                    columnTypes[i] = saver.LoadTypeDescriptionOrNull(ctx.Reader.BaseStream);
                    int  col;
                    bool success = input.TryGetColumnIndex(names[i], out col);
                    srcCols[i] = success ? col : -1;

                    success = loader.Schema.TryGetColumnIndex(names[i], out var colWithOptionalColumn);
                    env.CheckDecode(success);
                    srcColsWithOptionalColumn[i] = colWithOptionalColumn;
                }

                return(new Bindings(parent, columnTypes, srcCols, srcColsWithOptionalColumn, input, loader.Schema, false, names));
            }
예제 #8
0
        // Factory method for SignatureLoadModel.
        private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
            // *** Binary format ***
            // stream: tensorFlow model.
            // int: number of input columns
            // for each input column
            //   int: id of int column name
            // int: number of output columns
            // for each output column
            //   int: id of output column name
            byte[] modelBytes = null;
            if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
            {
                throw env.ExceptDecode();
            }
            var session   = TensorFlowUtils.LoadTFSession(env, modelBytes);
            var numInputs = ctx.Reader.ReadInt32();

            env.CheckDecode(numInputs > 0);
            string[] inputs = new string[numInputs];
            for (int j = 0; j < inputs.Length; j++)
            {
                inputs[j] = ctx.LoadNonEmptyString();
            }

            bool isMultiOutput = ctx.Header.ModelVerReadable >= 0x00010002;
            var  numOutputs    = 1;

            if (isMultiOutput)
            {
                numOutputs = ctx.Reader.ReadInt32();
            }

            env.CheckDecode(numOutputs > 0);
            var outputs = new string[numOutputs];

            for (int j = 0; j < outputs.Length; j++)
            {
                outputs[j] = ctx.LoadNonEmptyString();
            }

            return(new TensorFlowTransform(env, session, inputs, outputs));
        }
예제 #9
0
        // Factory method for SignatureLoadModel.
        private static OnnxTransform Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            byte[] modelBytes = null;
            if (!ctx.TryLoadBinaryStream("OnnxModel", r => modelBytes = r.ReadByteArray()))
            {
                throw env.ExceptDecode();
            }

            var inputColumn  = ctx.LoadNonEmptyString();
            var outputColumn = ctx.LoadNonEmptyString();
            var args         = new Arguments()
            {
                InputColumn = inputColumn, OutputColumn = outputColumn
            };

            return(new OnnxTransform(env, args, modelBytes));
        }
        // Factory method for SignatureLoadModel.
        private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // byte: indicator for frozen models
            // stream: tensorFlow model.
            // int: number of input columns
            // for each input column
            //   int: id of int column name
            // int: number of output columns
            // for each output column
            //   int: id of output column name
            GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen);
            if (isFrozen)
            {
                byte[] modelBytes = null;
                if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
                {
                    throw env.ExceptDecode();
                }
                return(new TensorFlowTransform(env, TensorFlowUtils.LoadTFSession(env, modelBytes), inputs, outputs, null, false));
            }

            var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), RegistrationName + "_" + Guid.NewGuid()));

            TensorFlowUtils.CreateFolderWithAclIfNotExists(env, tempDirPath);
            try
            {
                var load = ctx.TryLoadBinaryStream("TFSavedModel", br =>
                {
                    int count = br.ReadInt32();
                    for (int n = 0; n < count; n++)
                    {
                        string relativeFile = br.ReadString();
                        long fileLength     = br.ReadInt64();

                        string fullFilePath = Path.Combine(tempDirPath, relativeFile);
                        string fullFileDir  = Path.GetDirectoryName(fullFilePath);
                        if (fullFileDir != tempDirPath)
                        {
                            TensorFlowUtils.CreateFolderWithAclIfNotExists(env, fullFileDir);
                        }
                        using (var fs = new FileStream(fullFilePath, FileMode.Create, FileAccess.Write))
                        {
                            long actualRead = br.BaseStream.CopyRange(fs, fileLength);
                            env.Assert(actualRead == fileLength);
                        }
                    }
                });

                return(new TensorFlowTransform(env, TensorFlowUtils.GetSession(env, tempDirPath), inputs, outputs, tempDirPath, true));
            }
            catch (Exception)
            {
                TensorFlowUtils.DeleteFolderWithRetries(env, tempDirPath);
                throw;
            }
        }