mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-19 22:09:10 +00:00
Support Native Kimi K2 Thinking (#1663)
* [feat]: fix k2 prefill * Update Kimi-K2-Thinking.md * Create Kimi-K2-Thinking-Native.md * Update Kimi-K2-Thinking.md * Update Kimi-K2-Thinking.md * Update Kimi-K2-Thinking-Native.md * [perf] optimize K2 MoE weight loading with per-expert pointers - Avoid expensive torch.stack().contiguous() in Python (was ~6.6s) - Use per-expert pointer arrays (gate_projs) instead of contiguous memory - C++ worker pool performs parallel memcpy for TP slicing - Add LOAD_TIME_PROFILE for load_weights timing analysis 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: ouqingliang <1692110604@qq.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -364,16 +364,34 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
raise NotImplementedError("RAWINT4 wrapper expects pre-quantized safetensor weights.")
|
||||
|
||||
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
|
||||
import time
|
||||
|
||||
t0 = time.time()
|
||||
base_key = f"model.layers.{self.layer_idx}"
|
||||
weights = self.loader.load_experts(base_key)
|
||||
t1 = time.time()
|
||||
|
||||
self.gate_weights = torch.stack(weights["gate"], dim=0).contiguous()
|
||||
self.up_weights = torch.stack(weights["up"], dim=0).contiguous()
|
||||
self.down_weights = torch.stack(weights["down"], dim=0).contiguous()
|
||||
# Keep individual tensors instead of stacking - avoid expensive memory copy
|
||||
# weights["gate"], weights["up"], weights["down"] are lists of tensors per expert
|
||||
self.gate_weights = weights["gate"] # list of tensors
|
||||
self.up_weights = weights["up"]
|
||||
self.down_weights = weights["down"]
|
||||
|
||||
self.gate_scales = torch.stack(weights["gate_scale"], dim=0).to(torch.bfloat16).contiguous()
|
||||
self.up_scales = torch.stack(weights["up_scale"], dim=0).to(torch.bfloat16).contiguous()
|
||||
self.down_scales = torch.stack(weights["down_scale"], dim=0).to(torch.bfloat16).contiguous()
|
||||
# Convert scales to bf16 individually
|
||||
self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
|
||||
self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
|
||||
self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
|
||||
t2 = time.time()
|
||||
|
||||
# Build pointer lists: [numa_id][expert_id] -> pointer
|
||||
# Since RAWINT4 has no numa sharding, numa dimension is 1
|
||||
gate_ptrs = [[t.data_ptr() for t in self.gate_weights]]
|
||||
up_ptrs = [[t.data_ptr() for t in self.up_weights]]
|
||||
down_ptrs = [[t.data_ptr() for t in self.down_weights]]
|
||||
gate_scale_ptrs = [[t.data_ptr() for t in self.gate_scales]]
|
||||
up_scale_ptrs = [[t.data_ptr() for t in self.up_scales]]
|
||||
down_scale_ptrs = [[t.data_ptr() for t in self.down_scales]]
|
||||
t3 = time.time()
|
||||
|
||||
moe_config = MOEConfig(
|
||||
self.num_experts,
|
||||
@@ -390,17 +408,20 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
moe_config.quant_config.group_size = 32
|
||||
moe_config.quant_config.zero_point = False
|
||||
|
||||
moe_config.gate_proj = self.gate_weights.data_ptr()
|
||||
moe_config.up_proj = self.up_weights.data_ptr()
|
||||
moe_config.down_proj = self.down_weights.data_ptr()
|
||||
moe_config.gate_scale = self.gate_scales.data_ptr()
|
||||
moe_config.up_scale = self.up_scales.data_ptr()
|
||||
moe_config.down_scale = self.down_scales.data_ptr()
|
||||
# Use gate_projs instead of gate_proj for per-expert pointers
|
||||
moe_config.gate_projs = gate_ptrs
|
||||
moe_config.up_projs = up_ptrs
|
||||
moe_config.down_projs = down_ptrs
|
||||
moe_config.gate_scales = gate_scale_ptrs
|
||||
moe_config.up_scales = up_scale_ptrs
|
||||
moe_config.down_scales = down_scale_ptrs
|
||||
|
||||
self.moe = AMXInt4_KGroup_MOE(moe_config)
|
||||
t4 = time.time()
|
||||
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
t5 = time.time()
|
||||
|
||||
del self.gate_weights
|
||||
del self.up_weights
|
||||
@@ -408,6 +429,18 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
del self.gate_scales
|
||||
del self.up_scales
|
||||
del self.down_scales
|
||||
t6 = time.time()
|
||||
|
||||
print(
|
||||
f"[RAWAMXMoEWrapper Layer {self.layer_idx}] "
|
||||
f"load_experts: {(t1-t0)*1000:.1f}ms, "
|
||||
f"prepare_tensors: {(t2-t1)*1000:.1f}ms, "
|
||||
f"build_ptrs: {(t3-t2)*1000:.1f}ms, "
|
||||
f"create_moe: {(t4-t3)*1000:.1f}ms, "
|
||||
f"cpp_load_weights: {(t5-t4)*1000:.1f}ms, "
|
||||
f"cleanup: {(t6-t5)*1000:.1f}ms, "
|
||||
f"total: {(t6-t0)*1000:.1f}ms"
|
||||
)
|
||||
|
||||
def submit_write_weight_scale_to_buffer(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user