Support Native Kimi K2 Thinking (#1663)

* [feat]: fix k2 prefill

* Update Kimi-K2-Thinking.md

* Create Kimi-K2-Thinking-Native.md

* Update Kimi-K2-Thinking.md

* Update Kimi-K2-Thinking.md

* Update Kimi-K2-Thinking-Native.md

* [perf] optimize K2 MoE weight loading with per-expert pointers

- Avoid expensive torch.stack().contiguous() in Python (was ~6.6s)
- Use per-expert pointer arrays (gate_projs) instead of contiguous memory
- C++ worker pool performs parallel memcpy for TP slicing
- Add LOAD_TIME_PROFILE for load_weights timing analysis

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: ouqingliang <1692110604@qq.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
ErvinXie
2025-12-05 21:53:05 +08:00
committed by GitHub
parent 4850424345
commit 71f683acec
5 changed files with 419 additions and 70 deletions

View File

@@ -364,16 +364,34 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
raise NotImplementedError("RAWINT4 wrapper expects pre-quantized safetensor weights.")
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
import time
t0 = time.time()
base_key = f"model.layers.{self.layer_idx}"
weights = self.loader.load_experts(base_key)
t1 = time.time()
self.gate_weights = torch.stack(weights["gate"], dim=0).contiguous()
self.up_weights = torch.stack(weights["up"], dim=0).contiguous()
self.down_weights = torch.stack(weights["down"], dim=0).contiguous()
# Keep individual tensors instead of stacking - avoid expensive memory copy
# weights["gate"], weights["up"], weights["down"] are lists of tensors per expert
self.gate_weights = weights["gate"] # list of tensors
self.up_weights = weights["up"]
self.down_weights = weights["down"]
self.gate_scales = torch.stack(weights["gate_scale"], dim=0).to(torch.bfloat16).contiguous()
self.up_scales = torch.stack(weights["up_scale"], dim=0).to(torch.bfloat16).contiguous()
self.down_scales = torch.stack(weights["down_scale"], dim=0).to(torch.bfloat16).contiguous()
# Convert scales to bf16 individually
self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
t2 = time.time()
# Build pointer lists: [numa_id][expert_id] -> pointer
# Since RAWINT4 has no numa sharding, numa dimension is 1
gate_ptrs = [[t.data_ptr() for t in self.gate_weights]]
up_ptrs = [[t.data_ptr() for t in self.up_weights]]
down_ptrs = [[t.data_ptr() for t in self.down_weights]]
gate_scale_ptrs = [[t.data_ptr() for t in self.gate_scales]]
up_scale_ptrs = [[t.data_ptr() for t in self.up_scales]]
down_scale_ptrs = [[t.data_ptr() for t in self.down_scales]]
t3 = time.time()
moe_config = MOEConfig(
self.num_experts,
@@ -390,17 +408,20 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
moe_config.quant_config.group_size = 32
moe_config.quant_config.zero_point = False
moe_config.gate_proj = self.gate_weights.data_ptr()
moe_config.up_proj = self.up_weights.data_ptr()
moe_config.down_proj = self.down_weights.data_ptr()
moe_config.gate_scale = self.gate_scales.data_ptr()
moe_config.up_scale = self.up_scales.data_ptr()
moe_config.down_scale = self.down_scales.data_ptr()
# Use gate_projs instead of gate_proj for per-expert pointers
moe_config.gate_projs = gate_ptrs
moe_config.up_projs = up_ptrs
moe_config.down_projs = down_ptrs
moe_config.gate_scales = gate_scale_ptrs
moe_config.up_scales = up_scale_ptrs
moe_config.down_scales = down_scale_ptrs
self.moe = AMXInt4_KGroup_MOE(moe_config)
t4 = time.time()
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
self.cpu_infer.sync()
t5 = time.time()
del self.gate_weights
del self.up_weights
@@ -408,6 +429,18 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
del self.gate_scales
del self.up_scales
del self.down_scales
t6 = time.time()
print(
f"[RAWAMXMoEWrapper Layer {self.layer_idx}] "
f"load_experts: {(t1-t0)*1000:.1f}ms, "
f"prepare_tensors: {(t2-t1)*1000:.1f}ms, "
f"build_ptrs: {(t3-t2)*1000:.1f}ms, "
f"create_moe: {(t4-t3)*1000:.1f}ms, "
f"cpp_load_weights: {(t5-t4)*1000:.1f}ms, "
f"cleanup: {(t6-t5)*1000:.1f}ms, "
f"total: {(t6-t0)*1000:.1f}ms"
)
def submit_write_weight_scale_to_buffer(
self,