CUDA wmma API中使用wmma::load

mac2026-05-22  5

由于这个点在写kernel时容易被忽略,所以写下来加强一下记忆。

使用tensor core进行计算时, 由于需要使用wmma::load_matrix_sync()从内存中将数据copy到fragement(寄存器)上,而wmma为warp-level操作所以这里会涉及bank conflict的问题,既不同threads同时访问同一bank的不同地址时内存不能及时为所有thread提供数据。

需要注意的是,由于warp与warp之间已经通过硬件协调,所以其即使访问同一bank的不同地址也不用开发者关心bank confilct(毕竟warp之间只是逻辑上并行)。

一、申请shared memory空间时的地址分布

若以声明一个二维数组的形式申请一块shared memory,连续的地址会被分布到不同bank上。例如

__shared__ half cache[][num_bank];

那么half[0][0]half[0][1], half[0][2]half[0][3], ..., half[0][30]half[0][31]将被优先存储在不同的bank上(官方文档:连续4bytes会被分配到一个bank),这样也一定程度上避免了不同线程访问同一bank,直到所有bank上都分配了数据再从bank0开始分配。

二、减少bank confilct的发生

虽然我们无从知道wmma::load_matrix_sync()是如何将复制任务分配给32个thread的,但是通过减少数据存储的规律性,例如尽量不如不同行的同一位置存在同一个bank(其实就是避免不同行或列存在同一个bank),就能减少bank conflict的可能性,因为warp从内存中读取数据无非就是各thread同时读取行或列中各个数据。

所以在申请shared memory时如果读取的行/列刚好为32*4 bytes(每个bank存4 bytes)的倍数,就需要注意在接着存放下一行/列数据时不要再从bank0开始存了。所以方法就是先shift一段距离再开始存放。例如

__shared__ half cache[][num_bank+8];

为什么这里是8呢,不是偏移4个字节就行了吗?

因为对于wmma::load_matrix_sync(),要求fragment的leading dimension必须满足16字节对齐,而8个half刚好16字节。

最新回复(0)