Files
ktransformers/kt-kernel/examples/test_gate.py
2025-12-17 19:46:32 +08:00

217 lines
8.2 KiB
Python

import math
import os, sys
import time
from typing import Optional
os.environ["BLAS_NUM_THREADS"] = "1"
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
from kt_kernel import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
from torch import nn
import torch.nn.functional as F
# from modeling_deepseek_v3 import MoEGate
from configuration_deepseek_v3 import DeepseekV3Config
seed = 42 # 你可以选择任何整数作为种子
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
seqlen = 64
config = DeepseekV3Config()
hidden_size = config.hidden_size
num_experts_per_token = config.num_experts_per_tok
n_routed_experts = config.n_routed_experts
n_group = config.n_group
topk_group = config.topk_group
routed_scaling_factor = config.routed_scaling_factor
weights = torch.randn((n_routed_experts, hidden_size), dtype=torch.float32).to("cpu").contiguous()
bias = torch.randn((n_routed_experts,), dtype=torch.float32).to("cpu").contiguous()
# weights = torch.randn((n_routed_experts, hidden_size), dtype=torch.float16).to('cpu').contiguous ()
def load_fp32_tensor(file_path, shape):
return torch.zeros(shape, dtype=torch.float32).to("cpu").contiguous()
with open(file_path, "rb") as f:
raw_data = f.read()
tensor = torch.frombuffer(raw_data, dtype=torch.float32)
tensor = tensor.view(shape) # 根据你的 shape reshape
return tensor
class MoEGate(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.scoring_func = config.scoring_func
self.topk_method = config.topk_method
self.n_group = config.n_group
self.topk_group = config.topk_group
# topk selection algorithm
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
if self.topk_method == "noaux_tc":
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
h_to_check = load_fp32_tensor(
"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/gate_input", (seq_len, h)
)
diff = (h_to_check - hidden_states).abs().max()
# print("hidden_states diff:", diff)
# assert diff<0.02
bias_to_check = load_fp32_tensor(
"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/bias", (n_routed_experts)
)
diff = (bias - bias_to_check).abs().max()
# print('bias diff:',diff)
# assert diff < 0.02
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
logits_to_check = load_fp32_tensor(
"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/gate_logits",
(seq_len, n_routed_experts),
)
diff = (logits_to_check - logits).abs().max()
# print("logits diff:", diff)
# assert diff < 0.02
if self.scoring_func == "sigmoid":
scores = logits.sigmoid()
else:
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
### select top-k experts
if self.topk_method == "noaux_tc":
# assert not self.training
scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
scores_to_check = load_fp32_tensor(
"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/scores_to_choice",
(seq_len, n_routed_experts),
)
diff = (scores_for_choice - scores_to_check).abs().max()
print(f"score for choice diff = {diff}")
group_scores = (
scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
) # [n, n_group]
group_scores_to_check = load_fp32_tensor(
"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/group_scores",
(seq_len, n_group),
)
diff = (group_scores - group_scores_to_check).abs().max()
print(f"group scores diff = {diff}")
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group)
.reshape(bsz * seq_len, -1)
) # [n, e]
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
tmp_scores_to_check = load_fp32_tensor(
"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/gate_logits_toped",
(seq_len, n_routed_experts),
)
is_close = torch.isclose(tmp_scores, tmp_scores_to_check, rtol=1e-2, atol=1e-2, equal_nan=True)
print(f"tmp_score ok {is_close.all()}")
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
topk_weight = scores.gather(1, topk_idx)
else:
raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}")
### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
return topk_idx, topk_weight
def torch_gate(hidden_states):
hidden_states.unsqueeze_(0)
gate = MoEGate(config)
gate.weight.data = weights
gate.e_score_correction_bias.data = bias
y = gate(hidden_states)
# print(y)
return y
def cpuinfer_gate(hidden_states):
config = kt_kernel_ext.gate.GateConfig(
hidden_size,
num_experts_per_token,
n_routed_experts,
n_group,
topk_group,
)
CPUInfer = kt_kernel_ext.CPUInfer(64)
config.routed_scaling_factor = routed_scaling_factor
config.pool = CPUInfer.backend_
config.weight = weights.data_ptr()
config.weight_type = ggml_type.FP32
config.e_score_correction_bias = bias.data_ptr()
config.e_score_correction_bias_type = ggml_type.FP32
gate = kt_kernel_ext.gate.MoEGate(config)
expert_ids = torch.zeros((seqlen, num_experts_per_token), dtype=torch.int64).to("cpu").contiguous()
expert_weights = torch.zeros((seqlen, num_experts_per_token), dtype=torch.float32).to("cpu").contiguous()
gate.forward(seqlen, hidden_states.data_ptr(), expert_ids.data_ptr(), expert_weights.data_ptr())
# print(expert_ids,expert_weights)
return expert_ids, expert_weights
input = torch.randn(seqlen, hidden_size, dtype=torch.float32).to("cpu").contiguous()
# print(input)
ids, we = cpuinfer_gate(input)
idx = torch.argsort(ids, dim=-1, descending=True)
ids = torch.gather(ids, dim=-1, index=idx)
we = torch.gather(we, dim=-1, index=idx)
std_ids, std_we = torch_gate(input)
idx = torch.argsort(std_ids, dim=-1, descending=True)
std_we = torch.gather(std_we, dim=-1, index=idx)
std_ids = torch.gather(std_ids, dim=-1, index=idx)
# print("ids diff:", torch.abs(std_ids - ids).max())
# print("weights diff:", torch.abs(std_we - we).max())
assert torch.abs(std_ids - ids).max() == 0, "Expert IDs do not match!"
assert torch.abs(std_we - we).max() < 1e-2, "Expert Weights do not match!"
print("Expert IDs and Weights match successfully!")