Example #1
0
        /// <summary>
        /// Parse the shape information of a tensor.
        /// </summary>
        /// <param name="tensorShapeProto">ONNX's tensor shape.</param>
        public static IEnumerable <int> GetTensorDims(Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper.TensorShapeProto tensorShapeProto)
        {
            if (tensorShapeProto == null)
            {
                // Scalar has null dimensionality.
                return(null);
            }

            List <int> dims = new List <int>();

            foreach (var d in tensorShapeProto.Dim)
            {
                var dimValue = GetDimValue(d);
                dims.Add(dimValue);
            }

            // In ONNX, the first dimension refers to the batch size. If that is set to -1, it means OnnxRuntime can do inferencing in batches on
            // multiple rows at once. In ML.NET, a vector is considered to be of known size if the dimensions are all greater than zero
            // Leaving the batch size at -1 causes all Onnx vectors to be considered to be of unknown size. Therefore, if the first dimension is -1,
            // we need to fix up the shape. But GetDimValue above converts any dimension < 0 to be 0. We need that behavior for dimensions other than
            // the first dimension. So we check only the first dimension here and fix it up. (The '<=' comparison below is there to make sure that
            // this holds even if the behavior of GetDimValue changes).
            if ((dims.Count > 0) && (dims[0] <= 0))
            {
                dims[0] = 1;
            }

            return(dims);
        }
        /// <summary>
        /// Parse the shape information of a tensor.
        /// </summary>
        /// <param name="tensorShapeProto">ONNX's tensor shape.</param>
        public static IEnumerable <int> GetTensorDims(Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper.TensorShapeProto tensorShapeProto)
        {
            if (tensorShapeProto == null)
            {
                // Scalar has null dimensionality.
                return(null);
            }

            List <int> dims = new List <int>();

            foreach (var d in tensorShapeProto.Dim)
            {
                var dimValue = GetDimValue(d);
                dims.Add(dimValue);
            }
            return(dims);
        }