support Native BF16 format MoE. (#1788)

support Native BF16 format MoE
This commit is contained in:
Oql
2026-01-12 14:43:28 +08:00
committed by GitHub
parent ddb957596f
commit 5edc456749
11 changed files with 2149 additions and 501 deletions

View File

@@ -4,16 +4,16 @@ import ctypes
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader
from kt_kernel_ext.moe import MOEConfig
try:
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE
_HAS_AMX_SUPPORT = True
except (ImportError, AttributeError):
_HAS_AMX_SUPPORT = False
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE = None, None, None, None
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE = None, None, None, None, None
from typing import Optional
@@ -304,7 +304,7 @@ class AMXMoEWrapper(BaseMoEWrapper):
class NativeMoEWrapper(BaseMoEWrapper):
"""Wrapper for RAWINT4/FP8 experts stored in compressed SafeTensor format."""
"""Wrapper for RAWINT4/FP8/BF16 experts stored in compressed SafeTensor format."""
_native_loader_instance = None
@@ -330,6 +330,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
if method == "FP8" and AMXFP8_MOE is None:
raise RuntimeError("AMX backend with FP8 support is not available.")
if method == "BF16" and AMXBF16_MOE is None:
raise RuntimeError("AMX backend with BF16 support is not available.")
super().__init__(
layer_idx=layer_idx,
@@ -352,6 +354,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path)
elif method == "FP8":
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path)
elif method == "BF16":
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
else:
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
self.loader = NativeMoEWrapper._native_loader_instance
@@ -386,28 +390,42 @@ class NativeMoEWrapper(BaseMoEWrapper):
self.up_weights = weights["up"]
self.down_weights = weights["down"]
# Convert scales to bf16 individually
# self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
# self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
# self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
self.gate_scales = weights["gate_scale"]
self.up_scales = weights["up_scale"]
self.down_scales = weights["down_scale"]
if self.method == "RAWINT4":
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4"
elif self.method == "FP8":
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
# BF16 has no scales, others have scales
if self.method == "BF16":
# BF16 doesn't have scales
self.gate_scales = None
self.up_scales = None
self.down_scales = None
else:
# Convert scales to bf16 individually
# self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
# self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
# self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
self.gate_scales = weights["gate_scale"]
self.up_scales = weights["up_scale"]
self.down_scales = weights["down_scale"]
if self.method == "RAWINT4":
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4"
elif self.method == "FP8":
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
t2 = time.time()
# Build pointer lists: [numa_id][expert_id] -> pointer
# Since RAWINT4 has no numa sharding, numa dimension is 1
# Since RAWINT4/FP8/BF16 has no numa sharding, numa dimension is 1
gate_ptrs = [[t.data_ptr() for t in self.gate_weights]]
up_ptrs = [[t.data_ptr() for t in self.up_weights]]
down_ptrs = [[t.data_ptr() for t in self.down_weights]]
gate_scale_ptrs = [[t.data_ptr() for t in self.gate_scales]]
up_scale_ptrs = [[t.data_ptr() for t in self.up_scales]]
down_scale_ptrs = [[t.data_ptr() for t in self.down_scales]]
# BF16 has no scales, pass empty lists (will use 0/nullptr for consistency)
if self.method == "BF16":
gate_scale_ptrs = [[0 for _ in self.gate_weights]]
up_scale_ptrs = [[0 for _ in self.up_weights]]
down_scale_ptrs = [[0 for _ in self.down_weights]]
else:
gate_scale_ptrs = [[t.data_ptr() for t in self.gate_scales]]
up_scale_ptrs = [[t.data_ptr() for t in self.up_scales]]
down_scale_ptrs = [[t.data_ptr() for t in self.down_scales]]
t3 = time.time()
moe_config = MOEConfig(
@@ -444,6 +462,9 @@ class NativeMoEWrapper(BaseMoEWrapper):
moe_config.quant_config.group_size = 128
moe_config.quant_config.zero_point = False
self.moe = AMXFP8_MOE(moe_config)
elif self.method == "BF16":
# BF16 has no quantization config needed
self.moe = AMXBF16_MOE(moe_config)
t4 = time.time()
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
@@ -453,9 +474,10 @@ class NativeMoEWrapper(BaseMoEWrapper):
del self.gate_weights
del self.up_weights
del self.down_weights
del self.gate_scales
del self.up_scales
del self.down_scales
if self.gate_scales is not None:
del self.gate_scales
del self.up_scales
del self.down_scales
t6 = time.time()
print(