示例#1
0
        public static void Load(this Tensor tensor, System.IO.BinaryReader reader)
        {
            // First, read the type
            var type = (ScalarType)reader.Decode();

            if (type != tensor.dtype)
            {
                throw new ArgumentException("Mismatched tensor data types while loading.");
            }

            // Then, the shape
            var shLen = reader.Decode();

            long[] loadedShape = new long[shLen];

            long totalSize = 1;

            for (int i = 0; i < shLen; ++i)
            {
                loadedShape[i] = reader.Decode();
                totalSize     *= loadedShape[i];
            }

            if (!loadedShape.SequenceEqual(tensor.shape))
            {
                throw new ArgumentException("Mismatched tensor shape while loading.");
            }

            //
            // TODO: Fix this so that you can read large tensors. Right now, they are limited to 2GB
            //
            if (totalSize > int.MaxValue)
            {
                throw new NotImplementedException("Loading tensors larger than 2GB");
            }

            tensor.SetBytes(reader.ReadBytes((int)(totalSize * tensor.ElementSize)));
        }