mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
CUDA: add head size of 64 to new mma
Haven't turned it on yet, but observe slightly better PP and slightly worse TG performance with that.
This commit is contained in:
@@ -43,37 +43,37 @@ struct fattn_mma_f16_config;
|
||||
// Perhaps the 256 head size needs a closer look
|
||||
// to see if this implementation is better.
|
||||
//
|
||||
//template <>
|
||||
//struct fattn_mma_f16_config< 64, 64> {
|
||||
// static constexpr int nbatch_fa = 64;
|
||||
// static constexpr int nwarps_max = 4;
|
||||
// static constexpr bool Q_in_reg = true;
|
||||
// static constexpr int nstages_target = 2;
|
||||
//
|
||||
// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//};
|
||||
template <>
|
||||
struct fattn_mma_f16_config< 64, 64> {
|
||||
static constexpr int nbatch_fa = 64;
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
};
|
||||
//
|
||||
//template <>
|
||||
//struct fattn_mma_f16_config< 80, 80> {
|
||||
@@ -576,7 +576,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
float KQ_rowsum_add[cols_per_thread] = {0.0f};
|
||||
|
||||
if constexpr (ntiles == 1) {
|
||||
if constexpr (ncols2 > 1 || mask_h2) {
|
||||
if (ncols2 > 1 || mask_h2) {
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
|
||||
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
|
||||
@@ -1905,6 +1905,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 16 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 16>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
|
||||
return;
|
||||
@@ -1930,8 +1935,6 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
|
||||
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
@@ -1940,6 +1943,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (K->ne[0] == 64 && V->ne[0] == 64) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
|
||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user