mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-14 18:37:23 +00:00
* fix(amx): add BufferASmallKGroupImpl to fix buffer overflow in from_mat The original BufferAKGroupImpl::from_mat writes 64 bytes per K_STEP iteration but when K_STEP=32 (for GemmKernel224Int4SmallKGroup), this causes buffer overflow. BufferASmallKGroupImpl overrides from_mat to write only 32 bytes per iteration. * perf(k2-moe): optimize memory allocation with pooled buffers - Replace per-expert buffer allocation with shared memory pools - Dynamically assign buffer slices based on activated experts - Add group_size inference from scale tensor shape in amx.py * delete kimi k2 forward test * add TODO comment for pool_count_ calculation
This commit is contained in:
@@ -66,6 +66,18 @@ class AMX_K2_MOE_TP {
|
||||
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
|
||||
|
||||
size_t pool_count_ = 0; // rows reserved in each scratch pool
|
||||
size_t gate_up_ba_pool_bytes_ = 0;
|
||||
size_t gate_bc_pool_bytes_ = 0;
|
||||
size_t up_bc_pool_bytes_ = 0;
|
||||
size_t down_ba_pool_bytes_ = 0;
|
||||
size_t down_bc_pool_bytes_ = 0;
|
||||
void* gate_up_ba_pool_ = nullptr;
|
||||
void* gate_bc_pool_ = nullptr;
|
||||
void* up_bc_pool_ = nullptr;
|
||||
void* down_ba_pool_ = nullptr;
|
||||
void* down_bc_pool_ = nullptr;
|
||||
#ifdef CHECK
|
||||
char verify_bb[100000000];
|
||||
char check_bb[100000000];
|
||||
@@ -215,18 +227,23 @@ class AMX_K2_MOE_TP {
|
||||
down_bb_.push_back(std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size,
|
||||
group_size, down_bb_ptr));
|
||||
}
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
mem_requests.append_function([this, i](void* new_ptr) { gate_up_ba_[i]->set_data(new_ptr); },
|
||||
T::BufferA::required_size(config_.max_len, config_.hidden_size, group_size));
|
||||
mem_requests.append_function([this, i](void* new_ptr) { gate_bc_[i]->set_data(new_ptr); },
|
||||
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
|
||||
mem_requests.append_function([this, i](void* new_ptr) { up_bc_[i]->set_data(new_ptr); },
|
||||
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
|
||||
mem_requests.append_function([this, i](void* new_ptr) { down_ba_[i]->set_data(new_ptr); },
|
||||
T::BufferA::required_size(config_.max_len, config_.intermediate_size, group_size));
|
||||
mem_requests.append_function([this, i](void* new_ptr) { down_bc_[i]->set_data(new_ptr); },
|
||||
T::BufferC::required_size(config_.max_len, config_.hidden_size));
|
||||
}
|
||||
assert(T::BufferA::M_STEP == T::BufferC::M_STEP);
|
||||
// TODO: need update to all *.hpp
|
||||
// (config_.expert_num * T::BufferA::M_STEP) in pool_count_ is to ensure padding for each experts.
|
||||
pool_count_ = config_.max_len * config_.num_experts_per_tok + config_.expert_num * T::BufferA::M_STEP;
|
||||
|
||||
gate_up_ba_pool_bytes_ = (T::BufferA::required_size(pool_count_, config_.hidden_size, group_size)) + pool_count_ * 64;
|
||||
gate_bc_pool_bytes_ = (T::BufferC::required_size(pool_count_, config_.intermediate_size)) + pool_count_ * 64;
|
||||
up_bc_pool_bytes_ = (T::BufferC::required_size(pool_count_, config_.intermediate_size)) + pool_count_ * 64;
|
||||
down_ba_pool_bytes_ = (T::BufferA::required_size(pool_count_, config_.intermediate_size, group_size)) + pool_count_ * 64;
|
||||
down_bc_pool_bytes_ = (T::BufferC::required_size(pool_count_, config_.hidden_size)) + pool_count_ * 64;
|
||||
|
||||
mem_requests.append_pointer(&gate_up_ba_pool_, gate_up_ba_pool_bytes_);
|
||||
mem_requests.append_pointer(&gate_bc_pool_, gate_bc_pool_bytes_);
|
||||
mem_requests.append_pointer(&up_bc_pool_, up_bc_pool_bytes_);
|
||||
mem_requests.append_pointer(&down_ba_pool_, down_ba_pool_bytes_);
|
||||
mem_requests.append_pointer(&down_bc_pool_, down_bc_pool_bytes_);
|
||||
|
||||
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
|
||||
}
|
||||
|
||||
@@ -552,13 +569,61 @@ class AMX_K2_MOE_TP {
|
||||
// activated_expert 已经统计完成
|
||||
|
||||
size_t offset = 0;
|
||||
void* gate_up_ba_pool_ptr = gate_up_ba_pool_;
|
||||
void* gate_bc_pool_ptr = gate_bc_pool_;
|
||||
void* up_bc_pool_ptr = up_bc_pool_;
|
||||
void* down_ba_pool_ptr = down_ba_pool_;
|
||||
void* down_bc_pool_ptr = down_bc_pool_;
|
||||
constexpr size_t M_STEP = T::BufferA::M_STEP;
|
||||
auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };
|
||||
size_t used_pool_m = 0;
|
||||
size_t used_pool_bytes_a = 0, used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0,
|
||||
used_pool_bytes_bc_down = 0;
|
||||
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
|
||||
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
|
||||
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
|
||||
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
|
||||
offset += m_local_num_[i];
|
||||
|
||||
if (m_local_num_[i] == 0)
|
||||
continue;
|
||||
size_t max_m = (m_local_num_[i] + M_STEP - 1) / M_STEP * M_STEP;
|
||||
gate_up_ba_[i]->max_m = max_m;
|
||||
gate_up_ba_[i]->set_data(gate_up_ba_pool_ptr);
|
||||
size_t ba_size = align64(T::BufferA::required_size(max_m, config_.hidden_size, group_size));
|
||||
gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size);
|
||||
gate_bc_[i]->max_m = max_m;
|
||||
gate_bc_[i]->set_data(gate_bc_pool_ptr);
|
||||
size_t bc_gate_size = align64(T::BufferC::required_size(max_m, config_.intermediate_size));
|
||||
gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size);
|
||||
up_bc_[i]->max_m = max_m;
|
||||
up_bc_[i]->set_data(up_bc_pool_ptr);
|
||||
size_t bc_up_size = align64(T::BufferC::required_size(max_m, config_.intermediate_size));
|
||||
up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size);
|
||||
down_ba_[i]->max_m = max_m;
|
||||
down_ba_[i]->set_data(down_ba_pool_ptr);
|
||||
size_t ba_down_size = align64(T::BufferA::required_size(max_m, config_.intermediate_size, group_size));
|
||||
down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size);
|
||||
down_bc_[i]->max_m = max_m;
|
||||
down_bc_[i]->set_data(down_bc_pool_ptr);
|
||||
size_t bc_down_size = align64(T::BufferC::required_size(max_m, config_.hidden_size));
|
||||
down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size);
|
||||
used_pool_m += max_m;
|
||||
used_pool_bytes_a += ba_size;
|
||||
used_pool_bytes_bc_gate += bc_gate_size;
|
||||
used_pool_bytes_bc_up += bc_up_size;
|
||||
used_pool_bytes_ba_down += ba_down_size;
|
||||
used_pool_bytes_bc_down += bc_down_size;
|
||||
}
|
||||
assert(used_pool_m <= pool_count_);
|
||||
assert(used_pool_bytes_a <= gate_up_ba_pool_bytes_);
|
||||
assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_);
|
||||
assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_);
|
||||
assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_);
|
||||
assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
@@ -771,6 +836,59 @@ class AMX_K2_MOE_TP {
|
||||
offset += qlen;
|
||||
}
|
||||
|
||||
void* gate_up_ba_pool_ptr = gate_up_ba_pool_;
|
||||
void* gate_bc_pool_ptr = gate_bc_pool_;
|
||||
void* up_bc_pool_ptr = up_bc_pool_;
|
||||
void* down_ba_pool_ptr = down_ba_pool_;
|
||||
void* down_bc_pool_ptr = down_bc_pool_;
|
||||
constexpr size_t M_STEP = T::BufferA::M_STEP;
|
||||
auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };
|
||||
size_t used_pool_m = 0;
|
||||
size_t used_pool_bytes_a = 0, used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0,
|
||||
used_pool_bytes_bc_down = 0;
|
||||
for (int i = 0; i < activated_expert; i++) {
|
||||
auto expert_idx = m_expert_id_map_[i];
|
||||
size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;
|
||||
|
||||
gate_up_ba_[expert_idx]->max_m = max_m;
|
||||
gate_up_ba_[expert_idx]->set_data(gate_up_ba_pool_ptr);
|
||||
size_t ba_size = align64(T::BufferA::required_size(max_m, config_.hidden_size, group_size));
|
||||
gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size);
|
||||
|
||||
gate_bc_[expert_idx]->max_m = max_m;
|
||||
gate_bc_[expert_idx]->set_data(gate_bc_pool_ptr);
|
||||
size_t bc_gate_size = align64(T::BufferC::required_size(max_m, config_.intermediate_size));
|
||||
gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size);
|
||||
|
||||
up_bc_[expert_idx]->max_m = max_m;
|
||||
up_bc_[expert_idx]->set_data(up_bc_pool_ptr);
|
||||
size_t bc_up_size = align64(T::BufferC::required_size(max_m, config_.intermediate_size));
|
||||
up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size);
|
||||
|
||||
down_ba_[expert_idx]->max_m = max_m;
|
||||
down_ba_[expert_idx]->set_data(down_ba_pool_ptr);
|
||||
size_t ba_down_size = align64(T::BufferA::required_size(max_m, config_.intermediate_size, group_size));
|
||||
down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size);
|
||||
|
||||
down_bc_[expert_idx]->max_m = max_m;
|
||||
down_bc_[expert_idx]->set_data(down_bc_pool_ptr);
|
||||
size_t bc_down_size = align64(T::BufferC::required_size(max_m, config_.hidden_size));
|
||||
down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size);
|
||||
|
||||
used_pool_m += max_m;
|
||||
used_pool_bytes_a += ba_size;
|
||||
used_pool_bytes_bc_gate += bc_gate_size;
|
||||
used_pool_bytes_bc_up += bc_up_size;
|
||||
used_pool_bytes_ba_down += ba_down_size;
|
||||
used_pool_bytes_bc_down += bc_down_size;
|
||||
}
|
||||
assert(used_pool_m <= pool_count_);
|
||||
assert(used_pool_bytes_a <= gate_up_ba_pool_bytes_);
|
||||
assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_);
|
||||
assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_);
|
||||
assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_);
|
||||
assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_);
|
||||
|
||||
gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
|
||||
@@ -442,6 +442,78 @@ struct BufferAKGroupImpl {
|
||||
}
|
||||
};
|
||||
|
||||
// BufferASmallKGroupImpl: For kernels with K_STEP=32 (e.g., GemmKernel224Int4SmallKGroup)
|
||||
// This fixes the buffer overflow issue where the base class writes 64 bytes per K_STEP iteration
|
||||
// but the buffer is only sized for 32-byte steps.
|
||||
template <typename K>
|
||||
struct BufferASmallKGroupImpl : public BufferAKGroupImpl<K> {
|
||||
using Base = BufferAKGroupImpl<K>;
|
||||
using Base::a;
|
||||
using Base::d;
|
||||
using Base::k;
|
||||
using Base::k_group_count;
|
||||
using Base::k_group_size;
|
||||
using Base::max_m;
|
||||
|
||||
static constexpr int M_STEP = K::M_STEP;
|
||||
static constexpr int K_STEP = K::K_STEP;
|
||||
static constexpr int K_BLOCK = K::K_BLOCK;
|
||||
|
||||
BufferASmallKGroupImpl(int max_m, int k, int k_group_size, void* ptr)
|
||||
: Base(max_m, k, k_group_size, ptr) {}
|
||||
|
||||
// Override from_mat to write only 32 bytes per K_STEP iteration
|
||||
void from_mat(int m, ggml_bf16_t* src, int ith, int nth) {
|
||||
assert(m <= max_m);
|
||||
assert(ith == 0 && nth == 1);
|
||||
|
||||
// Calculate scale for each k_group (same as base class)
|
||||
for (int m_idx = 0; m_idx < m; m_idx++) {
|
||||
for (int kg = 0; kg < k_group_count; kg++) {
|
||||
float amax = 0.0f;
|
||||
int k_start = kg * k_group_size;
|
||||
int k_end = k_start + k_group_size;
|
||||
for (int j = k_start; j < k_end; j += 32) {
|
||||
__m512 f0, f1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(src + m_idx * k + j), &f0, &f1);
|
||||
amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));
|
||||
amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));
|
||||
}
|
||||
d[m_idx * k_group_count + kg] = amax / ((1 << 7) - 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Quantization with 32-byte writes per K_STEP iteration
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
|
||||
// Get the scale for this k_group
|
||||
int k_group_idx = (k_block_begin + k_begin) / k_group_size;
|
||||
float scale = d[(m_begin + i) * k_group_count + k_group_idx];
|
||||
__m512 id = _mm512_set1_ps(scale ? 1.0f / scale : 0.0f);
|
||||
|
||||
// Calculate destination - writes K_STEP (32) bytes
|
||||
int8_t* dst = a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP;
|
||||
|
||||
// Only process 32 bytes (2 x __m512 -> 2 x __m128i) per iteration
|
||||
__m512 f0, f1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin), &f0, &f1);
|
||||
__m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));
|
||||
__m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));
|
||||
__m128i s0 = _mm512_cvtsepi32_epi8(i0);
|
||||
__m128i s1 = _mm512_cvtsepi32_epi8(i1);
|
||||
_mm_store_si128((__m128i*)dst, s0);
|
||||
_mm_store_si128((__m128i*)(dst + 16), s1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename K>
|
||||
struct BufferBInt4Impl {
|
||||
using dt = typename K::dt;
|
||||
|
||||
@@ -2886,7 +2886,7 @@ struct GemmKernel224Int4SmallKGroup {
|
||||
static __m256i lo_mask() { return *((__m256i*)(&lo_mask_arr[0])); }
|
||||
static __m256i sign_xor_mask() { return *((__m256i*)(&sign_xor_arr[0])); }
|
||||
|
||||
using BufferA = BufferAKGroupImpl<GemmKernel224Int4SmallKGroup>;
|
||||
using BufferA = BufferASmallKGroupImpl<GemmKernel224Int4SmallKGroup>;
|
||||
using BufferB = BufferBInt4KGroupImpl<GemmKernel224Int4SmallKGroup>; // Use new signed int4 buffer
|
||||
using BufferC = BufferCReduceImpl<GemmKernel224Int4SmallKGroup>;
|
||||
|
||||
|
||||
@@ -404,8 +404,16 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
|
||||
# Infer group_size from scale shape (column-major layout)
|
||||
# For gate/up projection: in_features = hidden_size
|
||||
# So: group_size = hidden_size / scale.shape[1]
|
||||
scale_shape = self.gate_scales[0].shape
|
||||
group_size = self.hidden_size // scale_shape[1]
|
||||
print(f"[RAWAMXMoEWrapper Layer {self.layer_idx}] Inferred group_size: {group_size}")
|
||||
|
||||
moe_config.quant_config.bits = 4
|
||||
moe_config.quant_config.group_size = 32
|
||||
moe_config.quant_config.group_size = group_size
|
||||
|
||||
moe_config.quant_config.zero_point = False
|
||||
|
||||
# Use gate_projs instead of gate_proj for per-expert pointers
|
||||
|
||||
Reference in New Issue
Block a user