예제 #1
0
        public static Tensor <T> GetTensor <T>(this tflite.Model model, tflite.Tensor tensor)
            where T : unmanaged
        {
            var buffer = model.Buffers((int)tensor.Buffer).Value;

            return(new DenseTensor <T>(MemoryMarshal.Cast <byte, T>(buffer.GetDataBytes()).ToArray(), tensor.GetShapeArray()));
        }
예제 #2
0
        public static T GetScalar <T>(this tflite.Model model, tflite.Tensor tensor)
            where T : unmanaged
        {
            if (tensor.ShapeLength != 0)
            {
                throw new InvalidOperationException("Tensor is not a scalar");
            }
            var buffer = model.Buffers((int)tensor.Buffer).Value;

            return(MemoryMarshal.Cast <byte, T>(buffer.GetDataBytes())[0]);
        }
예제 #3
0
        public static Tensor <T> GetTensor <T>(this tflite.Model model, tflite.Tensor tensor)
            where T : unmanaged
        {
            if (typeof(T) == typeof(float) && tensor.Type != tflite.TensorType.FLOAT32)
            {
                throw new InvalidOperationException($"expect FLOAT32 tensor but got {tensor.Type}, use '--inference_type=FLOAT' when converting via toco.");
            }

            var buffer = model.Buffers((int)tensor.Buffer).Value;

            return(new DenseTensor <T>(MemoryMarshal.Cast <byte, T>(buffer.GetDataBytes()).ToArray(), tensor.GetShapeArray()));
        }