feat: add deferred expert scheduling support

This commit is contained in:
chenht2022
2025-10-31 08:03:37 +00:00
parent 4a9b6cd99e
commit dd4377b60b

View File

@@ -12,7 +12,7 @@ implementations, handling weight loading, buffer management, and forward inferen
from __future__ import annotations
import torch
from typing import List, Dict
from typing import Dict, List, Optional, Tuple
from safetensors import safe_open
import os
import ctypes
@@ -146,30 +146,60 @@ class KExpertsCPUBuffer:
capture_buffers: Dict = dict()
temp_bs: int = 0
temp_buffer: tuple = tuple()
buffer_depth: int = 2
@classmethod
def get_buffer(cls, hidden_states: torch.Tensor, num_experts_per_tok):
hidden_size = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_size)
batch_size, hidden_size = hidden_states.shape
batch_size = hidden_states.shape[0]
if batch_size in KExpertsCPUBuffer.capture_buffers:
return KExpertsCPUBuffer.capture_buffers[batch_size]
if batch_size == KExpertsCPUBuffer.temp_bs:
return KExpertsCPUBuffer.temp_buffer
if batch_size in cls.capture_buffers:
return cls.capture_buffers[batch_size]
if batch_size == cls.temp_bs:
return cls.temp_buffer
input_tensor_cpu = torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
expert_ids_cpu = torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
weights_cpu = torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
output_cpu = torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
bsz_tensor_cpu = torch.tensor((batch_size), device="cpu", dtype=torch.int32, pin_memory=True)
output_gpu = torch.zeros_like(hidden_states)
input_tensor_cpu = [
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
for _ in range(cls.buffer_depth)
]
immediate_experts_ids_cpu = [
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
for _ in range(cls.buffer_depth)
]
deferred_experts_ids_cpu = [
torch.full((batch_size, num_experts_per_tok), -1, device="cpu", dtype=torch.long, pin_memory=True)
for _ in range(cls.buffer_depth)
]
weights_cpu = [
torch.zeros((batch_size, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
for _ in range(cls.buffer_depth)
]
output_cpu = [
torch.zeros((batch_size, hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
for _ in range(cls.buffer_depth)
]
bsz_tensor_cpu = [
torch.zeros((1,), device="cpu", dtype=torch.int32, pin_memory=True)
for _ in range(cls.buffer_depth)
]
output_gpu = [
torch.zeros((batch_size, hidden_size), device=hidden_states.device, dtype=hidden_states.dtype)
for _ in range(cls.buffer_depth)
]
cur_buffer = (input_tensor_cpu, expert_ids_cpu, weights_cpu, output_cpu, bsz_tensor_cpu, output_gpu)
if batch_size in KExpertsCPUBuffer.capture_bs:
KExpertsCPUBuffer.capture_buffers[batch_size] = cur_buffer
KExpertsCPUBuffer.temp_bs = batch_size
KExpertsCPUBuffer.temp_buffer = cur_buffer
cur_buffer = (
input_tensor_cpu,
immediate_experts_ids_cpu,
deferred_experts_ids_cpu,
weights_cpu,
output_cpu,
bsz_tensor_cpu,
output_gpu,
)
if batch_size in cls.capture_bs:
cls.capture_buffers[batch_size] = cur_buffer
cls.temp_bs = batch_size
cls.temp_buffer = cur_buffer
return cur_buffer
@@ -181,6 +211,7 @@ class AMXMoEWrapper:
_cpu_infer_instance = None
_safetensor_loader_instance = None
_layer_has_pending_deferred: Dict[int, bool] = {}
def __init__(
self,
@@ -195,6 +226,7 @@ class AMXMoEWrapper:
amx_weight_path: str,
chunked_prefill_size: int,
cpu_save: bool = False,
max_deferred_experts_per_token: Optional[int] = None,
):
"""
Initialize AMX MoE Wrapper.
@@ -211,6 +243,7 @@ class AMXMoEWrapper:
amx_weight_path: Path to AMX weights
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).
"""
self.layer_idx = layer_idx
@@ -222,6 +255,9 @@ class AMXMoEWrapper:
self.amx_weight_path = amx_weight_path
self.chunked_prefill_size = chunked_prefill_size
self.cpu_save = cpu_save
self.max_deferred_experts_per_token = int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
# Initialize CPU inference engine (singleton)
if AMXMoEWrapper._cpu_infer_instance is None:
@@ -461,6 +497,36 @@ class AMXMoEWrapper:
del self.up_scales
del self.down_scales
def select_deferred_experts(
self,
expert_ids: torch.Tensor,
expert_scores: torch.Tensor,
protected_k: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, topk = expert_ids.shape
device = expert_ids.device
protected_k = max(0, min(int(protected_k), topk))
if protected_k == 0:
deferred_ids = expert_ids.clone()
immediate_ids = torch.full_like(expert_ids, -1)
return immediate_ids, deferred_ids
topk_result = torch.topk(expert_scores, k=protected_k, dim=-1, largest=True, sorted=False)
protected_indices = topk_result.indices
protected_ids = torch.gather(expert_ids, -1, protected_indices)
protected_flag = torch.zeros((self.num_experts,), dtype=torch.int32, device=device)
protected_flag.scatter_(0, protected_ids.reshape(-1), 1)
protected_mask_flat = torch.gather(protected_flag, 0, expert_ids.reshape(-1)).ne(0)
protected_mask = protected_mask_flat.view(batch, topk)
immediate_ids = expert_ids.clone().masked_fill(~protected_mask, -1)
deferred_ids = expert_ids.clone().masked_fill(protected_mask, -1)
return immediate_ids, deferred_ids
def submit_forward(
self,
hidden_states: torch.Tensor,
@@ -477,36 +543,72 @@ class AMXMoEWrapper:
topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]
cuda_stream: CUDA stream for synchronization
"""
# Get CPU buffers
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
batch_size = flat_hidden_states.shape[0]
(
input_tensor_cpu,
expert_ids_cpu,
immediate_experts_ids_cpu,
deferred_experts_ids_cpu,
weights_cpu,
output_cpu,
bsz_tensor_cpu,
output_gpu,
) = KExpertsCPUBuffer.get_buffer(hidden_states, self.num_experts_per_tok)
_output_gpu,
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
# Copy data to CPU
topk_ids = topk_ids.to(torch.long)
input_tensor_cpu.copy_(hidden_states, non_blocking=True)
expert_ids_cpu.copy_(topk_ids, non_blocking=True)
weights_cpu.copy_(topk_weights, non_blocking=True)
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
next_slot = (current_slot + 1) % KExpertsCPUBuffer.buffer_depth
# Submit task
bsz_slot_tensor = bsz_tensor_cpu[current_slot]
bsz_slot_tensor.fill_(batch_size)
deferred_experts_ids_cpu[current_slot].fill_(-1)
topk_ids_long = topk_ids.to(torch.long)
immediate_ids: torch.Tensor
deferred_ids: Optional[torch.Tensor]
if self.max_deferred_experts_per_token > 0:
protected_k = self.num_experts_per_tok - self.max_deferred_experts_per_token
immediate_ids, deferred_ids = self.select_deferred_experts(topk_ids_long, topk_weights, protected_k)
else:
immediate_ids = topk_ids_long
deferred_ids = None
input_tensor_cpu[current_slot].copy_(flat_hidden_states, non_blocking=True)
weights_cpu[current_slot].copy_(topk_weights, non_blocking=True)
immediate_experts_ids_cpu[current_slot].copy_(immediate_ids, non_blocking=True)
incremental = AMXMoEWrapper._layer_has_pending_deferred.get(self.layer_idx - 1, False)
self.cpu_infer.submit_with_cuda_stream(
cuda_stream,
self.moe.forward_task(
bsz_tensor_cpu.data_ptr(),
expert_ids_cpu.size(-1),
expert_ids_cpu.data_ptr(),
weights_cpu.data_ptr(),
input_tensor_cpu.data_ptr(),
output_cpu.data_ptr(),
False,
bsz_slot_tensor.data_ptr(),
immediate_experts_ids_cpu[current_slot].size(-1),
immediate_experts_ids_cpu[current_slot].data_ptr(),
weights_cpu[current_slot].data_ptr(),
input_tensor_cpu[current_slot].data_ptr(),
output_cpu[current_slot].data_ptr(),
incremental,
),
)
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
if deferred_ids is not None:
deferred_experts_ids_cpu[current_slot].copy_(deferred_ids, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(
cuda_stream,
self.moe.forward_task(
bsz_slot_tensor.data_ptr(),
deferred_experts_ids_cpu[current_slot].size(-1),
deferred_experts_ids_cpu[current_slot].data_ptr(),
weights_cpu[current_slot].data_ptr(),
input_tensor_cpu[current_slot].data_ptr(),
output_cpu[next_slot].data_ptr(),
False,
),
)
AMXMoEWrapper._layer_has_pending_deferred[self.layer_idx] = True
def sync_forward(self, hidden_states: torch.Tensor, cuda_stream) -> torch.Tensor:
"""
Synchronize and retrieve forward inference results.
@@ -518,18 +620,22 @@ class AMXMoEWrapper:
Returns:
output_gpu: Output tensor on GPU
"""
flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
(
input_tensor_cpu,
expert_ids_cpu,
immediate_experts_ids_cpu,
_deferred_experts_ids_cpu,
weights_cpu,
output_cpu,
bsz_tensor_cpu,
_bsz_tensor_cpu,
output_gpu,
) = KExpertsCPUBuffer.get_buffer(hidden_states, self.num_experts_per_tok)
) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)
self.cpu_infer.sync_with_cuda_stream(cuda_stream)
output_gpu.copy_(output_cpu, non_blocking=True)
return output_gpu
current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth
allow_pending = 1 if AMXMoEWrapper._layer_has_pending_deferred.get(self.layer_idx, False) else 0
self.cpu_infer.sync_with_cuda_stream(cuda_stream, allow_pending)
output_gpu[current_slot].copy_(output_cpu[current_slot], non_blocking=True)
return output_gpu[current_slot]
def forward(
self,