mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
[feat](kt-kernel): CPU-GPU experts sched (#1796)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import torch
|
||||
import ctypes
|
||||
from typing import Optional
|
||||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
@@ -41,7 +42,7 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
||||
num_experts_per_tok: int,
|
||||
hidden_size: int,
|
||||
moe_intermediate_size: int,
|
||||
num_gpu_experts: int,
|
||||
gpu_experts_mask: Optional[torch.Tensor],
|
||||
cpuinfer_threads: int,
|
||||
threadpool_count: int,
|
||||
weight_path: str,
|
||||
@@ -59,7 +60,10 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
||||
num_experts_per_tok: Number of experts per token (top-k)
|
||||
hidden_size: Hidden dimension size
|
||||
moe_intermediate_size: MoE intermediate size
|
||||
num_gpu_experts: Number of experts to run on GPU
|
||||
gpu_experts_mask: Boolean mask indicating which experts are on GPU.
|
||||
Shape: [num_experts], dtype: torch.bool.
|
||||
mask[i] = True means expert i is on GPU.
|
||||
If None, all experts are on CPU.
|
||||
cpuinfer_threads: Number of CPU inference threads
|
||||
threadpool_count: Number of NUMA subpools
|
||||
weight_path: Path to AMX weights (SafeTensor format)
|
||||
@@ -81,7 +85,7 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
num_gpu_experts=num_gpu_experts,
|
||||
gpu_experts_mask=gpu_experts_mask,
|
||||
cpuinfer_threads=cpuinfer_threads,
|
||||
threadpool_count=threadpool_count,
|
||||
weight_path=weight_path,
|
||||
@@ -139,7 +143,7 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
||||
self.num_experts_per_tok,
|
||||
self.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
self.num_gpu_experts,
|
||||
self.gpu_experts_mask.data_ptr(),
|
||||
)
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
@@ -254,7 +258,7 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
||||
self.num_experts_per_tok,
|
||||
self.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
self.num_gpu_experts,
|
||||
self.gpu_experts_mask.data_ptr(),
|
||||
)
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
@@ -323,7 +327,7 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
num_experts_per_tok: int,
|
||||
hidden_size: int,
|
||||
moe_intermediate_size: int,
|
||||
num_gpu_experts: int,
|
||||
gpu_experts_mask: Optional[torch.Tensor],
|
||||
cpuinfer_threads: int,
|
||||
threadpool_count: int,
|
||||
weight_path: str,
|
||||
@@ -349,7 +353,7 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
num_experts_per_tok=num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=moe_intermediate_size,
|
||||
num_gpu_experts=num_gpu_experts,
|
||||
gpu_experts_mask=gpu_experts_mask,
|
||||
cpuinfer_threads=cpuinfer_threads,
|
||||
threadpool_count=threadpool_count,
|
||||
weight_path=weight_path,
|
||||
@@ -448,7 +452,7 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
self.num_experts_per_tok,
|
||||
self.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
self.num_gpu_experts,
|
||||
self.gpu_experts_mask.data_ptr(),
|
||||
)
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
|
||||
Reference in New Issue
Block a user