1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| template <int TILE_SIZE = 32, int BLOCK_ROWS = 8> __global__ void transpose_v2(float* output, const float* input, int M, int N) { __shared__ float tile[TILE_SIZE][TILE_SIZE + 1];
int col = blockIdx.x * TILE_SIZE + threadIdx.x; int row = blockIdx.y * TILE_SIZE + threadIdx.y; int index_in = row * N + col;
for (int i = 0; i < TILE_SIZE; i += BLOCK_ROWS) { if (col < N && (row + i) < M) { tile[threadIdx.y + i][threadIdx.x] = input[index_in + i * N]; } }
__syncthreads();
col = blockIdx.y * TILE_SIZE + threadIdx.x; row = blockIdx.x * TILE_SIZE + threadIdx.y; int index_out = row * M + col;
for (int i = 0; i < TILE_SIZE; i += BLOCK_ROWS) { if (col < M && (row + i) < N) { output[index_out + i * M] = tile[threadIdx.x][threadIdx.y + i]; } } }
|