private static T ComputeScan <T, TScanOperation, TScanImplementation>( T value, out ArrayView <T> sharedMemory) where T : unmanaged where TScanOperation : struct, IScanReduceOperation <T> where TScanImplementation : struct, IScanImplementation <T, TScanOperation> { const int SharedMemoryLength = 32; sharedMemory = SharedMemory.Allocate <T>(SharedMemoryLength); int warpIdx = Warp.WarpIdx; TScanOperation scanOperation = default; // Initialize if (Group.DimX / Warp.WarpSize < SharedMemoryLength) { if (warpIdx < 1) { sharedMemory[Group.IdxX] = scanOperation.Identity; } Group.Barrier(); } TScanImplementation scanImplementation = default; var scannedValue = scanImplementation.Scan(value); if (Warp.IsLastLane) { sharedMemory[warpIdx] = scanImplementation.ScanRightBoundary( scannedValue, value); } Group.Barrier(); // Reduce results again in the first warp if (warpIdx < 1) { ref T sharedBoundary = ref sharedMemory[Group.IdxX]; sharedBoundary = PTXWarpExtensions.InclusiveScan <T, TScanOperation>( sharedBoundary); }
public static T AllReduce <T, TReduction>(T value) where T : unmanaged where TReduction : IScanReduceOperation <T> { // A fixed number of memory banks to distribute the workload // of the atomic operations in shared memory. const int NumMemoryBanks = 4; var sharedMemory = SharedMemory.Allocate <T>(NumMemoryBanks); var warpIdx = Warp.ComputeWarpIdx(Group.IdxX); var laneIdx = Warp.LaneIdx; TReduction reduction = default; if (warpIdx == 0) { for ( int bankIdx = laneIdx; bankIdx < NumMemoryBanks; bankIdx += Warp.WarpSize) { sharedMemory[bankIdx] = reduction.Identity; } } Group.Barrier(); value = PTXWarpExtensions.Reduce <T, TReduction>(value); if (laneIdx == 0) { reduction.AtomicApply(ref sharedMemory[warpIdx % NumMemoryBanks], value); } Group.Barrier(); // Note that this is explicitly unrolled (see NumMemoryBanks above) var result = sharedMemory[0]; result = reduction.Apply(result, sharedMemory[1]); result = reduction.Apply(result, sharedMemory[2]); result = reduction.Apply(result, sharedMemory[3]); Group.Barrier(); return(result); }
public T Scan(T value) => PTXWarpExtensions.ExclusiveScan <T, TScanOperation>(value);