CUDA FA WIP - TG, not working yet.

This commit is contained in:
Iwan Kawrakow
2025-03-03 22:23:59 +02:00
parent 47474c1c7e
commit f064db93b2
7 changed files with 161 additions and 103 deletions

View File

@@ -1,9 +1,9 @@
#include "common.cuh"
#include "fattn-common.cuh"
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
template<int Dk, int Dv, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
__launch_bounds__(Dk, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_vec_ext_f16(
const char * __restrict__ Q,
@@ -43,14 +43,15 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne3) {
#ifdef FP16_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_softcap && !(D == 128 || D == 256)) {
if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) {
if (use_softcap && !(Dk == 128 || Dk == 256)) {
NO_DEVICE_CODE;
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<Dk>(type_K);
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
@@ -67,12 +68,13 @@ static __global__ void flash_attn_vec_ext_f16(
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE;
static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64.");
static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = Dk / WARP_SIZE;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
__builtin_assume(tid < D);
__builtin_assume(tid < Dk);
__shared__ half KQ[ncols*D];
__shared__ half KQ[ncols*Dk];
half2 * KQ2 = (half2 *) KQ;
half kqmax[ncols];
@@ -94,9 +96,9 @@ static __global__ void flash_attn_vec_ext_f16(
__syncthreads();
// Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
half2 Q_h2[ncols][D/(2*WARP_SIZE)];
int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)];
half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
half2 Q_h2[ncols][Dk/(2*WARP_SIZE)];
int Q_i32[ncols][Dk/(sizeof(int)*QK8_1) == 0 ? 1 : Dk/(sizeof(int)*QK8_1)];
half2 Q_ds[ncols][Dk/QK8_1 == 0 ? 1 : Dk/QK8_1];
if (Q_q8_1) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
@@ -107,18 +109,18 @@ static __global__ void flash_attn_vec_ext_f16(
}
// Reuse KQ as temporary storage for converting Q to q8_1:
int * tmp_q_i32 = (int *) &KQ[j*D];
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
int * tmp_q_i32 = (int *) &KQ[j*Dk];
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + Dk/sizeof(int));
// Set memory to zero if out of bounds:
if (ncols > 2 && ic0 + j >= ne01) {
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
tmp_q_i32[i] = 0;
}
if (threadIdx.x < D/QK8_1) {
if (threadIdx.x < Dk/QK8_1) {
tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f);
}
continue;
@@ -126,7 +128,7 @@ static __global__ void flash_attn_vec_ext_f16(
const float * Q_f = (const float *) (Q + j*nb01);
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) {
quantize_q8_1_to_shared<half2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
}
}
@@ -135,11 +137,11 @@ static __global__ void flash_attn_vec_ext_f16(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
int * tmp_q_i32 = (int *) &KQ[j*D];
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
int * tmp_q_i32 = (int *) &KQ[j*Dk];
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + Dk/sizeof(int));
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
@@ -154,7 +156,7 @@ static __global__ void flash_attn_vec_ext_f16(
const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
for (int i0 = 0; i0 < Dk/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
@@ -166,13 +168,13 @@ static __global__ void flash_attn_vec_ext_f16(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ[j*D + tid] = -HALF_MAX_HALF;
KQ[j*Dk + tid] = -HALF_MAX_HALF;
}
half2 VKQ[ncols] = {{0.0f, 0.0f}};
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
const int k_start = parallel_blocks == 1 ? 0 : ip*Dk;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) {
// Calculate KQ tile and keep track of new maximum KQ values:
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
@@ -186,10 +188,10 @@ static __global__ void flash_attn_vec_ext_f16(
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
if ((i_KQ_0 + nwarps > Dk && i_KQ >= Dk) || (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + i_KQ >= ne11)) {
break;
}
@@ -209,7 +211,7 @@ static __global__ void flash_attn_vec_ext_f16(
}
if (threadIdx.x == 0) {
KQ[j*D + i_KQ] = sum;
KQ[j*Dk + i_KQ] = sum;
}
}
}
@@ -234,9 +236,9 @@ static __global__ void flash_attn_vec_ext_f16(
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;
const half val = hexp(KQ[j*D + tid] - kqmax[j]);
const half val = hexp(KQ[j*Dk + tid] - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale + val;
KQ[j*D + tid] = val;
KQ[j*Dk + tid] = val;
VKQ[j] *= __half2half2(KQ_max_scale);
}
@@ -244,8 +246,8 @@ static __global__ void flash_attn_vec_ext_f16(
__syncthreads();
#pragma unroll
for (int k0 = 0; k0 < D; k0 += 2) {
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
for (int k0 = 0; k0 < Dv; k0 += 2) {
if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k0 >= ne11) {
break;
}
@@ -254,7 +256,7 @@ static __global__ void flash_attn_vec_ext_f16(
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
VKQ[j] += V_k*KQ2[j*(Dk/2) + k0/2];
}
}
@@ -285,27 +287,28 @@ static __global__ void flash_attn_vec_ext_f16(
dst_val /= kqsum[j_VKQ];
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val;
}
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
}
#else
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
}
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
template <int Dk, int Dv, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
constexpr int nwarps = Dk/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
constexpr bool need_f16_K = Dk != 128 && Dk != 192;
constexpr bool need_f16_V = Dv != 128 && Dv != 64;
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>
template <int Dk, int Dv, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
@@ -325,9 +328,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
}
return;
}
@@ -336,9 +339,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
}
return;
}
@@ -347,9 +350,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
}
return;
}
@@ -358,9 +361,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
}
return;
}
@@ -368,15 +371,19 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
}
}
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
template void ggml_cuda_flash_attn_ext_vec_f16_case \
<D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
<D, D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
#define DECL_FATTN_VEC_F16_CASE_DKDV(Dk, Dv, type_K, type_V) \
template void ggml_cuda_flash_attn_ext_vec_f16_case \
<Dk, Dv, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
@@ -435,3 +442,6 @@ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
extern DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16);
extern DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);

View File

@@ -1,9 +1,9 @@
#include "common.cuh"
#include "fattn-common.cuh"
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
template<int Dk, int Dv, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
__launch_bounds__(Dk, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_vec_ext_f32(
const char * __restrict__ Q,
@@ -42,14 +42,15 @@ static __global__ void flash_attn_vec_ext_f32(
const int ne2,
const int ne3) {
// Skip unused kernel variants for faster compilation:
if (use_softcap && !(D == 128 || D == 256)) {
if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) {
if (use_softcap && !(Dk == 128 || Dk == 256)) {
NO_DEVICE_CODE;
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<Dk>(type_K);
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
@@ -64,15 +65,16 @@ static __global__ void flash_attn_vec_ext_f32(
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE;
static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64.");
static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = Dk / WARP_SIZE;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
__builtin_assume(tid < D);
__builtin_assume(tid < Dk);
__shared__ float KQ[ncols*D];
__shared__ float KQ[ncols*Dk];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ[j*D + tid] = -FLT_MAX/2.0f;
KQ[j*Dk + tid] = -FLT_MAX/2.0f;
}
float kqmax[ncols];
@@ -94,9 +96,9 @@ static __global__ void flash_attn_vec_ext_f32(
__syncthreads();
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
float2 Q_f2[ncols][D/(2*WARP_SIZE)];
int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)];
float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
float2 Q_f2[ncols][Dk/(2*WARP_SIZE)];
int Q_i32[ncols][Dk/(sizeof(int)*QK8_1) == 0 ? 1 : Dk >= Dk/(sizeof(int)*QK8_1)];
float2 Q_ds[ncols][Dk/QK8_1 == 0 ? 1 : Dk/QK8_1];
if (Q_q8_1) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
@@ -107,18 +109,18 @@ static __global__ void flash_attn_vec_ext_f32(
}
// Reuse KQ as temporary storage for converting Q to q8_1:
int * tmp_q_i32 = (int *) &KQ[j*D];
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
int * tmp_q_i32 = (int *) &KQ[j*Dk];
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + Dk/sizeof(int));
// Set memory to zero if out of bounds:
if (ncols > 2 && ic0 + j >= ne01) {
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
tmp_q_i32[i] = 0;
}
if (threadIdx.x < D/QK8_1) {
if (threadIdx.x < Dk/QK8_1) {
tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
}
continue;
@@ -126,7 +128,7 @@ static __global__ void flash_attn_vec_ext_f32(
const float * Q_f = (const float *) (Q + j*nb01);
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) {
quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
}
}
@@ -135,11 +137,11 @@ static __global__ void flash_attn_vec_ext_f32(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
int * tmp_q_i32 = (int *) &KQ[j*D];
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
int * tmp_q_i32 = (int *) &KQ[j*Dk];
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + Dk/sizeof(int));
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
@@ -153,7 +155,7 @@ static __global__ void flash_attn_vec_ext_f32(
for (int j = 0; j < ncols; ++j) {
const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
for (int i0 = 0; i0 < Dk/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
@@ -165,8 +167,8 @@ static __global__ void flash_attn_vec_ext_f32(
float VKQ[ncols] = {0.0f};
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
const int k_start = parallel_blocks == 1 ? 0 : ip*Dk;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) {
// Calculate KQ tile and keep track of new maximum KQ values:
float kqmax_new_arr[ncols];
@@ -176,10 +178,10 @@ static __global__ void flash_attn_vec_ext_f32(
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
if ((i_KQ_0 + nwarps > Dk && i_KQ >= Dk) || (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + i_KQ >= ne11)) {
break;
}
@@ -195,7 +197,7 @@ static __global__ void flash_attn_vec_ext_f32(
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
if (threadIdx.x == 0) {
KQ[j*D + i_KQ] = sum;
KQ[j*Dk + i_KQ] = sum;
}
}
}
@@ -220,9 +222,9 @@ static __global__ void flash_attn_vec_ext_f32(
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;
const float val = expf(KQ[j*D + tid] - kqmax[j]);
const float val = expf(KQ[j*Dk + tid] - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale + val;
KQ[j*D + tid] = val;
KQ[j*Dk + tid] = val;
VKQ[j] *= KQ_max_scale;
}
@@ -230,15 +232,15 @@ static __global__ void flash_attn_vec_ext_f32(
__syncthreads();
#pragma unroll
for (int k = 0; k < D; ++k) {
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) {
for (int k = 0; k < Dv; ++k) {
if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k >= ne11) {
break;
}
const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_ki*KQ[j*D + k];
VKQ[j] += V_ki*KQ[j*Dk + k];
}
}
@@ -269,24 +271,25 @@ static __global__ void flash_attn_vec_ext_f32(
dst_val /= kqsum[j_VKQ];
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val;
}
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
}
}
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
template <int Dk, int Dv, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
constexpr int nwarps = Dk/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
constexpr bool need_f16_K = Dk != 128 && Dk != 192;
constexpr bool need_f16_V = Dv != 128 && Dv != 64;
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>
template <int Dk, int Dv, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
@@ -303,9 +306,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
}
return;
}
@@ -314,9 +317,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
}
return;
}
@@ -325,9 +328,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
}
return;
}
@@ -336,9 +339,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
}
return;
}
@@ -346,15 +349,19 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
if (softcap == 0.0f) {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
}
}
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
template void ggml_cuda_flash_attn_ext_vec_f32_case \
<D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
<D, D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
#define DECL_FATTN_VEC_F32_CASE_DKDV(Dk, Dv, type_K, type_V) \
template void ggml_cuda_flash_attn_ext_vec_f32_case \
<Dk, Dv, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
@@ -406,3 +413,6 @@ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
extern DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16);
extern DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);

View File

@@ -175,7 +175,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
}
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
ggml_cuda_flash_attn_ext_vec_f16_case<D, D, type_K, type_V>(ctx, dst); \
return; \
} \
@@ -247,6 +247,9 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0)
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
#else
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
@@ -260,6 +263,10 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
#endif // GGML_CUDA_FA_ALL_QUANTS
on_no_fattn_vec_case(Q->ne[0], V->ne[0]);
@@ -267,7 +274,13 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
#define FATTN_VEC_F32_CASE(D, type_K, type_V) \
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
ggml_cuda_flash_attn_ext_vec_f32_case<D, D, type_K, type_V>(ctx, dst); \
return; \
} \
#define FATTN_VEC_F32_CASE_DKDV(Dk, Dv, type_K, type_V) \
if (Q->ne[0] == (Dk) && V->ne[0] == Dv && K->type == (type_K) && V->type == (type_V)) { \
ggml_cuda_flash_attn_ext_vec_f32_case<Dk, Dv, type_K, type_V>(ctx, dst); \
return; \
} \
@@ -327,6 +340,9 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
#else
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
@@ -335,6 +351,9 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
#endif // GGML_CUDA_FA_ALL_QUANTS
on_no_fattn_vec_case(Q->ne[0], V->ne[0]);
@@ -343,7 +362,6 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * V = dst->src[2];
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh"
DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f16.cuh"
DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh"
DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec-f32.cuh"
DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);