This commit is contained in:
Iwan Kawrakow
2025-12-28 08:57:24 +00:00
parent 9b7d08eaa2
commit ba0e88a5e3
2 changed files with 3 additions and 3 deletions

View File

@@ -1102,8 +1102,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
const int jc = cols_per_warp == 8 ? tile_C_VKQ::get_j(col) : tile_C_VKQ_16::get_i(2*col);
//const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
//const int jc = cols_per_warp == 8 ? tile_C_VKQ::get_j(col) : tile_C_VKQ_16::get_i(2*col);
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
const float sink = sinks_f[jc % ncols2];
const float KQ_max_new = fmaxf(KQ_max[col], sink);

View File

@@ -1407,7 +1407,7 @@ static ggml_tensor * llm_build_kqv(
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 ||
model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE) {
model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE || model.arch == LLM_ARCH_MIMO2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);