Ejemplo n.º 1
0
            private static (string[], int[], bool[], TFShape[], TFDataType[]) GetInputMetaData(TFGraph graph, string[] source, ISchema inputSchema)
            {
                var tfShapes        = new TFShape[source.Length];
                var tfTypes         = new TFDataType[source.Length];
                var colNames        = new string[source.Length];
                var inputColIndices = new int[source.Length];
                var isInputVector   = new bool[source.Length];

                for (int i = 0; i < source.Length; i++)
                {
                    colNames[i] = source[i];
                    if (!inputSchema.TryGetColumnIndex(colNames[i], out inputColIndices[i]))
                    {
                        throw Contracts.Except($"Column '{colNames[i]}' does not exist");
                    }

                    var tfoutput = new TFOutput(graph[colNames[i]]);
                    if (!TensorFlowUtils.IsTypeSupported(tfoutput.OutputType))
                    {
                        throw Contracts.Except($"Input type '{tfoutput.OutputType}' of input column '{colNames[i]}' is not supported in TensorFlow");
                    }

                    tfShapes[i] = graph.GetTensorShape(tfoutput);
                    var type  = inputSchema.GetColumnType(inputColIndices[i]);
                    var shape = tfShapes[i].ToIntArray().Skip(tfShapes[i][0] == -1 ? BatchSize : 0);
                    if (type.AsVector.DimCount == 1)
                    {
                        int valCount = shape.Aggregate((x, y) => x * y);
                        if (type.ValueCount != valCount)
                        {
                            throw Contracts.Except($"Input shape mismatch: Input '{colNames[i]}' has shape {tfShapes[i].ToString()}, but input data is of length {valCount}.");
                        }
                    }
                    else if (shape.Select((dim, j) => dim != type.AsVector.GetDim(j)).Any(b => b))
                    {
                        throw Contracts.Except($"Input shape mismatch: Input '{colNames[i]}' has shape {tfShapes[i].ToString()}, but input data is {type.AsVector.ToString()}.");
                    }

                    isInputVector[i] = type.IsVector;

                    tfTypes[i] = tfoutput.OutputType;

                    var l = new long[tfShapes[i].NumDimensions];
                    for (int ishape = 0; ishape < tfShapes[i].NumDimensions; ishape++)
                    {
                        l[ishape] = tfShapes[i][ishape] == -1 ? BatchSize : tfShapes[i][ishape];
                    }
                    tfShapes[i] = new TFShape(l);
                }
                return(colNames, inputColIndices, isInputVector, tfShapes, tfTypes);
            }
        private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] inputs, string[] outputs)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(RegistrationName));
            _host.CheckValue(modelBytes, nameof(modelBytes));
            Session = LoadTFSession(modelBytes);
            foreach (var input in inputs)
            {
                _host.CheckNonWhiteSpace(input, nameof(inputs));
                if (Session.Graph[input] == null)
                {
                    throw _host.ExceptParam(nameof(inputs), $"Input column '{input}' does not exist in the model");
                }
                var tfInput = new TFOutput(Session.Graph[input]);
                if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType))
                {
                    throw _host.ExceptParam(nameof(modelBytes), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow");
                }
            }

            var newNames = new HashSet <string>();

            foreach (var output in outputs)
            {
                _host.CheckNonEmpty(output, nameof(outputs));
                if (!newNames.Add(output))
                {
                    throw _host.ExceptParam(nameof(outputs), $"Output column '{output}' specified multiple times");
                }
                if (Session.Graph[output] == null)
                {
                    throw _host.ExceptParam(nameof(outputs), $"Output column '{output}' does not exist in the model");
                }
            }

            Inputs        = inputs;
            TFInputTypes  = new TFDataType[Inputs.Length];
            TFInputShapes = new TFShape[Inputs.Length];
            for (int i = 0; i < Inputs.Length; i++)
            {
                var tfInput = new TFOutput(Graph[Inputs[i]]);
                TFInputTypes[i]  = tfInput.OutputType;
                TFInputShapes[i] = Graph.GetTensorShape(tfInput);
                var newShape = new long[TFInputShapes[i].NumDimensions];
                for (int j = 0; j < TFInputShapes[i].NumDimensions; j++)
                {
                    newShape[j] = TFInputShapes[i][j] == -1 ? BatchSize : TFInputShapes[i][j];
                }
                TFInputShapes[i] = new TFShape(newShape);
            }

            Outputs       = outputs;
            OutputTypes   = new ColumnType[Outputs.Length];
            TFOutputTypes = new TFDataType[Outputs.Length];
            for (int i = 0; i < Outputs.Length; i++)
            {
                var   tfOutput = new TFOutput(Graph[Outputs[i]]);
                var   shape    = Graph.GetTensorShape(tfOutput);
                int[] dims     = shape.ToIntArray().Skip(shape[0] == -1 ? BatchSize : 0).ToArray();
                var   type     = TensorFlowUtils.Tf2MlNetType(tfOutput.OutputType);
                OutputTypes[i]   = new VectorType(type, dims);
                TFOutputTypes[i] = tfOutput.OutputType;
            }
        }