Exemplo n.º 1
0
        /// <summary>
        /// Set dimension of inputs, value infos, outputs and potential Reshape ops.
        /// Can be used to make models have dynamic batch size or different static batch sizes.
        /// </summary>
        public static void SetDim(this GraphProto graph, int dimIndex, DimParamOrValue dimParamOrValue)
        {
            // Reshape ops have their "new shape" defined as input to the reshape op.
            // This input needs to be changed to reflect new dim e.g. be set -1 if dynamic.
            var reshapeDimValue = dimParamOrValue.IsParam
                ? Ops.Reshape.DynamicReshapeValue
                : dimParamOrValue.Value;

            SetDimInReshapes(graph, dimIndex, reshapeDimValue);

            // Should we set this based on nodes instead? Handling input, outputs based on that?

            // Shapes are defined in inputs, valueInfos and outputs
            //
            // Only real inputs should be changed, not "initializer" inputs
            var initializserNames = new HashSet <string>(graph.Initializer.Select(i => i.Name));
            var inferenceInputs   = graph.Input.Where(i => !initializserNames.Contains(i.Name));

            foreach (var input in inferenceInputs)
            {
                SetDim(input, dimIndex, dimParamOrValue);
            }
            //SetDim(graph.Input, dimIndex, dimParam);

            SetDim(graph.ValueInfo, dimIndex, dimParamOrValue);
            SetDim(graph.Output, dimIndex, dimParamOrValue);
        }
Exemplo n.º 2
0
 internal static void SetDim(RepeatedField <ValueInfoProto> valueInfos,
                             int dimIndex, DimParamOrValue dimParamOrValue)
 {
     for (int i = 0; i < valueInfos.Count; i++)
     {
         var valueInfo = valueInfos[i];
         SetDim(valueInfo, dimIndex, dimParamOrValue);
     }
 }
Exemplo n.º 3
0
 internal static void SetDim(TensorShapeProto.Types.Dimension dim,
                             DimParamOrValue dimParamOrValue)
 {
     dim.ClearValue();
     if (dimParamOrValue.IsParam)
     {
         dim.DimParam = dimParamOrValue.Param;
     }
     else
     {
         dim.DimValue = dimParamOrValue.Value;
     }
 }
Exemplo n.º 4
0
        internal static void SetDim(ValueInfoProto valueInfo,
                                    int dimIndex, DimParamOrValue dimParamOrValue)
        {
            var shape = valueInfo.Type.TensorType.Shape;
            var dims  = shape.Dim;
            var dim   = dims[dimIndex];

            if (dim.ValueCase == TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue)
            {
                // TODO: Should perhaps be parameter that says
                //       bool shouldSetDimFor(dim)
                if (dim.DimValue == 1)
                {
                    SetDim(dim, dimParamOrValue);
                }
            }
        }
Exemplo n.º 5
0
 /// <summary>
 /// Set dimension of inputs, value infos, outputs and potential Reshape ops.
 /// Default sets leading dimension to dynamic batch size 'N'.
 /// </summary>
 public static void SetDim(this GraphProto graph) =>
 graph.SetDim(dimIndex: 0, DimParamOrValue.New("N"));