mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
Minor
This commit is contained in:
@@ -1102,8 +1102,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int col = 0; col < cols_per_thread; ++col) {
|
for (int col = 0; col < cols_per_thread; ++col) {
|
||||||
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
|
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
|
||||||
const int jc = cols_per_warp == 8 ? tile_C_VKQ::get_j(col) : tile_C_VKQ_16::get_i(2*col);
|
//const int jc = cols_per_warp == 8 ? tile_C_VKQ::get_j(col) : tile_C_VKQ_16::get_i(2*col);
|
||||||
//const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
|
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
|
||||||
const float sink = sinks_f[jc % ncols2];
|
const float sink = sinks_f[jc % ncols2];
|
||||||
|
|
||||||
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
||||||
|
|||||||
@@ -1407,7 +1407,7 @@ static ggml_tensor * llm_build_kqv(
|
|||||||
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
|
|
||||||
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 ||
|
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 ||
|
||||||
model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE) {
|
model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE || model.arch == LLM_ARCH_MIMO2) {
|
||||||
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
||||||
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
|
|||||||
Reference in New Issue
Block a user