mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Mimo-V2-Flash support (#1096)
* Mimo-2 support * Fix bug for head sizes not being the same It still does not solve the Mimo-2 quantized cache issue. * Fix quantized cache * Minor --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -1080,6 +1080,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Finally, sum up partial KQ rowsums.
|
||||
// The partial sums are spread across 8/4 threads each, does not need full reduce.
|
||||
{
|
||||
constexpr int offset_first = ntiles == 1 ? 16 : 2;
|
||||
constexpr int offset_last = ntiles == 1 ? 4 : 1;
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
#pragma unroll
|
||||
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
||||
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
||||
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
|
||||
// so it's being done unconditionally for every thread.
|
||||
@@ -1088,6 +1102,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
|
||||
//const int jc = cols_per_warp == 8 ? tile_C_VKQ::get_j(col) : tile_C_VKQ_16::get_i(2*col);
|
||||
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
|
||||
const float sink = sinks_f[jc % ncols2];
|
||||
|
||||
@@ -1126,20 +1141,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, sum up partial KQ rowsums.
|
||||
// The partial sums are spread across 8/4 threads each, does not need full reduce.
|
||||
{
|
||||
constexpr int offset_first = ntiles == 1 ? 16 : 2;
|
||||
constexpr int offset_last = ntiles == 1 ? 4 : 1;
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
#pragma unroll
|
||||
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
||||
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Combine VKQ accumulator values if np > 1.
|
||||
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
||||
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
||||
@@ -1803,18 +1804,16 @@ static void launch_fattn_new_mma(
|
||||
to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream);
|
||||
K_data = (char *) K_f16.ptr;
|
||||
|
||||
nb11 = K->ne[0]*sizeof(half);
|
||||
nb12 = nb11*K->ne[1];
|
||||
nb13 = nb12*K->ne[2];
|
||||
auto bs = ggml_blck_size(K->type);
|
||||
auto ts = ggml_type_size(K->type);
|
||||
|
||||
// Original PR in llama.cpp. I don't think that can work when K is not contiguous (e.g., nb11 > nb12), there are
|
||||
// gaps between the rows, etc., as ggml_get_to_fp16_cuda stores into contiguous memory.
|
||||
//const size_t bs = ggml_blck_size(K->type);
|
||||
//const size_t ts = ggml_type_size(K->type);
|
||||
nb11 = nb11*bs*sizeof(half)/ts;
|
||||
nb12 = nb12*bs*sizeof(half)/ts;
|
||||
nb13 = nb13*bs*sizeof(half)/ts;
|
||||
|
||||
//nb11 = nb11*bs*sizeof(half)/ts;
|
||||
//nb12 = nb12*bs*sizeof(half)/ts;
|
||||
//nb13 = nb13*bs*sizeof(half)/ts;
|
||||
//nb11 = K->ne[0]*sizeof(half);
|
||||
//nb12 = nb11*K->ne[1];
|
||||
//nb13 = nb12*K->ne[2];
|
||||
}
|
||||
|
||||
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
@@ -1831,17 +1830,16 @@ static void launch_fattn_new_mma(
|
||||
to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16.ptr;
|
||||
|
||||
nb21 = K->ne[0]*sizeof(half);
|
||||
nb22 = nb21*V->ne[1];
|
||||
nb23 = nb22*V->ne[2];
|
||||
auto bs = ggml_blck_size(V->type);
|
||||
auto ts = ggml_type_size(V->type);
|
||||
|
||||
// Original PR in llama.cpp. Same comment as above for the K cache.
|
||||
//const size_t bs = ggml_blck_size(V->type);
|
||||
//const size_t ts = ggml_type_size(V->type);
|
||||
nb21 = nb21*bs*sizeof(half)/ts;
|
||||
nb22 = nb22*bs*sizeof(half)/ts;
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
|
||||
//nb21 = nb21*bs*sizeof(half)/ts;
|
||||
//nb22 = nb22*bs*sizeof(half)/ts;
|
||||
//nb23 = nb23*bs*sizeof(half)/ts;
|
||||
//nb21 = V->ne[0]*sizeof(half);
|
||||
//nb22 = nb21*V->ne[1];
|
||||
//nb23 = nb22*V->ne[2];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2145,10 +2143,10 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
//}
|
||||
if (K->ne[0] == 192 && V->ne[0] == 128) {
|
||||
GGML_ASSERT(Q->ne[0] == 192);
|
||||
GGML_ASSERT(gqa_ratio == 1);
|
||||
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<192, 128>(ctx, dst);
|
||||
//GGML_ASSERT(gqa_ratio == 1); // Haha, this assert was for DeepSeek. But now we have Mimo2, which has GQA > 1
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<192, 128>(ctx, dst);
|
||||
// Reduce compile time
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst);
|
||||
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (K->ne[0] == 192 && V->ne[0] == 192) {
|
||||
|
||||
@@ -93,7 +93,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
// On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache.
|
||||
//const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
||||
//const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0);
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0 && K->ne[0] == V->ne[0]);
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0;
|
||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1 && !ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
@@ -107,6 +107,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
// so no other implementation works.
|
||||
//
|
||||
if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
|
||||
//printf("Using ggml_cuda_flash_attn_ext_mma_new\n");
|
||||
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||
return;
|
||||
}
|
||||
@@ -170,7 +171,7 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te
|
||||
// On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache.
|
||||
//const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
||||
//const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0);
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0 && K->ne[0] == V->ne[0]);
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0;
|
||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1 && !ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
return ggml_cuda_fattn_vec_f32_is_supported(ctx, dst);
|
||||
|
||||
Reference in New Issue
Block a user