1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| __global__ void softmax_row_kernel(float* input, float* output, int M, int N) { __shared__ float s_max_val; __shared__ float s_sum; int laneId = threadIdx.x % warpSize; int row = blockIdx.x; if (row >= M) return;
int iteration = CEIL(N, warpSize);
float max_val = -FLT_MAX; #pragma unroll for (int i = 0; i < iteration; i++) { int col = i * warpSize + laneId; max_val = (col < N) ? fmaxf(max_val, input[row * N + col]) : max_val; } #pragma unroll for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { max_val = fmaxf(max_val, __shfl_down_sync(0xFFFFFFFF, max_val, offset)); } if (laneId == 0) s_max_val = max_val;
float sum = 0.0f; #pragma unroll for (int i = 0; i < iteration; i++) { int col = i * warpSize + laneId; sum += (col < N) ? expf(input[row * N + col] - s_max_val) : 0.0f; } #pragma unroll for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xFFFFFFFF, sum, offset); } if (laneId == 0) s_sum = sum;
#pragma unroll for (int i = 0; i < iteration; i++) { int col = i * warpSize + laneId; if (col < N) output[row * N + col] = expf(input[row * N + col] - s_max_val) / s_sum; } }
|