Files
ktransformers/kt-kernel/examples/test-debug.py
2025-12-17 19:46:32 +08:00

40 lines
1.4 KiB
Python

import os
import sys
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import torch
import ctypes
from kt_kernel import kt_kernel_ext
from kt_kernel_ext.moe import MOEConfig, MOE, AMXBF16_MOE, AMXInt8_MOE, AMXInt4_MOE, AMXInt4_1_MOE
intermediate_size_full = 2048
moe_intermediate_size = 3072
hidden_size = 7168
experts_num = 256
num_experts_per_tok = 8
cpu_infer = kt_kernel_ext.CPUInfer(97)
up = torch.empty(experts_num, intermediate_size_full, hidden_size, dtype=torch.bfloat16, device="cpu")
gate = torch.empty(experts_num, intermediate_size_full, hidden_size, dtype=torch.bfloat16, device="cpu")
down = torch.empty(experts_num, hidden_size, intermediate_size_full, dtype=torch.bfloat16, device="cpu")
gate_ptr = ctypes.addressof(ctypes.cast(gate.data_ptr(), ctypes.POINTER(ctypes.c_uint64)).contents)
up_ptr = ctypes.addressof(ctypes.cast(up.data_ptr(), ctypes.POINTER(ctypes.c_uint64)).contents)
down_ptr = ctypes.addressof(ctypes.cast(down.data_ptr(), ctypes.POINTER(ctypes.c_uint64)).contents)
moe_config = MOEConfig(
experts_num,
num_experts_per_tok,
hidden_size,
moe_intermediate_size,
)
moe_config.layer_idx = 45
moe_config.pool = cpu_infer.backend_
moe_config.max_len = 1024 # TODO(zbx): multi cuda graph
moe_config.gate_proj = gate_ptr
moe_config.up_proj = up_ptr
moe_config.down_proj = down_ptr
moe_config.path = ""
moe = AMXInt4_MOE(moe_config)