public static void Load(this Tensor tensor, System.IO.BinaryReader reader) { // First, read the type var type = (ScalarType)reader.Decode(); if (type != tensor.dtype) { throw new ArgumentException("Mismatched tensor data types while loading."); } // Then, the shape var shLen = reader.Decode(); long[] loadedShape = new long[shLen]; long totalSize = 1; for (int i = 0; i < shLen; ++i) { loadedShape[i] = reader.Decode(); totalSize *= loadedShape[i]; } if (!loadedShape.SequenceEqual(tensor.shape)) { throw new ArgumentException("Mismatched tensor shape while loading."); } // // TODO: Fix this so that you can read large tensors. Right now, they are limited to 2GB // if (totalSize > int.MaxValue) { throw new NotImplementedException("Loading tensors larger than 2GB"); } tensor.SetBytes(reader.ReadBytes((int)(totalSize * tensor.ElementSize))); }