private static (string[], int[], bool[], TFShape[], TFDataType[]) GetInputMetaData(TFGraph graph, string[] source, ISchema inputSchema) { var tfShapes = new TFShape[source.Length]; var tfTypes = new TFDataType[source.Length]; var colNames = new string[source.Length]; var inputColIndices = new int[source.Length]; var isInputVector = new bool[source.Length]; for (int i = 0; i < source.Length; i++) { colNames[i] = source[i]; if (!inputSchema.TryGetColumnIndex(colNames[i], out inputColIndices[i])) { throw Contracts.Except($"Column '{colNames[i]}' does not exist"); } var tfoutput = new TFOutput(graph[colNames[i]]); if (!TensorFlowUtils.IsTypeSupported(tfoutput.OutputType)) { throw Contracts.Except($"Input type '{tfoutput.OutputType}' of input column '{colNames[i]}' is not supported in TensorFlow"); } tfShapes[i] = graph.GetTensorShape(tfoutput); var type = inputSchema.GetColumnType(inputColIndices[i]); var shape = tfShapes[i].ToIntArray().Skip(tfShapes[i][0] == -1 ? BatchSize : 0); if (type.AsVector.DimCount == 1) { int valCount = shape.Aggregate((x, y) => x * y); if (type.ValueCount != valCount) { throw Contracts.Except($"Input shape mismatch: Input '{colNames[i]}' has shape {tfShapes[i].ToString()}, but input data is of length {valCount}."); } } else if (shape.Select((dim, j) => dim != type.AsVector.GetDim(j)).Any(b => b)) { throw Contracts.Except($"Input shape mismatch: Input '{colNames[i]}' has shape {tfShapes[i].ToString()}, but input data is {type.AsVector.ToString()}."); } isInputVector[i] = type.IsVector; tfTypes[i] = tfoutput.OutputType; var l = new long[tfShapes[i].NumDimensions]; for (int ishape = 0; ishape < tfShapes[i].NumDimensions; ishape++) { l[ishape] = tfShapes[i][ishape] == -1 ? BatchSize : tfShapes[i][ishape]; } tfShapes[i] = new TFShape(l); } return(colNames, inputColIndices, isInputVector, tfShapes, tfTypes); }
private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] inputs, string[] outputs) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(RegistrationName)); _host.CheckValue(modelBytes, nameof(modelBytes)); Session = LoadTFSession(modelBytes); foreach (var input in inputs) { _host.CheckNonWhiteSpace(input, nameof(inputs)); if (Session.Graph[input] == null) { throw _host.ExceptParam(nameof(inputs), $"Input column '{input}' does not exist in the model"); } var tfInput = new TFOutput(Session.Graph[input]); if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType)) { throw _host.ExceptParam(nameof(modelBytes), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow"); } } var newNames = new HashSet <string>(); foreach (var output in outputs) { _host.CheckNonEmpty(output, nameof(outputs)); if (!newNames.Add(output)) { throw _host.ExceptParam(nameof(outputs), $"Output column '{output}' specified multiple times"); } if (Session.Graph[output] == null) { throw _host.ExceptParam(nameof(outputs), $"Output column '{output}' does not exist in the model"); } } Inputs = inputs; TFInputTypes = new TFDataType[Inputs.Length]; TFInputShapes = new TFShape[Inputs.Length]; for (int i = 0; i < Inputs.Length; i++) { var tfInput = new TFOutput(Graph[Inputs[i]]); TFInputTypes[i] = tfInput.OutputType; TFInputShapes[i] = Graph.GetTensorShape(tfInput); var newShape = new long[TFInputShapes[i].NumDimensions]; for (int j = 0; j < TFInputShapes[i].NumDimensions; j++) { newShape[j] = TFInputShapes[i][j] == -1 ? BatchSize : TFInputShapes[i][j]; } TFInputShapes[i] = new TFShape(newShape); } Outputs = outputs; OutputTypes = new ColumnType[Outputs.Length]; TFOutputTypes = new TFDataType[Outputs.Length]; for (int i = 0; i < Outputs.Length; i++) { var tfOutput = new TFOutput(Graph[Outputs[i]]); var shape = Graph.GetTensorShape(tfOutput); int[] dims = shape.ToIntArray().Skip(shape[0] == -1 ? BatchSize : 0).ToArray(); var type = TensorFlowUtils.Tf2MlNetType(tfOutput.OutputType); OutputTypes[i] = new VectorType(type, dims); TFOutputTypes[i] = tfOutput.OutputType; } }