CUDA优化之LayerNorm性能优化实践
撰文 | 郭冉、姚迟、郑泽康、柳俊丞
以 PyTorch 为例,LayerNorm 的接口为:
torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None)
其中 input 形状为:[∗, normalized_shape[0], normalized_shape[1], …,normalized_shape[−1]]
LayerNorm 中求方差的方法
1.two-pass方法
使用的公式是:
这种方法是一种 single pass 方法,在计算方差时只需要遍历一遍数据累加 x 的平方及累加 x,最后按上述公式计算得到方差。这种方法只需要遍历一遍数据,相比 two-pass 的算法,更容易达到好的性能,但是上面的 Wiki 参考链接中介绍由于 SumSquare 和 (Sum×Sum)/n 可能非常接近,可能会导致计算结果损失精度较大,因此这种方法不建议在实践中使用。
3.Welford 算法
使用的公式是:
Welford 算法也是一种 single pass 方法,且数值稳定性很好,因此现在很多框架都采用这种方法。本文的代码中采用的也是 Welford 方法。
OneFlow 深度优化 LayerNorm CUDA Kernel 的技巧
和 Softmax 一样,LayerNorm 也采用分段函数优化,对于不同的 num_cols 范围,采用不同的实现,以在各种情况下都能达到较高的有效带宽。
在每种实现中都采用了一个公共的优化:向量化访存,NVIDIA 性能优化的博客 Increase Performance with Vectorized Memory Access 中提到可以通过向量化内存操作来提高 CUDA Kernel 性能,很多 CUDA Kernel 都是带宽受限的,使用向量化内存操作可以减少总的指令数,减少延迟,提高带宽利用率。
理论上来说,在计算 LayerNorm 的过程中,输入 x 需要被读两次,第一次用于计算均值和方差。第二次用于得到均值和方差后的计算过程。而对 Global Memory 的访问操作是昂贵的,如果能将输入 x 先存起来,不重复读,就可以提升性能。在 GPU 中将输入 x 存起来可以使用寄存器或 Shared memory,但是寄存器资源和 Shared memory 资源都是有限的,如果 num_cols 过大,就会超出资源的使用限制,因此我们针对不同 num_cols 采用不同的实现,下面分别进行介绍:
1.num_cols <= 1024 的情况
针对 num_cols <= 1024 的情况,以 Warp 为单位处理一行或两行,将输入 x 存储到寄存器中。
WelfordWarpAllReduce 由 WelfordWarpReduce 和 Broadcast 操作完成,WelfordWarpReduce 借助 Warp 级别同步原语 __shfl_down_sync 实现,Broadcast操作借助 __shfl_sync 实现,代码如下:
template
T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpReduce(T thread_mean, T thread_m2, T thread_count, T* mean,
T* m2, T* count) {
*mean = thread_mean;
*m2 = thread_m2;
*count = thread_count;
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
T b_mean = __shfl_down_sync(0xffffffff, *mean, mask);
T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask);
T b_count = __shfl_down_sync(0xffffffff, *count, mask);
WelfordCombine(b_mean, b_m2, b_count, mean, m2, count);
}
}
templateT, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(T thread_mean, T thread_m2, T thread_count, T* mean,
T* m2, T* count) {
WelfordWarpReduce<T, thread_group_width>(thread_mean, thread_m2, thread_count, mean, m2, count);
*mean = __shfl_sync(0xffffffff, *mean, 0, thread_group_width);
*m2 = __shfl_sync(0xffffffff, *m2, 0, thread_group_width);
*count = __shfl_sync(0xffffffff, *count, 0, thread_group_width);
}
在这里有个模板参数 thread_group_width,当 num_cols > pack_size * WarpSize 时,thread_group_width 为 WarpSize。当 num_cols 太小,即 num_cols
将 pack_size 个元素 pack 成更大的数据类型读入,但是 x 还要参与计算。因此我们定义一个union 结构的 Pack 类型,storage 用于从 Global Memory中读写,做计算时用 elem[i] 取每个元素参与计算,Pack 类型定义如下:
template<typename T, int N>
union Pack {
PackTypestorage;
T elem[N];
};
LayerNormWarpImpl Kernel 代码如下:
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, bool padding>
__global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
static_assert(cols_per_thread % pack_size == 0, "");
static_assert(thread_group_width <= kWarpSize, "");
static_assert(kWarpSize % thread_group_width == 0, "");
constexpr int num_packs = cols_per_thread / pack_size;
assert(cols <= cols_per_thread * thread_group_width);
ComputeType buf[rows_per_access][cols_per_thread];
const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int64_t num_global_thread_group = gridDim.x * blockDim.y;
const int64_t lane_id = threadIdx.x;
for (int64_t row = global_thread_group_id * rows_per_access; row < rows;
row += num_global_thread_group * rows_per_access) {
ComputeType thread_mean[rows_per_access];
ComputeType thread_m2[rows_per_access];
ComputeType thread_count[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
thread_mean[row_id] = 0;
thread_m2[row_id] = 0;
thread_count[row_id] = 0;
ComputeType* row_buf = buf[row_id];
#pragma unroll
for (int pack_id = 0; pack_id < num_packs; ++pack_id) {
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
const int pack_offset = pack_id * pack_size;
if (!padding || col < cols) {
load.template load(row_buf + pack_offset, row + row_id, col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id, thread_m2 + row_id,
thread_count + row_id);
}
} else {
#pragma unroll
for (int i = 0; i < pack_size; ++i) { row_buf[pack_offset + i] = 0; }
}
}
}
ComputeType warp_mean[rows_per_access];
ComputeType warp_m2[rows_per_access];
ComputeType warp_count[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
int global_row_id = row + row_id;
ComputeType* row_buf = buf[row_id];
WelfordWarpAllReduce(
thread_mean[row_id], thread_m2[row_id], thread_count[row_id], warp_mean + row_id,
warp_m2 + row_id, warp_count + row_id);
ComputeType row_mean = warp_mean[row_id];
ComputeType row_variance =
max(Div(warp_m2[row_id], warp_count[row_id]), static_cast(0.0));
ComputeType row_inv_var = Rsqrt(row_variance + static_cast(epsilon));
if (lane_id == 0) {
mean[global_row_id] = row_mean;
inv_variance[global_row_id] = row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_buf[i] = (row_buf[i] - row_mean) * row_inv_var;
}
#pragma unroll
for (int i = 0; i < num_packs; ++i) {
const int col = (i * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
store.template store(row_buf + i * pack_size, global_row_id, col);
}
}
}
}
}
LOAD、STORE 分别代表输入输出,使用load.template load (ptr, row_id, col_id); 和 store.template store(ptr, row_id, col_id); 进行读取和写入。使用 LOAD 和 STORE 有两个好处:a) 可以在 CUDA Kernel中只关心计算类型 ComputeType,而不用关心具体的数据类型 T。b) 只需要加几行代码就可以快速支持 LayerNorm 和其他 Kernel Fuse,减少带宽需求,提升整体性能。ComputeType 代表计算类型。pack_size 代表向量化访存操作的 pack 元素的个数,我们将几个元素 pack 起来读写,提升带宽利用率。 cols_per_thread 代表每个线程处理的元素个数。 thread_group_width 代表处理元素的线程组的宽度,当 cols > pack_size * warp_size 时,thread_group_width 就是warp_size,即32。当 cols < pack_size * warp_size 时,就根据 cols 大小用 1/2个warp 或 1/4个warp 来处理每行的元素。采用更小的 thread_group_width 后,WarpAllReduce需要执行的轮次也相应减少。 rows_per_access 代表每个 thread_group 一次处理的行数,当 cols 较小且 thread_group_width 小于warp_size时,若 rows 能被2整除,我们就让每个线程处理2行来增加指令并行度,从而提升性能。 padding 代表当前是否做了 padding,若 cols 不是 warp_size 的整数倍,我们会把它padding 到最近的整数倍处理。
2.num_cols > 1024 的情况
LayerNormBlockSMemImpl Kernel的代码如下:
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
__global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto* buf = reinterpret_cast(shared_buf);
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = cols / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_mean = 0;
ComputeType thread_m2 = 0;
ComputeType thread_count = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
buf[i * num_packs + pack_id] = pack[i];
WelfordCombine(pack[i], &thread_mean, &thread_m2, &thread_count);
}
}
ComputeType row_mean = 0;
ComputeType row_m2 = 0;
ComputeType row_count = 0;
WelfordBlockAllReduce(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,
&row_count);
ComputeType row_variance = max(Div(row_m2, row_count), static_cast(0.0));
ComputeType row_inv_var = Rsqrt(row_variance + static_cast(epsilon));
if (threadIdx.x == 0) {
mean[row] = row_mean;
inv_variance[row] = row_inv_var;
}
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack[i] = (buf[i * num_packs + pack_id] - row_mean) * row_inv_var;
}
store.template store(pack, row, pack_id * pack_size);
}
}
}
3.num_cols 较大时,不使用 Shared Memory 的情况
当 num_cols 较大,当前硬件资源条件下使用Shared Memory的方法无法成功Launch Kernel时,使用这种实现:一个 Block 处理一行的元素,不使用 Shared Memory,重复读输入 x。
这种方法和前面第二种情况线程和元素对应关系一致,唯一的区别在于,第二种方法将输入 x 存储到Shared Memory 中,本方法不存储 x,在每次计算时需要再从 Global Memory 中读入 x。这种方法虽然需要多读一份 x,但是在实际执行时,部分输入可以被 Cache 缓存起来,不会实际增加很多时间。值得注意的是,在这种实现中,block_size 越大,SM 中能同时并行执行的 block 数就越少,对 Cache 的需求就越少,就有更多机会命中 Cache,因此我们使用较大的 block_size。
LayerNormBlockUncachedImpl 代码如下:
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
__global__ void LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon,
ComputeType* mean, ComputeType* inv_variance) {
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = cols / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_mean = 0;
ComputeType thread_m2 = 0;
ComputeType thread_count = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
WelfordCombine(pack[i], &thread_mean, &thread_m2, &thread_count);
}
}
ComputeType row_mean = 0;
ComputeType row_m2 = 0;
ComputeType row_count = 0;
WelfordBlockAllReduce(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,
&row_count);
ComputeType row_variance = max(Div(row_m2, row_count), static_cast(0.0));
ComputeType row_inv_var = Rsqrt(row_variance + static_cast(epsilon));
if (threadIdx.x == 0) {
mean[row] = row_mean;
inv_variance[row] = row_inv_var;
}
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
const int pack_offset = pack_id * pack_size;
load.template load(pack, row, pack_offset);
#pragma unroll
for (int i = 0; i < pack_size; ++i) { pack[i] = (pack[i] - row_mean) * row_inv_var; }
store.template store(pack, row, pack_offset);
}
}
}
3 OneFlow Softmax 库
oneflow::cuda::softmax::DirectLoad float> load(in, cols);
oneflow::cuda::softmax::DirectStore<float, half> store(out, cols);
oneflow::cuda::softmax::DispatchSoftmax<decltype(load), decltype(store), float>(
cuda_stream, load, store, rows, cols);
性能优势,可见之前的文章分享。此外,最近一年进一步优化了小的 num_cols 下的性能。
同时支持了 Softmax 和 LogSoftmax,适用场景更广。
输入输出通过 Load/Store 结构传递,解耦数据IO和计算,只需要加几行代码就可以快速支持 Softmax 和其他 Kernel Fuse,减少带宽需求,带来很高的性能收益。