General Matrix Multiplication
GEMM 优化本身是一个非常值得讨论的课题,其优化也涉及 GPU 中优化的大多数常用的技巧。这部分以解析知乎大佬有了琦琦的棍子文章中的代码进行解读,也作为代码阅读笔记梳理整个思路。
首先,其优化技巧分块计算、shared memory 的多次利用、register 的多次利用以及各种 bank 的 conflict 解决,有的甚至会涉及到汇编层面的优化,这里有些技巧在基础篇已经讲过,就不再赘述了。
其次,简单叙述一下优化的思路,主要的思路就是对矩阵进行分块计算,不同 block 负责计算出 C 中的不同部分,同时在 block 内又让不同线程负责不同部分,这里面为了能多次利用 shared memory,需要进行多次循环,因此在 block 内有多次大循环,在大循环内又有每个线程中的多次小循环。因为涉及到把数据不断搬到 shared memory,所以作者设计了预取 prefetch 的做法,这样做可以掩盖 io 的 latency,因此也要设计哪些线程搬运哪些数据。由于可能在访问 shared memory 的时候有 bank conflict,所以也要设计哪些线程访问哪些内存。
分块计算的思路
首先如下图,对 C 进行分块:

由图可知,C 被分为 MxN/BlOCK_SIZE_M/BlOCK_SIZE_N 块,每块的大小为高 BLOCK_SIZE_M,宽 BLOCK_SIZE_N,每一块对应 A 中的相应行,对应 B 中相应列。其中一个块交给一个 block 来计算,在本次示例中,A、B、C 都是 2048 的方阵,每块的大小为 128 的方阵,因此需要 256 个 block。

对于每个 block 而言,实际上是 A 中的某些行形成的矩阵(Ai)和 B 中的某些列形成的矩阵(Bi)进行矩阵乘法得到的。如上图所示,我们可以对 Ai 中的按列进行再一次划分(Aij)和 Bi 中按照相同的规则按行再次进行划分(Bij),其实可以发现 Ai 中的第 i 列只会和 Bi 中的第 i 行进行相应的运算,其实运算结果的矩阵就是 128 的方阵,所以其实 Ai 和 Bi 的矩阵乘法结果也可以看成多个 128 方阵的累加结果。对于一个 block 来说要进行 K/BLOCK_SIZE_K=8 次循环,然后将结果进行点对点叠加得到。

对于每个 Aij 和 Bij 进行矩阵乘法的时候也不是直接按传统的 Aij 的某一行与 Bij 的某一列进行向量相乘得到最终结果的一个元素来算的,我们每次只计算 Aij 中的第 m 列和 Bi 中的第 m 行的运算结果。这实际上是有考虑 shared memory 的 bank conflict,并且虽然我们设定每个线程在每次小迭代中只计算 Aij 中的 8 个元素和 Bij 中的 8 个元素运算的结果,然而这 8 个元素并不是紧挨着的,如上如可知,我们开了 256 个线程来计算这 128x128 的矩阵,每个线程负责 64 个元素。256 个线程一共 8 个 wrap,我们需要使得每个 wrap 内不发生 bank conflict。上图展示了这个 8 个 wrap 负责的位置,以及 wrap 内线程的负责情况。下面先详细介绍一下 bank conflict:
Bank Conflict

在 shared memory 中的连续的内存被放入 32 个 bank 中,通常每 4 Bytes 一个 bank。如上图所示,一个 float 或者一个 int 占某一个 bank 中的一格,因此可以将一个 bank 理解为一个内存块,一个 bank 是一列,32 个 bank 组成 shared memory。一般 shared memory 的大小是 64 KB,也就是说一个 bank 最多有 2000 Bytes(500 行)。当我们给一个 wrap 中线程分配访问 shared memory 的地址时,也就是分配哪一个线程访问哪一个 bank 中的哪一格。什么是 bank conflict?当 同一 wrap 中 的 不同线程,访问 同一 bank 中的 不同地址 ,此时不能并行访问,只能串行,所以会导致速度下降。shared memory 中的访问可能性:
- 不同线程,访问不同 bank;这种情况是最希望的,因为可以并行访问
- 不同线程,访问同一 bank 中的同一地址(同一格),这种情况 gpu 会自动广播给所有线程,因此不影响速度
- 不同线程,访问同一 bank 中的不同地址(比如有 64 个线程,一个线程访问 4 Bytes,那么第 0 线程访问的 0,第 32 线程访问的 32,他们都属于 bank 0)这时候就会 conflict
为何不讨论不同 wrap 之间的 conflict 呢?因为不同 wrap 不能同时访问 SM 中的 shared memory。关于如何解决呢? 目前一般有两个办法:
- 对 shared memory 开辟的长度进行改变,比如 32x6,对 0 线程而言 ,tid x 6 / 32,那么 16 线程就会发现冲突。如果将其改为 32 x 7 就不会了
- 对线程访问的地址进行设计,使其不会冲突
最后我们基于他的代码进行分析和学习(源代码注释也挺清楚的了):
template <
const int BLOCK_SIZE_M, // C 中每个分块的高
const int BLOCK_SIZE_K, // A 中每次需要搬到共享内存块的宽
const int BLOCK_SIZE_N, // C 中每个分块的宽
const int THREAD_SIZE_Y, // C 中每个分块中一个线程需要计算的块的宽
const int THREAD_SIZE_X, // C 中每个分块中一个线程需要计算的块的高
const bool ENABLE_DOUBLE_BUFFER // 是否开启预取
>
__global__ void Sgemm(
float * __restrict__ A,
float * __restrict__ B,
float * __restrict__ C,
const int M,
const int N,
const int K) {
// Block 索引
int bx = blockIdx.x;
int by = blockIdx.y;
// Thread 索引
int tx = threadIdx.x;
int ty = threadIdx.y;
// 一个 block 线程的总数
const int THREAD_X_PER_BLOCK = BLOCK_SIZE_N / THREAD_SIZE_X;
const int THREAD_Y_PER_BLOCK = BLOCK_SIZE_M / THREAD_SIZE_Y;
const int THREAD_NUM_PER_BLOCK = THREAD_X_PER_BLOCK * THREAD_Y_PER_BLOCK;
// 线程在当前 block 中的索引
const int tid = ty * THREAD_X_PER_BLOCK + tx;
// 为搬运 A、B 矩阵,开了两份共享内存
__shared__ float As[2][BLOCK_SIZE_K][BLOCK_SIZE_M];
__shared__ float Bs[2][BLOCK_SIZE_K][BLOCK_SIZE_N];
// 在寄存器上开辟了记录累积结果以便把最终结果搬到 C 中
float accum[THREAD_SIZE_Y][THREAD_SIZE_X];
#pragma unroll
for(int i=0; i<THREAD_SIZE_Y; i++){
#pragma unroll
for(int j=0; j<THREAD_SIZE_X; j++){
accum[i][j]=0.0;
}
}
// 当前线程为搬运共享内存到寄存器上所开辟的内存
float frag_a[2][THREAD_SIZE_Y];
float frag_b[2][THREAD_SIZE_X];
// 这里计算每个线程每次搬运 4 个 float 需要搬运的次数(从 global 搬到 shared 上)
const int ldg_num_a = BLOCK_SIZE_M * BLOCK_SIZE_K / (THREAD_NUM_PER_BLOCK * 4);
const int ldg_num_b = BLOCK_SIZE_K * BLOCK_SIZE_N / (THREAD_NUM_PER_BLOCK * 4);
float ldg_a_reg[4*ldg_num_a];
float ldg_b_reg[4*ldg_num_b];
// 搬运一行需要的线程数
const int A_TILE_THREAD_PER_ROW = BLOCK_SIZE_K / 4;
const int B_TILE_THREAD_PER_ROW = BLOCK_SIZE_N / 4;
// 计算当前线程应该从哪一行哪一列开始搬运
const int A_TILE_ROW_START = tid / A_TILE_THREAD_PER_ROW;
const int B_TILE_ROW_START = tid / B_TILE_THREAD_PER_ROW;
const int A_TILE_COL = tid % A_TILE_THREAD_PER_ROW * 4;
const int B_TILE_COL = tid % B_TILE_THREAD_PER_ROW * 4;
// 总体搬运一次需要跨过的行数 -- stride
const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / A_TILE_THREAD_PER_ROW;
const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / B_TILE_THREAD_PER_ROW;
// 指向属于该 block 那个块的首地址
A = &A[(BLOCK_SIZE_M * by)* K];
B = &B[BLOCK_SIZE_N * bx];
// 需要该线程搬运共享内存的索引(从 shared 搬到 register 上)
const int warp_id = tid / 32;
const int lane_id = tid % 32;
const int a_tile_index = warp_id/2*16 + lane_id/8*4;
const int b_tile_index = warp_id%2*32 + lane_id%8*4;
上面这份代码主要对需要的 index 进行了计算,其中
// 需要该线程搬运共享内存的索引(从 shared 搬到 register 上)
const int warp_id = tid / 32;
const int lane_id = tid % 32;
const int a_tile_index = warp_id/2*16 + lane_id/8*4;
const int b_tile_index = warp_id%2*32 + lane_id%8*4;
这段代码的作用就是为了避免 bank conflict ,其效果如讲解分块时的图是一致的。首先,进行了计算该线程属于哪一个 wrap,在 wrap 中的局部 id 是多少,然后计算去共享内存搬运数据时从哪一个地址开始,举例说明:对于 A 矩阵需要搬运的共享内存而言,根据 wrap id 来确定大的位置(即这一片都是这个 wrap 需要搬运的数据,可参考前面分块画那个示意图)warp_id/2*16
得知每行两个 wrap,排完两个 wrap,跳(stride)16 个 foat。从 lane_id/8*4
可以看出 在 wrap 内 每行 8 个线程,也就是一共 4 行,4 行跳了 16 float,所以一个线程搬运 A 中的 4 个 float 刚好可以用 float4 这样可以访存对齐。由于开 256 个线程,一共 8 个 wrap 每行两个,一共 4 行,那才 64 float。但 A (转置后)一行有 128 个 float,所以需要搬运两次,因此一个线程搬运 8 个 float。
需要的 index 计算完成后,接下来就是数据的搬运了
// 首先将第一次大迭代相应的数据 global memory 搬运至 shared memory
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_M ; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[OFFSET(
A_TILE_ROW_START + i, // row
A_TILE_COL, // col
K)]);
As[0][A_TILE_COL][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index];
As[0][A_TILE_COL+1][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+1];
As[0][A_TILE_COL+2][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+2];
As[0][A_TILE_COL+3][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+3];
}
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
FETCH_FLOAT4(Bs[0][B_TILE_ROW_START + i][B_TILE_COL]) = FETCH_FLOAT4(B[OFFSET(
B_TILE_ROW_START + i, // row
B_TILE_COL, // col
N )]);
}
__syncthreads();
// 再把 shared memory 中的第一次小迭代数据搬到 register 中
FETCH_FLOAT4(frag_a[0][0]) = FETCH_FLOAT4(As[0][0][a_tile_index]);
FETCH_FLOAT4(frag_a[0][4]) = FETCH_FLOAT4(As[0][0][a_tile_index + 64]);
FETCH_FLOAT4(frag_b[0][0]) = FETCH_FLOAT4(Bs[0][0][b_tile_index]);
FETCH_FLOAT4(frag_b[0][4]) = FETCH_FLOAT4(Bs[0][0][b_tile_index + 64]);
可以注意到上一段代码,最后搬到 register 确实搬了两次。首先,这段代码将第一次大迭代需要的数据存入了 shared memory 里面,然后把第一次小迭代需要的数据存入了 register,其中:
FETCH_FLOAT4
:是利用 cuda 中的类型 float4 而设置的宏,很明显一次可以取 4 个 float,可以避免访存不对齐的问题;ldg_a_reg
:显然要想从 A 中搬运数据到存 A 数据的 shared memory 需要转置,此时不能直接使用FETCH_FLOAT4
接下来就是正式计算的代码了
// 作为预取数据的数组 index
int write_stage_idx = 1;
int tile_idx = 0;
do{
// 大迭代中每次的 stride 为 BLOCK_SIZE_K
tile_idx += BLOCK_SIZE_K;
// 先将下一次大迭代的 global memory 中的数据考到 register 中
if(tile_idx< K){
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_M ; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[OFFSET(
A_TILE_ROW_START + i, // row
A_TILE_COL + tile_idx, // col
K )]);
}
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
int ldg_index = i / B_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(ldg_b_reg[ldg_index]) = FETCH_FLOAT4(B[OFFSET(
tile_idx + B_TILE_ROW_START + i, // row
B_TILE_COL, // col
N )]);
}
}
// 作为加载数据的数组的 index write 为 0 load 就为 1
int load_stage_idx = write_stage_idx ^ 1;
#pragma unroll
for(int j=0; j<BLOCK_SIZE_K - 1; ++j){
// 第一次大迭代中 load_stage_idx 就为 0
// 因为在上部分代码中,我们已经进行了数据的预取
// 这里把下一次小迭代的需要数据放入 register 中
FETCH_FLOAT4(frag_a[(j+1)%2][0]) = FETCH_FLOAT4(As[load_stage_idx][(j+1)][a_tile_index]);
FETCH_FLOAT4(frag_a[(j+1)%2][4]) = FETCH_FLOAT4(As[load_stage_idx][(j+1)][a_tile_index + 64]);
FETCH_FLOAT4(frag_b[(j+1)%2][0]) = FETCH_FLOAT4(Bs[load_stage_idx][(j+1)][b_tile_index]);
FETCH_FLOAT4(frag_b[(j+1)%2][4]) = FETCH_FLOAT4(Bs[load_stage_idx][(j+1)][b_tile_index + 64]);
// 这里把第一次小迭代的数据拿出来进行计算
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) {
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) {
accum[thread_y][thread_x] += frag_a[j%2][thread_y] * frag_b[j%2][thread_x];
}
}
}
if(tile_idx < K){
// 将刚才的预取到 register 中的数据,存到 shared momery 中
// 供下一次小迭代反复使用
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_M ; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
As[write_stage_idx][A_TILE_COL][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index];
As[write_stage_idx][A_TILE_COL+1][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+1];
As[write_stage_idx][A_TILE_COL+2][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+2];
As[write_stage_idx][A_TILE_COL+3][A_TILE_ROW_START + i]=ldg_a_reg[ldg_index+3];
}
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
int ldg_index = i / B_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(Bs[write_stage_idx][B_TILE_ROW_START + i][B_TILE_COL]) = FETCH_FLOAT4(ldg_b_reg[ldg_index]);
}
__syncthreads();
// 下一次的预取索引就应该是上一次的加载索引
// 同理上面使用的预取索引就是下一次的加载索引 预取就是为了未来的加载嘛
write_stage_idx ^= 1;
}
// 把第一次小迭代的数据预取至寄存器
FETCH_FLOAT4(frag_a[0][0]) = FETCH_FLOAT4(As[load_stage_idx^1][0][a_tile_index]);
FETCH_FLOAT4(frag_a[0][4]) = FETCH_FLOAT4(As[load_stage_idx^1][0][a_tile_index + 64]);
FETCH_FLOAT4(frag_b[0][0]) = FETCH_FLOAT4(Bs[load_stage_idx^1][0][b_tile_index]);
FETCH_FLOAT4(frag_b[0][4]) = FETCH_FLOAT4(Bs[load_stage_idx^1][0][b_tile_index + 64]);
// 计算最后一次小迭代
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) {
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) {
accum[thread_y][thread_x] += frag_a[1][thread_y] * frag_b[1][thread_x];
}
}
}while(tile_idx< K);
每次大迭代中最后一次小迭代放到搬完数据的最后来做,猜测这样在执行指令的时候就可以边取存下一次迭代的数据时,边执行计算。最后把每个线程寄存器上的结果搬回 C 中
const int c_block_row = a_tile_index;
const int c_block_col = b_tile_index;
//store C00 block
for(int i=0; i<4; i++){
FETCH_FLOAT4(C[OFFSET(
BLOCK_SIZE_M * by + c_block_row + i,
BLOCK_SIZE_N * bx + c_block_col,
N)]) = FETCH_FLOAT4(accum[i][0]);
}
//store C01 block
for(int i=0; i<4; i++){
FETCH_FLOAT4(C[OFFSET(
BLOCK_SIZE_M * by + c_block_row + i,
BLOCK_SIZE_N * bx + c_block_col + 64,
N)]) = FETCH_FLOAT4(accum[i][4]);
}
//store C10 block
for(int i=0; i<4; i++){
FETCH_FLOAT4(C[OFFSET(
BLOCK_SIZE_M * by + c_block_row + 64 + i,
BLOCK_SIZE_N * bx + c_block_col,
N)]) = FETCH_FLOAT4(accum[i+4][0]);
}
//store C11 block
for(int i=0; i<4; i++){
FETCH_FLOAT4(C[OFFSET(
BLOCK_SIZE_M * by + c_block_row + 64 + i,
BLOCK_SIZE_N * bx + c_block_col + 64,
N)]) = FETCH_FLOAT4(accum[i+4][4]);
}
}
Trick
- 首先分块:如何分给 block,block 内如何分给线程
- 然后搬运数据:大迭代的数据怎么搬运、小迭代的数据怎么搬运、利用 shared memory 和 register 提高搬运数据的效率;利用预取掩盖访问 latency
- 最后如何防止 bank conflict,对各个 wrap 应该访问的地址进行了设计
To Be Continued