mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-14 18:37:23 +00:00
40 lines
1.4 KiB
Python
40 lines
1.4 KiB
Python
import os
|
|
import sys
|
|
|
|
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
|
import torch
|
|
import ctypes
|
|
from kt_kernel import kt_kernel_ext
|
|
from kt_kernel_ext.moe import MOEConfig, MOE, AMXBF16_MOE, AMXInt8_MOE, AMXInt4_MOE, AMXInt4_1_MOE
|
|
|
|
intermediate_size_full = 2048
|
|
moe_intermediate_size = 3072
|
|
hidden_size = 7168
|
|
experts_num = 256
|
|
num_experts_per_tok = 8
|
|
cpu_infer = kt_kernel_ext.CPUInfer(97)
|
|
|
|
up = torch.empty(experts_num, intermediate_size_full, hidden_size, dtype=torch.bfloat16, device="cpu")
|
|
|
|
gate = torch.empty(experts_num, intermediate_size_full, hidden_size, dtype=torch.bfloat16, device="cpu")
|
|
|
|
down = torch.empty(experts_num, hidden_size, intermediate_size_full, dtype=torch.bfloat16, device="cpu")
|
|
|
|
gate_ptr = ctypes.addressof(ctypes.cast(gate.data_ptr(), ctypes.POINTER(ctypes.c_uint64)).contents)
|
|
up_ptr = ctypes.addressof(ctypes.cast(up.data_ptr(), ctypes.POINTER(ctypes.c_uint64)).contents)
|
|
down_ptr = ctypes.addressof(ctypes.cast(down.data_ptr(), ctypes.POINTER(ctypes.c_uint64)).contents)
|
|
moe_config = MOEConfig(
|
|
experts_num,
|
|
num_experts_per_tok,
|
|
hidden_size,
|
|
moe_intermediate_size,
|
|
)
|
|
moe_config.layer_idx = 45
|
|
moe_config.pool = cpu_infer.backend_
|
|
moe_config.max_len = 1024 # TODO(zbx): multi cuda graph
|
|
moe_config.gate_proj = gate_ptr
|
|
moe_config.up_proj = up_ptr
|
|
moe_config.down_proj = down_ptr
|
|
moe_config.path = ""
|
|
moe = AMXInt4_MOE(moe_config)
|