protected override Tensor CopyAndReshape(Tensor X, TensorShape newShape) { var copyShape = X.shape; Assert.AreEqual(copyShape.length, newShape.length); // NOTE: "Copy" kernel copies tensor data while preserving the shape // However here in CopyAndReshape we want to both copy and change the shape, // To be able to piggyback "Copy" kernel we specify new shape when allocating destination tensor, // but use shape identical to source when copying. var O = NewTensor(newShape); var fn = BestKernel(ComputeKernelLibrary.Copy(copyShape, copyShape)); fn.SetTensor("X", copyShape, Pin(X).buffer); fn.SetTensor("O", copyShape, Pin(O).buffer); fn.shader.SetInts("_Pad", new int[] { 0, 0, 0, 0 }); fn.Dispatch(); return(O); }
public override Tensor Concat(Tensor[] tensors, int axis) { var O = NewTensor(TensorExtensions.Concat(tensors.Select(t => t.shape).ToArray(), axis)); var offsets = new int[] { 0, 0, 0, 0 }; axis = O.shape.Axis(axis); foreach (var X in tensors) { var fn = BestKernel(ComputeKernelLibrary.Copy(X.shape, O.shape)); fn.SetTensor("X", X.shape, Pin(X).buffer); fn.SetTensor("O", O.shape, Pin(O).buffer); fn.shader.SetInts("_Pad", offsets); fn.Dispatch(); offsets[axis] += X.shape[axis]; } return(O); }