1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
| __device__ void max_kernel(float* d_in, float* d_out, int N) { unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; unsigned int tid = threadIdx.x; float max_value = (idx < N) ? d_in[idx] : (-FLT_MAX);
#pragma unroll for(int offset = warpSize / 2; offset > 0; offset /= 2){ max_value = fmaxf(max_value, __shfl_down_sync(0xffffffff, max_value, offset)); }
const int laneId = tid % warpSize; const int warpId = tid / warpSize; int warpNum = blockDim.x / warpSize; __shared__ float warpLevelMaxs[warpNum]; if(laneId == 0) warpLevelMaxs[warpId] = max_value; __syncthreads();
sum = (tid < warpNum)? warpLevelMaxs[tid]:(-FLT_MAX); if (warpId == 0){ #pragma unroll for(int offset = warpSize / 2; offset > 0; offset /= 2){ max_value = fmaxf(max_value, __shfl_down_sync(0xffffffff, max_value, offset)); } } if(tid == 0) d_out[blockIdx.x] = max_value; }
__device__ void reduce_kernel(float* d_in, float* d_out, float* max_val, int N) { unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; unsigned int tid = threadIdx.x; float sum = (idx < N) ? expf(input[idx] - *max_val) : 0.0f; #pragma unroll for(int offset = warpSize / 2; offset > 0; offset /= 2){ sum += __shfl_down_sync(0xffffffff, sum, offset); }
const int laneId = tid % warpSize; const int warpId = tid / warpSize; int warpNum = blockDim.x / warpSize; __shared__ float warpLevelSums[warpNum];
if(laneId == 0) warpLevelSums[warpId] = sum; __syncthreads();
sum = (tid < warpNum)? warpLevelSums[tid]:0; if (warpId == 0){ #pragma unroll for(int offset = warpSize / 2; offset > 0; offset /= 2){ sum += __shfl_down_sync(0xffffffff, sum, offset); } } if(tid == 0) d_out[blockIdx.x] = sum; }
__global__ void softmax_kernel(float* input, float* output, float* sum, float* max_val, int N) { int idx = blockDim.x * blockIdx.x + threadIdx.x; if (idx < N) output[idx] = expf(input[idx] - *max_val) / (*sum); }
int block_size = 256; int grid_size = CEIL(N, block_size);
max_kernel<<<grid_size, block_size>>>(input, max_val, N);
reduce_kernel<<<grid_size, block_size>>>(input, sum, max_val, N); softmax_kernel<<<grid_size, block_size>>>(input, output, sum, max_val, N);
|