Exemple #1
0
        public static Tensor Invoke(CudaReduceAllKernels reduceAllKernels, float init, ReduceInitType initType, string kernelName, Tensor result, Tensor src, object extraArg = null)
        {
            int           deviceId    = CudaHelpers.GetDeviceId(src);
            TSCudaContext context     = CudaHelpers.TSContextForTensor(src);
            CudaContext   cudaContext = context.CudaContextForDevice(deviceId);

            if (src.DimensionCount > TSCudaContext.MaxDims)
            {
                throw new InvalidOperationException("Tensors with dimension count > " + TSCudaContext.MaxDims + " are not supported");
            }

            Tensor writeTarget = TensorResultBuilder.GetWriteTarget(result, src, false, 1);

            if (src.DimensionCount == 0)
            {
                return(result);
            }

            long totalElements         = src.ElementCount();
            ApplySpecialization config = new ApplySpecialization(src);
            object totalElementsTyped  = config.Use32BitIndices ? (uint)totalElements : (ulong)totalElements;
            object initValueTyped      = ReduceInitConverter.GetInitValue(init, initType, src.ElementType);

            dim3 grid;
            dim3 block;

            byte[] ptx            = reduceAllKernels.GetPtx(context.Compiler);
            string fullKernelName = PermutationGenerator.GetMangledName(kernelName, config);

            ManagedCuda.BasicTypes.CUdeviceptr outputDevicePtr = CudaHelpers.GetBufferStart(writeTarget);

            if (isTwoPassReductionSize(totalElements))
            {
                getPass1ReduceBlockGrid(context, deviceId, totalElements, out grid, out block);
                uint smemSize = block.x * sizeof(float);

                ManagedCuda.BasicTypes.CUdeviceptr scratchSpace = context.ScratchSpaceForDevice(deviceId).buffer;

                if (extraArg == null)
                {
                    InvokeReduceAll(context, cudaContext, ptx, "twoPassA_" + fullKernelName, grid, block, smemSize, config, src, totalElementsTyped, initValueTyped, scratchSpace);
                }
                else
                {
                    InvokeReduceAll(context, cudaContext, ptx, "twoPassA_" + fullKernelName, grid, block, smemSize, config, src, totalElementsTyped, initValueTyped, scratchSpace, extraArg);
                }

                uint numPass1Blocks = grid.x;
                getPass2ReduceBlockGrid(context, deviceId, totalElements, out grid, out block);
                smemSize = block.x * sizeof(float);

                InvokeReduceAllPass2(context, cudaContext, ptx, "twoPassB_" + fullKernelName, grid, block, smemSize, config.Use32BitIndices, numPass1Blocks, initValueTyped, scratchSpace, outputDevicePtr);
            }
            else
            {
                getSinglePassReduceBlockGrid(totalElements, out grid, out block);
                uint smemSize = block.x * sizeof(float);

                if (extraArg == null)
                {
                    InvokeReduceAll(context, cudaContext, ptx, "onePass_" + fullKernelName, grid, block, smemSize, config, src, totalElementsTyped, initValueTyped, outputDevicePtr);
                }
                else
                {
                    InvokeReduceAll(context, cudaContext, ptx, "onePass_" + fullKernelName, grid, block, smemSize, config, src, totalElementsTyped, initValueTyped, outputDevicePtr, extraArg);
                }
            }

            return(writeTarget);
        }