/// <summary> /// Set dimension of inputs, value infos, outputs and potential Reshape ops. /// Can be used to make models have dynamic batch size or different static batch sizes. /// </summary> public static void SetDim(this GraphProto graph, int dimIndex, DimParamOrValue dimParamOrValue) { // Reshape ops have their "new shape" defined as input to the reshape op. // This input needs to be changed to reflect new dim e.g. be set -1 if dynamic. var reshapeDimValue = dimParamOrValue.IsParam ? Ops.Reshape.DynamicReshapeValue : dimParamOrValue.Value; SetDimInReshapes(graph, dimIndex, reshapeDimValue); // Should we set this based on nodes instead? Handling input, outputs based on that? // Shapes are defined in inputs, valueInfos and outputs // // Only real inputs should be changed, not "initializer" inputs var initializserNames = new HashSet <string>(graph.Initializer.Select(i => i.Name)); var inferenceInputs = graph.Input.Where(i => !initializserNames.Contains(i.Name)); foreach (var input in inferenceInputs) { SetDim(input, dimIndex, dimParamOrValue); } //SetDim(graph.Input, dimIndex, dimParam); SetDim(graph.ValueInfo, dimIndex, dimParamOrValue); SetDim(graph.Output, dimIndex, dimParamOrValue); }
internal static void SetDim(RepeatedField <ValueInfoProto> valueInfos, int dimIndex, DimParamOrValue dimParamOrValue) { for (int i = 0; i < valueInfos.Count; i++) { var valueInfo = valueInfos[i]; SetDim(valueInfo, dimIndex, dimParamOrValue); } }
internal static void SetDim(TensorShapeProto.Types.Dimension dim, DimParamOrValue dimParamOrValue) { dim.ClearValue(); if (dimParamOrValue.IsParam) { dim.DimParam = dimParamOrValue.Param; } else { dim.DimValue = dimParamOrValue.Value; } }
internal static void SetDim(ValueInfoProto valueInfo, int dimIndex, DimParamOrValue dimParamOrValue) { var shape = valueInfo.Type.TensorType.Shape; var dims = shape.Dim; var dim = dims[dimIndex]; if (dim.ValueCase == TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue) { // TODO: Should perhaps be parameter that says // bool shouldSetDimFor(dim) if (dim.DimValue == 1) { SetDim(dim, dimParamOrValue); } } }
/// <summary> /// Set dimension of inputs, value infos, outputs and potential Reshape ops. /// Default sets leading dimension to dynamic batch size 'N'. /// </summary> public static void SetDim(this GraphProto graph) => graph.SetDim(dimIndex: 0, DimParamOrValue.New("N"));