mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-07 15:00:11 +00:00
DeepSeek CUDA Flash Attention (#241)
* WIP CUDA FA with Dk != Dv * WIP * CUDA FA WIP - It actually works! No TG yet, but for PP I can run FA with fp16 cache and it gets the same answer. * CUDA FA WIP - it now works for Q8_0 + Q8_0 for KV cache * CUDA FA WIP - TG, not working yet. * CUDA FA with Dk != Dv: it works now for DeepSeek --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -3275,6 +3275,10 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
if (op->src[0]->ne[0] == 128) {
|
||||
return true;
|
||||
}
|
||||
if (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) {
|
||||
return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) ||
|
||||
(op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0);
|
||||
}
|
||||
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ typedef half (*vec_dot_KQ_f16_t)(
|
||||
typedef float (*vec_dot_KQ_f32_t)(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
||||
|
||||
template<typename T, int D>
|
||||
template<typename T, int Dk>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
@@ -62,7 +62,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
@@ -92,7 +92,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T, int D>
|
||||
template<typename T, int Dk>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
@@ -102,7 +102,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
@@ -142,7 +142,7 @@ static __device__ __forceinline__ int get_one_int_from_table_16(const int & q4)
|
||||
return *((const int *) &val0_8);
|
||||
}
|
||||
|
||||
template<typename T, int D>
|
||||
template<typename T, int Dk>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
@@ -152,7 +152,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl(
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
@@ -179,7 +179,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl(
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T, int D>
|
||||
template<typename T, int Dk>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
@@ -189,7 +189,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
@@ -226,7 +226,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T, int D>
|
||||
template<typename T, int Dk>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
@@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
@@ -277,7 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T, int D>
|
||||
template<typename T, int Dk>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
@@ -287,7 +287,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0(
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
@@ -320,7 +320,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0(
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename T, int D>
|
||||
template <typename T, int Dk>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
|
||||
@@ -330,7 +330,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_0;
|
||||
@@ -353,7 +353,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename T, int D>
|
||||
template <typename T, int Dk>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
||||
|
||||
@@ -368,7 +368,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
||||
half2 sum2 = make_half2(0.0f, 0.0f);
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const half2 K_ik = K_h2[k_KQ];
|
||||
@@ -384,7 +384,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
||||
float sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const half2 K_ik = K_h2[k_KQ];
|
||||
@@ -603,29 +603,29 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
|
||||
return x[i];
|
||||
}
|
||||
|
||||
template <int D>
|
||||
template <int Dk>
|
||||
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
|
||||
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
|
||||
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
|
||||
type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, D> :
|
||||
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
|
||||
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
|
||||
type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<half, D> :
|
||||
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
|
||||
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
|
||||
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, Dk> :
|
||||
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, Dk> :
|
||||
type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<half, Dk> :
|
||||
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, Dk> :
|
||||
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, Dk> :
|
||||
type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<half, Dk> :
|
||||
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, Dk> :
|
||||
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, Dk> :
|
||||
nullptr;
|
||||
}
|
||||
|
||||
template <int D>
|
||||
template <int Dk>
|
||||
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
|
||||
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
|
||||
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
|
||||
type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float, D> :
|
||||
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
|
||||
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
|
||||
type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<float, D> :
|
||||
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
|
||||
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
|
||||
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, Dk> :
|
||||
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, Dk> :
|
||||
type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl<float, Dk> :
|
||||
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, Dk> :
|
||||
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, Dk> :
|
||||
type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0<float, Dk> :
|
||||
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, Dk> :
|
||||
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, Dk> :
|
||||
nullptr;
|
||||
}
|
||||
|
||||
@@ -653,20 +653,20 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
||||
nullptr;
|
||||
}
|
||||
|
||||
template<int D, int parallel_blocks> // D == head size
|
||||
template<int Dv, int parallel_blocks> // Dv == V head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
__launch_bounds__(Dv, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_combine_results(
|
||||
const float * __restrict__ VKQ_parts,
|
||||
const float2 * __restrict__ VKQ_meta,
|
||||
float * __restrict__ dst) {
|
||||
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
|
||||
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
|
||||
dst += D * gridDim.y*blockIdx.x;
|
||||
VKQ_parts += parallel_blocks*Dv * gridDim.y*blockIdx.x;
|
||||
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
|
||||
dst += Dv * gridDim.y*blockIdx.x;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
__builtin_assume(tid < Dv);
|
||||
|
||||
__shared__ float2 meta[parallel_blocks];
|
||||
if (tid < 2*parallel_blocks) {
|
||||
@@ -690,20 +690,20 @@ static __global__ void flash_attn_combine_results(
|
||||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*Dv + blockIdx.y*Dv + tid];
|
||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||
}
|
||||
|
||||
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
dst[blockIdx.y*Dv + tid] = VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
static void on_no_fattn_vec_case(const int D) {
|
||||
if (D == 64) {
|
||||
static void on_no_fattn_vec_case(const int Dk, const int Dv) {
|
||||
if (Dk == 64 && Dv == 64) {
|
||||
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
|
||||
fprintf(stderr, "By default only f16 KV cache is supported.\n");
|
||||
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
|
||||
GGML_ABORT("fatal error");
|
||||
} else if (D == 128) {
|
||||
} else if (Dk == 128 && Dv == 128) {
|
||||
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
|
||||
fprintf(stderr, "Supported combinations:\n");
|
||||
fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n");
|
||||
@@ -715,14 +715,22 @@ static void on_no_fattn_vec_case(const int D) {
|
||||
fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n");
|
||||
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n");
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
else if (Dk == 192 && Dv == 128) {
|
||||
fprintf(stderr, "Unsupported KV type combination for head_sizes 192 / 128\n");
|
||||
// TODO: add what is supported
|
||||
}
|
||||
else if (Dk == 576 && Dv == 512) {
|
||||
fprintf(stderr, "Unsupported KV type combination for head_sizes 576 / 512\n");
|
||||
// TODO: add what is supported
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
|
||||
fprintf(stderr, "Unsupported KV type combination for head_sizes %d, %d.\n", Dk, Dv);
|
||||
fprintf(stderr, "Only f16 is supported.\n");
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int parallel_blocks>
|
||||
template <int Dk, int Dv, int parallel_blocks>
|
||||
void launch_fattn(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
||||
const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
|
||||
@@ -838,11 +846,11 @@ void launch_fattn(
|
||||
return;
|
||||
}
|
||||
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 block_dim_combine(Dv, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
||||
const int shmem_combine = 0;
|
||||
|
||||
flash_attn_combine_results<D, parallel_blocks>
|
||||
flash_attn_combine_results<Dv, parallel_blocks>
|
||||
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
@@ -291,13 +291,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#define FATTN_KQ_STRIDE_TILE_F32 32
|
||||
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
|
||||
template<int Dk, int Dv, int ncols, int nwarps, int parallel_blocks, bool use_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
@@ -44,8 +44,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512));
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_softcap && !(D == 128 || D == 256)) {
|
||||
if (use_softcap && !(Dk == 128 || Dk == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
@@ -61,15 +62,22 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
const int stride_K2 = nb11 / sizeof(half2);
|
||||
const int stride_V2 = nb12 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
// TODO: is it Dk or Dv or both that need to be multiple of 2*WARP_SIZE ?
|
||||
// let's assume it is is both.
|
||||
static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64.");
|
||||
static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64.");
|
||||
|
||||
constexpr int Dkv = Dk < Dv ? Dv : Dk; // let's use this when we don't understand if it is Dk or Dv
|
||||
|
||||
__shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
|
||||
|
||||
__shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
|
||||
// This is being used to store either K or V data. Hence we need max(Dk, Dv) as the dimension
|
||||
__shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][Dkv + 1]; // Pad D to avoid memory bank conflicts.
|
||||
float2 * KV_tmp2 = (float2 *) KV_tmp;
|
||||
|
||||
float kqmax[ncols/nwarps];
|
||||
@@ -79,16 +87,16 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
}
|
||||
float kqsum[ncols/nwarps] = {0.0f};
|
||||
|
||||
float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
|
||||
float2 VKQ[ncols/nwarps][(Dv/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
|
||||
|
||||
// Convert Q to half2 and store in registers:
|
||||
__shared__ float Q_f[ncols][D];
|
||||
__shared__ float Q_f[ncols][Dk];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dk; i0 += 2*WARP_SIZE) {
|
||||
float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
|
||||
Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
|
||||
Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
|
||||
@@ -112,8 +120,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
|
||||
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 2*WARP_SIZE) {
|
||||
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_K2 + k_KQ_0/2 + threadIdx.x];
|
||||
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
|
||||
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
|
||||
}
|
||||
@@ -124,7 +132,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
|
||||
for (int k_KQ = 0; k_KQ < Dk; ++k_KQ) {
|
||||
float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
|
||||
float Q_k[ncols/nwarps];
|
||||
|
||||
@@ -193,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
|
||||
}
|
||||
@@ -206,11 +214,11 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const int k = k0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
|
||||
KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
|
||||
KV_tmp2[k*(Dv/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]);
|
||||
KV_tmp2[k*(Dv/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,14 +226,14 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
|
||||
float2 V_k[(D/2)/WARP_SIZE];
|
||||
float2 V_k[(Dv/2)/WARP_SIZE];
|
||||
float KQ_k[ncols/nwarps];
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
|
||||
V_k[i0/WARP_SIZE] = KV_tmp2[k*(Dv/2) + i];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
@@ -235,7 +243,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
|
||||
@@ -259,7 +267,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
kqsum_j = warp_reduce_sum(kqsum_j);
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||
for (int i00 = 0; i00 < Dv; i00 += 2*WARP_SIZE) {
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
@@ -268,8 +276,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
dst_val.y /= kqsum_j;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
|
||||
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 0] = dst_val.x;
|
||||
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 1] = dst_val.y;
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && threadIdx.x == 0) {
|
||||
@@ -285,14 +293,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, D, cols_per_block, nwarps, parallel_blocks, use_softcap>;
|
||||
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
|
||||
template<int Dk, int Dv, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
__launch_bounds__(Dk, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_vec_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
@@ -43,14 +43,15 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_softcap && !(D == 128 || D == 256)) {
|
||||
if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) {
|
||||
if (use_softcap && !(Dk == 128 || Dk == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
|
||||
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<Dk>(type_K);
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
|
||||
|
||||
@@ -67,12 +68,13 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
constexpr int nwarps = D / WARP_SIZE;
|
||||
static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64.");
|
||||
static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64.");
|
||||
constexpr int nwarps = Dk / WARP_SIZE;
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
__builtin_assume(tid < Dk);
|
||||
|
||||
__shared__ half KQ[ncols*D];
|
||||
__shared__ half KQ[ncols*Dk];
|
||||
half2 * KQ2 = (half2 *) KQ;
|
||||
|
||||
half kqmax[ncols];
|
||||
@@ -94,9 +96,9 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
__syncthreads();
|
||||
|
||||
// Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
half2 Q_h2[ncols][D/(2*WARP_SIZE)];
|
||||
int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)];
|
||||
half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
|
||||
half2 Q_h2[ncols][Dk/(2*WARP_SIZE)];
|
||||
int Q_i32[ncols][Dk/(sizeof(int)*QK8_1) == 0 ? 1 : Dk/(sizeof(int)*QK8_1)];
|
||||
half2 Q_ds[ncols][Dk/QK8_1 == 0 ? 1 : Dk/QK8_1];
|
||||
if (Q_q8_1) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
@@ -107,18 +109,18 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
|
||||
// Reuse KQ as temporary storage for converting Q to q8_1:
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
int * tmp_q_i32 = (int *) &KQ[j*Dk];
|
||||
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + Dk/sizeof(int));
|
||||
|
||||
// Set memory to zero if out of bounds:
|
||||
if (ncols > 2 && ic0 + j >= ne01) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
tmp_q_i32[i] = 0;
|
||||
}
|
||||
if (threadIdx.x < D/QK8_1) {
|
||||
if (threadIdx.x < Dk/QK8_1) {
|
||||
tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f);
|
||||
}
|
||||
continue;
|
||||
@@ -126,7 +128,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
const float * Q_f = (const float *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) {
|
||||
quantize_q8_1_to_shared<half2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
|
||||
}
|
||||
}
|
||||
@@ -135,11 +137,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
int * tmp_q_i32 = (int *) &KQ[j*Dk];
|
||||
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + Dk/sizeof(int));
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
|
||||
@@ -154,7 +156,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dk/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
|
||||
@@ -166,13 +168,13 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ[j*D + tid] = -HALF_MAX_HALF;
|
||||
KQ[j*Dk + tid] = -HALF_MAX_HALF;
|
||||
}
|
||||
|
||||
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
||||
|
||||
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
||||
const int k_start = parallel_blocks == 1 ? 0 : ip*Dk;
|
||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
||||
@@ -186,10 +188,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
|
||||
if ((i_KQ_0 + nwarps > Dk && i_KQ >= Dk) || (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + i_KQ >= ne11)) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -209,7 +211,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
KQ[j*D + i_KQ] = sum;
|
||||
KQ[j*Dk + i_KQ] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -234,9 +236,9 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
||||
kqmax[j] = kqmax_new_j;
|
||||
|
||||
const half val = hexp(KQ[j*D + tid] - kqmax[j]);
|
||||
const half val = hexp(KQ[j*Dk + tid] - kqmax[j]);
|
||||
kqsum[j] = kqsum[j]*KQ_max_scale + val;
|
||||
KQ[j*D + tid] = val;
|
||||
KQ[j*Dk + tid] = val;
|
||||
|
||||
VKQ[j] *= __half2half2(KQ_max_scale);
|
||||
}
|
||||
@@ -244,8 +246,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < D; k0 += 2) {
|
||||
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
||||
for (int k0 = 0; k0 < Dv; k0 += 2) {
|
||||
if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k0 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -254,7 +256,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
|
||||
VKQ[j] += V_k*KQ2[j*(Dk/2) + k0/2];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -285,27 +287,28 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
|
||||
template <int Dk, int Dv, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
|
||||
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||
constexpr int nwarps = Dk/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
|
||||
constexpr bool need_f16_K = Dk != 128 && Dk != 192;
|
||||
constexpr bool need_f16_V = Dv != 128 && Dv != 64;
|
||||
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
template <int Dk, int Dv, ggml_type type_K, ggml_type type_V>
|
||||
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
@@ -325,9 +328,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
constexpr int cols_per_block = 1;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -336,9 +339,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
constexpr int cols_per_block = 2;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -347,9 +350,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
constexpr int cols_per_block = 4;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -358,9 +361,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -368,15 +371,19 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 1;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||
template void ggml_cuda_flash_attn_ext_vec_f16_case \
|
||||
<D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
<D, D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
|
||||
#define DECL_FATTN_VEC_F16_CASE_DKDV(Dk, Dv, type_K, type_V) \
|
||||
template void ggml_cuda_flash_attn_ext_vec_f16_case \
|
||||
<Dk, Dv, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
|
||||
extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
@@ -435,3 +442,6 @@ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
|
||||
extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
|
||||
extern DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
|
||||
template<int Dk, int Dv, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
__launch_bounds__(Dk, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_vec_ext_f32(
|
||||
const char * __restrict__ Q,
|
||||
@@ -42,14 +42,15 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_softcap && !(D == 128 || D == 256)) {
|
||||
if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) {
|
||||
if (use_softcap && !(Dk == 128 || Dk == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
|
||||
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<Dk>(type_K);
|
||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
||||
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
|
||||
|
||||
@@ -64,15 +65,16 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
constexpr int nwarps = D / WARP_SIZE;
|
||||
static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64.");
|
||||
static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64.");
|
||||
constexpr int nwarps = Dk / WARP_SIZE;
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
__builtin_assume(tid < Dk);
|
||||
|
||||
__shared__ float KQ[ncols*D];
|
||||
__shared__ float KQ[ncols*Dk];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ[j*D + tid] = -FLT_MAX/2.0f;
|
||||
KQ[j*Dk + tid] = -FLT_MAX/2.0f;
|
||||
}
|
||||
|
||||
float kqmax[ncols];
|
||||
@@ -94,9 +96,9 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
__syncthreads();
|
||||
|
||||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
float2 Q_f2[ncols][D/(2*WARP_SIZE)];
|
||||
int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)];
|
||||
float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
|
||||
float2 Q_f2[ncols][Dk/(2*WARP_SIZE)];
|
||||
int Q_i32[ncols][Dk/(sizeof(int)*QK8_1) == 0 ? 1 : Dk >= Dk/(sizeof(int)*QK8_1)];
|
||||
float2 Q_ds[ncols][Dk/QK8_1 == 0 ? 1 : Dk/QK8_1];
|
||||
if (Q_q8_1) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
@@ -107,18 +109,18 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
}
|
||||
|
||||
// Reuse KQ as temporary storage for converting Q to q8_1:
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
int * tmp_q_i32 = (int *) &KQ[j*Dk];
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + Dk/sizeof(int));
|
||||
|
||||
// Set memory to zero if out of bounds:
|
||||
if (ncols > 2 && ic0 + j >= ne01) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
tmp_q_i32[i] = 0;
|
||||
}
|
||||
if (threadIdx.x < D/QK8_1) {
|
||||
if (threadIdx.x < Dk/QK8_1) {
|
||||
tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
|
||||
}
|
||||
continue;
|
||||
@@ -126,7 +128,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
const float * Q_f = (const float *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) {
|
||||
quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
|
||||
}
|
||||
}
|
||||
@@ -135,11 +137,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
int * tmp_q_i32 = (int *) &KQ[j*Dk];
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + Dk/sizeof(int));
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
|
||||
@@ -153,7 +155,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dk/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
|
||||
@@ -165,8 +167,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
float VKQ[ncols] = {0.0f};
|
||||
|
||||
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
||||
const int k_start = parallel_blocks == 1 ? 0 : ip*Dk;
|
||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
float kqmax_new_arr[ncols];
|
||||
@@ -176,10 +178,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
|
||||
if ((i_KQ_0 + nwarps > Dk && i_KQ >= Dk) || (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + i_KQ >= ne11)) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -195,7 +197,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
KQ[j*D + i_KQ] = sum;
|
||||
KQ[j*Dk + i_KQ] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -220,9 +222,9 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
|
||||
kqmax[j] = kqmax_new_j;
|
||||
|
||||
const float val = expf(KQ[j*D + tid] - kqmax[j]);
|
||||
const float val = expf(KQ[j*Dk + tid] - kqmax[j]);
|
||||
kqsum[j] = kqsum[j]*KQ_max_scale + val;
|
||||
KQ[j*D + tid] = val;
|
||||
KQ[j*Dk + tid] = val;
|
||||
|
||||
VKQ[j] *= KQ_max_scale;
|
||||
}
|
||||
@@ -230,15 +232,15 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < D; ++k) {
|
||||
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) {
|
||||
for (int k = 0; k < Dv; ++k) {
|
||||
if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j] += V_ki*KQ[j*D + k];
|
||||
VKQ[j] += V_ki*KQ[j*Dk + k];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,24 +271,25 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||
dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
|
||||
template <int Dk, int Dv, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_softcap>
|
||||
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||
constexpr int nwarps = Dk/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, use_softcap>;
|
||||
constexpr bool need_f16_K = Dk != 128 && Dk != 192;
|
||||
constexpr bool need_f16_V = Dv != 128 && Dv != 64;
|
||||
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
template <int Dk, int Dv, ggml_type type_K, ggml_type type_V>
|
||||
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
@@ -303,9 +306,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
|
||||
constexpr int cols_per_block = 1;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -314,9 +317,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
|
||||
constexpr int cols_per_block = 2;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -325,9 +328,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
|
||||
constexpr int cols_per_block = 4;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -336,9 +339,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 4;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -346,15 +349,19 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 1;
|
||||
if (softcap == 0.0f) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, false>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<Dk, Dv, cols_per_block, parallel_blocks, type_K, type_V, true>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
||||
template void ggml_cuda_flash_attn_ext_vec_f32_case \
|
||||
<D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
<D, D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
|
||||
#define DECL_FATTN_VEC_F32_CASE_DKDV(Dk, Dv, type_K, type_V) \
|
||||
template void ggml_cuda_flash_attn_ext_vec_f32_case \
|
||||
<Dk, Dv, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
|
||||
extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
@@ -406,3 +413,6 @@ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
|
||||
extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
|
||||
extern DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
extern DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
#include <mma.h>
|
||||
#endif // FP16_MMA_AVAILABLE
|
||||
|
||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_softcap>
|
||||
// Dk == K head size, Dv = V head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||
template<int Dk, int Dv, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_softcap>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
@@ -47,8 +47,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512));
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_softcap && !(D == 128 || D == 256)) {
|
||||
if (use_softcap && !(Dk == 128 || Dk == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
@@ -58,11 +59,11 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
|
||||
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
||||
static_assert(Dk <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
||||
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
||||
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
||||
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
||||
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
||||
static_assert(Dk % frag_m == 0, "If ncols == 8 then Dk % frag_m must be 0.");
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
|
||||
@@ -74,30 +75,32 @@ static __global__ void flash_attn_ext_f16(
|
||||
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
|
||||
|
||||
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
|
||||
constexpr int D_padded = D + 8;
|
||||
constexpr int Dk_padded = Dk + 8;
|
||||
constexpr int Dv_padded = Dv + 8;
|
||||
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
||||
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
||||
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
||||
|
||||
const int stride_Q = nb01 / sizeof(float);
|
||||
const int stride_KV = nb11 / sizeof(half);
|
||||
const int stride_Q = nb01 / sizeof(float);
|
||||
const int stride_K = nb11 / sizeof(half);
|
||||
const int stride_V = nb21 / sizeof(half);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
const half2 slope2 = make_half2(slopef, slopef);
|
||||
const half2 softcap_2 = make_half2(softcap, softcap);
|
||||
|
||||
frag_b Q_b[D/16][ncols/frag_n];
|
||||
frag_b Q_b[Dk/16][ncols/frag_n];
|
||||
|
||||
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
||||
constexpr int mem_KQ = ncols*kqs_padded*kqar;
|
||||
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
|
||||
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*Dv_padded;
|
||||
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
|
||||
float * KQ_f = (float *) KQ;
|
||||
half2 * KQ2 = (half2 *) KQ;
|
||||
@@ -120,18 +123,18 @@ static __global__ void flash_attn_ext_f16(
|
||||
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
|
||||
}
|
||||
|
||||
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
||||
__shared__ half VKQ[ncols*Dv_padded]; // Accumulator for final VKQ slice.
|
||||
half2 * VKQ2 = (half2 *) VKQ;
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
||||
if (i0 + WARP_SIZE > Dv/2 && i >= Dv/2) {
|
||||
break;
|
||||
}
|
||||
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
|
||||
VKQ2[j*(Dv_padded/2) + i] = make_half2(0.0f, 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,12 +143,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dk; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D && i >= D) {
|
||||
if (i0 + WARP_SIZE > Dk && i >= Dk) {
|
||||
break;
|
||||
}
|
||||
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
||||
KQ[j*Dk_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,10 +156,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// Load Q into tensor core fragments/registers since it will be used frequently:
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += 16) {
|
||||
for (int i0 = 0; i0 < Dk; i0 += 16) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
||||
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*Dk_padded + i0, Dk_padded);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,9 +176,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 16) {
|
||||
frag_a_K K_a;
|
||||
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_K + k_KQ_0, stride_K);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||
@@ -309,9 +312,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
}
|
||||
|
||||
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
|
||||
frag_c_VKQ VKQ_c[Dv/VKQ_stride][ncols/frag_n];
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < Dv; i_VKQ_0 += VKQ_stride) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
||||
@@ -322,7 +325,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
|
||||
frag_a_V v_a;
|
||||
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_V + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_V);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||
@@ -332,15 +335,15 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
|
||||
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*Dk_padded);
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += VKQ_stride) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::store_matrix_sync(
|
||||
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
KQ + offset_k + j0*Dk_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
||||
D_padded, nvcuda::wmma::mem_col_major);
|
||||
Dk_padded, nvcuda::wmma::mem_col_major);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -358,18 +361,18 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
||||
if (i0 + WARP_SIZE > Dv/2 && i >= Dv/2) {
|
||||
break;
|
||||
}
|
||||
|
||||
half2 VKQ_add = make_half2(0.0f, 0.0f);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < VKQ_ratio; ++l) {
|
||||
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
|
||||
VKQ_add += KQ2[l*(ncols*Dk_padded/2) + j*(Dk_padded/2) + i];
|
||||
}
|
||||
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
|
||||
VKQ2[j*(Dv_padded/2) + i] = VKQ_scale*VKQ2[j*(Dv_padded/2) + i] + VKQ_add;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -392,16 +395,16 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < Dv; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D && i >= D) {
|
||||
if (i0 + WARP_SIZE > Dv && i >= Dv) {
|
||||
break;
|
||||
}
|
||||
float dst_val = VKQ[j_VKQ*D_padded + i];
|
||||
float dst_val = VKQ[j_VKQ*Dv_padded + i];
|
||||
if (parallel_blocks == 1) {
|
||||
dst_val /= KQ_rowsum_j;
|
||||
}
|
||||
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
|
||||
dst[j_dst*gridDim.y*Dv + blockIdx.y*Dv + i] = dst_val;
|
||||
}
|
||||
|
||||
if (parallel_blocks == 1 || threadIdx.x != 0) {
|
||||
@@ -446,13 +449,13 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
||||
|
||||
template <int D, int cols_per_block, typename KQ_acc_t>
|
||||
template <int Dk, int Dv, int cols_per_block, typename KQ_acc_t>
|
||||
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
constexpr int nwarps = 4;
|
||||
|
||||
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
||||
constexpr int frag_m = cols_per_block == 8 && Dk % 32 == 0 ? 32 : 16;
|
||||
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
|
||||
@@ -462,29 +465,33 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
||||
if (4*blocks_num_pb1 < 2*nsm) {
|
||||
constexpr int parallel_blocks = 4;
|
||||
fattn_kernel_t fattn_kernel = softcap == 0.0f ?
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
|
||||
flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
|
||||
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
return;
|
||||
}
|
||||
if (2*blocks_num_pb1 < 2*nsm) {
|
||||
constexpr int parallel_blocks = 2;
|
||||
fattn_kernel_t fattn_kernel = softcap == 0.0f ?
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
|
||||
flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
|
||||
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
return;
|
||||
}
|
||||
constexpr int parallel_blocks = 1;
|
||||
fattn_kernel_t fattn_kernel = softcap == 0.0f ?
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
|
||||
flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
|
||||
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
}
|
||||
|
||||
#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
|
||||
template void ggml_cuda_flash_attn_ext_wmma_f16_case \
|
||||
<D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
<D, D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
|
||||
#define DECL_FATTN_WMMA_F16_CASE_DKDV(Dk, Dv, cols_per_block, KQ_acc_t) \
|
||||
template void ggml_cuda_flash_attn_ext_wmma_f16_case \
|
||||
<Dk, Dv, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
||||
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
|
||||
extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
|
||||
@@ -518,3 +525,7 @@ extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
|
||||
|
||||
extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 8, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, half);
|
||||
extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, half);
|
||||
|
||||
@@ -12,6 +12,14 @@
|
||||
static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
if (Q->ne[0] != V->ne[0]) {
|
||||
if (!((Q->ne[0] == 192 && V->ne[0] == 128) || (Q->ne[0] == 576 && V->ne[0] == 512))) {
|
||||
fprintf(stderr, "======================= %s: Unhandled head size combination %d, %d\n", __func__, (int)Q->ne[0], (int)V->ne[0]);
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
|
||||
@@ -20,22 +28,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
||||
constexpr int cols_per_block = 16;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 192:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]);
|
||||
@@ -46,19 +57,22 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
||||
constexpr int cols_per_block = 32;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
case 192:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst);
|
||||
break;
|
||||
// case 256:
|
||||
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||
@@ -76,16 +90,19 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
||||
constexpr int cols_per_block = 8;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 192:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]);
|
||||
@@ -99,22 +116,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
||||
constexpr int cols_per_block = 16;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 192:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]);
|
||||
@@ -127,22 +147,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
||||
constexpr int cols_per_block = 32;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 192:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]);
|
||||
@@ -152,7 +175,13 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
||||
}
|
||||
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case<D, D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
|
||||
#define FATTN_VEC_F16_CASE_DKDV(Dk, Dv, type_K, type_V) \
|
||||
if (Q->ne[0] == (Dk) && V->ne[0] == Dv && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case<Dk, Dv, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
|
||||
@@ -218,6 +247,9 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
|
||||
|
||||
FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
#else
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
|
||||
@@ -231,14 +263,24 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
|
||||
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
|
||||
|
||||
FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
on_no_fattn_vec_case(Q->ne[0]);
|
||||
on_no_fattn_vec_case(Q->ne[0], V->ne[0]);
|
||||
}
|
||||
|
||||
#define FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
||||
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case<D, D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
|
||||
#define FATTN_VEC_F32_CASE_DKDV(Dk, Dv, type_K, type_V) \
|
||||
if (Q->ne[0] == (Dk) && V->ne[0] == Dv && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case<Dk, Dv, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
|
||||
@@ -298,6 +340,9 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
|
||||
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
|
||||
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
#else
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||
|
||||
@@ -306,9 +351,12 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
||||
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
|
||||
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||
FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
on_no_fattn_vec_case(Q->ne[0]);
|
||||
on_no_fattn_vec_case(Q->ne[0], V->ne[0]);
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f16.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec-f32.cuh"
|
||||
|
||||
DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
@@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 16, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(112, 16, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(128, 16, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(256, 16, float);
|
||||
|
||||
DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, float);
|
||||
|
||||
@@ -7,3 +7,5 @@ DECL_FATTN_WMMA_F16_CASE(80, 32, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(96, 32, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(112, 32, float);
|
||||
DECL_FATTN_WMMA_F16_CASE(128, 32, float);
|
||||
|
||||
DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, float);
|
||||
|
||||
@@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 16, half);
|
||||
DECL_FATTN_WMMA_F16_CASE(112, 16, half);
|
||||
DECL_FATTN_WMMA_F16_CASE(128, 16, half);
|
||||
DECL_FATTN_WMMA_F16_CASE(256, 16, half);
|
||||
|
||||
DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, half);
|
||||
|
||||
@@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 32, half);
|
||||
DECL_FATTN_WMMA_F16_CASE(112, 32, half);
|
||||
DECL_FATTN_WMMA_F16_CASE(128, 32, half);
|
||||
DECL_FATTN_WMMA_F16_CASE(256, 32, half);
|
||||
|
||||
DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, half);
|
||||
|
||||
@@ -6,3 +6,5 @@ DECL_FATTN_WMMA_F16_CASE(64, 8, half);
|
||||
DECL_FATTN_WMMA_F16_CASE(96, 8, half);
|
||||
DECL_FATTN_WMMA_F16_CASE(128, 8, half);
|
||||
DECL_FATTN_WMMA_F16_CASE(256, 8, half);
|
||||
|
||||
DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 8, half);
|
||||
|
||||
@@ -8777,7 +8777,11 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
|
||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
|
||||
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
|
||||
// For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
|
||||
// Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel.
|
||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
|
||||
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8)) {
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
//ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
|
||||
Reference in New Issue
Block a user