mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
CUDA: Fix FA for Pascal GPU (#1036)
Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
|
||||
#define FATTN_KQ_STRIDE_TILE_F32 32
|
||||
|
||||
template<int Dk, int Dv, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
@@ -52,9 +52,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512));
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_softcap && !(Dk == 128 || Dk == 256)) {
|
||||
if (use_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
@@ -70,22 +70,15 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const int stride_K2 = nb11 / sizeof(half2);
|
||||
const int stride_V2 = nb12 / sizeof(half2);
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
|
||||
// TODO: is it Dk or Dv or both that need to be multiple of 2*WARP_SIZE ?
|
||||
// let's assume it is is both.
|
||||
static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64.");
|
||||
static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64.");
|
||||
|
||||
constexpr int Dkv = Dk < Dv ? Dv : Dk; // let's use this when we don't understand if it is Dk or Dv
|
||||
static_assert(D % (2 * WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
|
||||
__shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
|
||||
|
||||
// This is being used to store either K or V data. Hence we need max(Dk, Dv) as the dimension
|
||||
__shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][Dkv + 1]; // Pad D to avoid memory bank conflicts.
|
||||
__shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
|
||||
|
||||
float2 * KV_tmp2 = (float2 *) KV_tmp;
|
||||
|
||||
float kqmax[ncols/nwarps];
|
||||
@@ -95,16 +88,16 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
}
|
||||
float kqsum[ncols/nwarps] = {0.0f};
|
||||
|
||||
float2 VKQ[ncols/nwarps][(Dv/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
|
||||
float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
|
||||
|
||||
// Convert Q to half2 and store in registers:
|
||||
__shared__ float Q_f[ncols][Dk];
|
||||
__shared__ float Q_f[ncols][D];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < Dk; i0 += 2*WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
|
||||
float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
|
||||
Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
|
||||
Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
|
||||
@@ -128,8 +121,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 2*WARP_SIZE) {
|
||||
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_K2 + k_KQ_0/2 + threadIdx.x];
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
|
||||
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
|
||||
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
|
||||
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
|
||||
}
|
||||
@@ -140,7 +133,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ = 0; k_KQ < Dk; ++k_KQ) {
|
||||
for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
|
||||
float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
|
||||
float Q_k[ncols/nwarps];
|
||||
|
||||
@@ -209,7 +202,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
|
||||
}
|
||||
@@ -222,11 +215,11 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const int k = k0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
KV_tmp2[k*(Dv/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]);
|
||||
KV_tmp2[k*(Dv/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]);
|
||||
KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)* stride_KV2 + i]);
|
||||
KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)* stride_KV2 + i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,14 +227,14 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
|
||||
float2 V_k[(Dv/2)/WARP_SIZE];
|
||||
float2 V_k[(D/2)/WARP_SIZE];
|
||||
float KQ_k[ncols/nwarps];
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
V_k[i0/WARP_SIZE] = KV_tmp2[k*(Dv/2) + i];
|
||||
V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
@@ -251,7 +244,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
|
||||
@@ -275,7 +268,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
kqsum_j = warp_reduce_sum(kqsum_j);
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < Dv; i00 += 2*WARP_SIZE) {
|
||||
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
@@ -284,8 +277,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
dst_val.y /= kqsum_j;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 0] = dst_val.x;
|
||||
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 1] = dst_val.y;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && threadIdx.x == 0) {
|
||||
@@ -301,13 +294,13 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
default: {
|
||||
|
||||
Reference in New Issue
Block a user