f32 vec kernel

This commit is contained in:
Iwan Kawrakow
2025-09-03 14:52:32 +03:00
parent 2a09cd1c08
commit bf0b5088e0
2 changed files with 117 additions and 94 deletions

View File

@@ -6,55 +6,54 @@
//
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-vec-common.cuh"
template<int Dk, int Dv, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
// Currenlty llvm with the amdgcn target dose not support unrolling loops
// that contain a break that can not be resolved at compile time.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template<int Dk, int Dv, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // Dk, Dv == K-, V-head size
#ifndef GGML_USE_HIP
__launch_bounds__(Dk, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
#endif // GGML_USE_HIP
static __global__ void flash_attn_vec_ext_f32(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int2 * __restrict__ KV_min_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const float softcap,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
// Skip unused kernel variants for faster compilation:
if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) {
if (use_softcap && !(Dk == 128 || Dk == 256)) {
if (use_logit_softcap && !(Dk == 128 || Dk == 256)) {
NO_DEVICE_CODE;
return;
}
}
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
if (ncols > 1) {
NO_DEVICE_CODE;
return;
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
@@ -62,17 +61,19 @@ static __global__ void flash_attn_vec_ext_f32(
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.y + nb01*ic0;
K += nb12*(blockIdx.y / gqa_ratio);
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
Q += nb03*sequence + nb02* head + nb01*ic0;
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64.");
static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64.");
@@ -87,11 +88,12 @@ static __global__ void flash_attn_vec_ext_f32(
}
float kqmax[ncols];
float kqsum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax[j] = -FLT_MAX/2.0f;
kqsum[j] = 0.0f;
}
float kqsum[ncols] = {0.0f};
__shared__ float kqmax_shared[ncols][WARP_SIZE];
__shared__ float kqsum_shared[ncols][WARP_SIZE];
@@ -102,6 +104,13 @@ static __global__ void flash_attn_vec_ext_f32(
kqsum_shared[j][threadIdx.x] = 0.0f;
}
}
__shared__ float maskf_shared[ncols*Dk];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskf_shared[j*Dk + tid] = 0.0f;
}
__syncthreads();
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
@@ -176,10 +185,26 @@ static __global__ void flash_attn_vec_ext_f32(
float VKQ[ncols] = {0.0f};
const int k_start = parallel_blocks == 1 ? 0 : ip*Dk;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) {
const int k_VKQ_max = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].y : ne11;
const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0;
K += (first_y + blockIdx.y*Dk) * nb11;
V += (first_y + blockIdx.y*Dv) * nb21;
maskh += (first_y + blockIdx.y*Dk);
for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk,
// Increment pointers after each loop:
K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) {
// Calculate KQ tile and keep track of new maximum KQ values:
if (mask) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
maskf_shared[j*Dk + tid] = slope*__half2float(maskh[j*ne11 + tid]);
}
__syncthreads();
}
float kqmax_new_arr[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
@@ -196,12 +221,14 @@ static __global__ void flash_attn_vec_ext_f32(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum);
if (use_softcap) {
sum = softcap*tanhf(sum);
if (use_logit_softcap) {
sum = logit_softcap*tanhf(sum);
}
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
sum += maskf_shared[j*Dk + i_KQ];
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
@@ -215,7 +242,6 @@ static __global__ void flash_attn_vec_ext_f32(
for (int j = 0; j < ncols; ++j) {
float kqmax_new_j = kqmax_new_arr[j];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
}
@@ -246,7 +272,7 @@ static __global__ void flash_attn_vec_ext_f32(
break;
}
const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
const float V_ki = dequantize_1_v(V + k*nb21, tid);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_ki*KQ[j*Dk + k];
@@ -256,8 +282,8 @@ static __global__ void flash_attn_vec_ext_f32(
__syncthreads();
}
if (sinksf) {
const float sink = sinksf[blockIdx.y];
if (sinksf && blockIdx.y == 0) {
const float sink = sinksf[head];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
@@ -270,7 +296,7 @@ static __global__ void flash_attn_vec_ext_f32(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
float kqmax_new_j = kqmax_shared[j][threadIdx.y];
float kqmax_new_j = kqmax_shared[j][threadIdx.x];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
@@ -309,26 +335,28 @@ static __global__ void flash_attn_vec_ext_f32(
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
float dst_val = VKQ[j_VKQ];
if (parallel_blocks == 1) {
if (gridDim.y == 1) {
dst_val /= kqsum[j_VKQ];
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val;
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*Dv + tid] = dst_val;
}
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
template <int Dk, int Dv, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
template <int Dk, int Dv, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = Dk/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
constexpr bool need_f16_K = Dk != 128 && Dk != 192;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = Dk != 128;
constexpr bool need_f16_V = Dv != 128 && Dv != 64;
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
constexpr size_t nbytes_shared = 0;
launch_fattn<Dv, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, Dv, need_f16_K, need_f16_V);
}
template <int Dk, int Dv, ggml_type type_K, ggml_type type_V>
@@ -341,59 +369,54 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);
float softcap;
memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float));
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
constexpr int cols_per_block = 2;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
constexpr int cols_per_block = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
constexpr int cols_per_block = 8;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
}

View File

@@ -73,9 +73,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
//if (precision == GGML_PREC_DEFAULT) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
// ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
//} else {
// ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
//}
return;
}