mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-14 18:37:23 +00:00
217 lines
8.2 KiB
Python
217 lines
8.2 KiB
Python
import math
|
|
import os, sys
|
|
import time
|
|
from typing import Optional
|
|
|
|
os.environ["BLAS_NUM_THREADS"] = "1"
|
|
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
|
from kt_kernel import kt_kernel_ext
|
|
from kt_kernel_ext.kvcache import ggml_type
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
# from modeling_deepseek_v3 import MoEGate
|
|
from configuration_deepseek_v3 import DeepseekV3Config
|
|
|
|
seed = 42 # 你可以选择任何整数作为种子
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
seqlen = 64
|
|
|
|
config = DeepseekV3Config()
|
|
|
|
hidden_size = config.hidden_size
|
|
num_experts_per_token = config.num_experts_per_tok
|
|
n_routed_experts = config.n_routed_experts
|
|
n_group = config.n_group
|
|
topk_group = config.topk_group
|
|
routed_scaling_factor = config.routed_scaling_factor
|
|
|
|
weights = torch.randn((n_routed_experts, hidden_size), dtype=torch.float32).to("cpu").contiguous()
|
|
bias = torch.randn((n_routed_experts,), dtype=torch.float32).to("cpu").contiguous()
|
|
|
|
|
|
# weights = torch.randn((n_routed_experts, hidden_size), dtype=torch.float16).to('cpu').contiguous ()
|
|
def load_fp32_tensor(file_path, shape):
|
|
return torch.zeros(shape, dtype=torch.float32).to("cpu").contiguous()
|
|
with open(file_path, "rb") as f:
|
|
raw_data = f.read()
|
|
tensor = torch.frombuffer(raw_data, dtype=torch.float32)
|
|
tensor = tensor.view(shape) # 根据你的 shape reshape
|
|
return tensor
|
|
|
|
|
|
class MoEGate(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.top_k = config.num_experts_per_tok
|
|
self.n_routed_experts = config.n_routed_experts
|
|
self.routed_scaling_factor = config.routed_scaling_factor
|
|
self.scoring_func = config.scoring_func
|
|
self.topk_method = config.topk_method
|
|
self.n_group = config.n_group
|
|
self.topk_group = config.topk_group
|
|
|
|
# topk selection algorithm
|
|
self.norm_topk_prob = config.norm_topk_prob
|
|
self.gating_dim = config.hidden_size
|
|
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
|
if self.topk_method == "noaux_tc":
|
|
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self) -> None:
|
|
import torch.nn.init as init
|
|
|
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
|
|
def forward(self, hidden_states):
|
|
bsz, seq_len, h = hidden_states.shape
|
|
### compute gating score
|
|
hidden_states = hidden_states.view(-1, h)
|
|
|
|
h_to_check = load_fp32_tensor(
|
|
"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/gate_input", (seq_len, h)
|
|
)
|
|
diff = (h_to_check - hidden_states).abs().max()
|
|
# print("hidden_states diff:", diff)
|
|
# assert diff<0.02
|
|
|
|
bias_to_check = load_fp32_tensor(
|
|
"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/bias", (n_routed_experts)
|
|
)
|
|
diff = (bias - bias_to_check).abs().max()
|
|
# print('bias diff:',diff)
|
|
# assert diff < 0.02
|
|
|
|
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
|
|
|
|
logits_to_check = load_fp32_tensor(
|
|
"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/gate_logits",
|
|
(seq_len, n_routed_experts),
|
|
)
|
|
diff = (logits_to_check - logits).abs().max()
|
|
# print("logits diff:", diff)
|
|
# assert diff < 0.02
|
|
|
|
if self.scoring_func == "sigmoid":
|
|
scores = logits.sigmoid()
|
|
else:
|
|
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
|
|
|
|
### select top-k experts
|
|
if self.topk_method == "noaux_tc":
|
|
# assert not self.training
|
|
scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
|
|
|
|
scores_to_check = load_fp32_tensor(
|
|
"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/scores_to_choice",
|
|
(seq_len, n_routed_experts),
|
|
)
|
|
diff = (scores_for_choice - scores_to_check).abs().max()
|
|
print(f"score for choice diff = {diff}")
|
|
|
|
group_scores = (
|
|
scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
|
) # [n, n_group]
|
|
|
|
group_scores_to_check = load_fp32_tensor(
|
|
"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/group_scores",
|
|
(seq_len, n_group),
|
|
)
|
|
diff = (group_scores - group_scores_to_check).abs().max()
|
|
print(f"group scores diff = {diff}")
|
|
|
|
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
|
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
score_mask = (
|
|
group_mask.unsqueeze(-1)
|
|
.expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group)
|
|
.reshape(bsz * seq_len, -1)
|
|
) # [n, e]
|
|
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
|
tmp_scores_to_check = load_fp32_tensor(
|
|
"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/gate_logits_toped",
|
|
(seq_len, n_routed_experts),
|
|
)
|
|
is_close = torch.isclose(tmp_scores, tmp_scores_to_check, rtol=1e-2, atol=1e-2, equal_nan=True)
|
|
print(f"tmp_score ok {is_close.all()}")
|
|
|
|
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
|
topk_weight = scores.gather(1, topk_idx)
|
|
else:
|
|
raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}")
|
|
|
|
### norm gate to sum 1
|
|
if self.top_k > 1 and self.norm_topk_prob:
|
|
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
|
topk_weight = topk_weight / denominator
|
|
topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
|
|
|
|
return topk_idx, topk_weight
|
|
|
|
|
|
def torch_gate(hidden_states):
|
|
hidden_states.unsqueeze_(0)
|
|
gate = MoEGate(config)
|
|
gate.weight.data = weights
|
|
gate.e_score_correction_bias.data = bias
|
|
y = gate(hidden_states)
|
|
# print(y)
|
|
return y
|
|
|
|
|
|
def cpuinfer_gate(hidden_states):
|
|
config = kt_kernel_ext.gate.GateConfig(
|
|
hidden_size,
|
|
num_experts_per_token,
|
|
n_routed_experts,
|
|
n_group,
|
|
topk_group,
|
|
)
|
|
|
|
CPUInfer = kt_kernel_ext.CPUInfer(64)
|
|
config.routed_scaling_factor = routed_scaling_factor
|
|
|
|
config.pool = CPUInfer.backend_
|
|
config.weight = weights.data_ptr()
|
|
config.weight_type = ggml_type.FP32
|
|
config.e_score_correction_bias = bias.data_ptr()
|
|
config.e_score_correction_bias_type = ggml_type.FP32
|
|
|
|
gate = kt_kernel_ext.gate.MoEGate(config)
|
|
|
|
expert_ids = torch.zeros((seqlen, num_experts_per_token), dtype=torch.int64).to("cpu").contiguous()
|
|
expert_weights = torch.zeros((seqlen, num_experts_per_token), dtype=torch.float32).to("cpu").contiguous()
|
|
|
|
gate.forward(seqlen, hidden_states.data_ptr(), expert_ids.data_ptr(), expert_weights.data_ptr())
|
|
|
|
# print(expert_ids,expert_weights)
|
|
return expert_ids, expert_weights
|
|
|
|
|
|
input = torch.randn(seqlen, hidden_size, dtype=torch.float32).to("cpu").contiguous()
|
|
# print(input)
|
|
ids, we = cpuinfer_gate(input)
|
|
idx = torch.argsort(ids, dim=-1, descending=True)
|
|
ids = torch.gather(ids, dim=-1, index=idx)
|
|
we = torch.gather(we, dim=-1, index=idx)
|
|
|
|
|
|
std_ids, std_we = torch_gate(input)
|
|
idx = torch.argsort(std_ids, dim=-1, descending=True)
|
|
std_we = torch.gather(std_we, dim=-1, index=idx)
|
|
std_ids = torch.gather(std_ids, dim=-1, index=idx)
|
|
|
|
|
|
# print("ids diff:", torch.abs(std_ids - ids).max())
|
|
# print("weights diff:", torch.abs(std_we - we).max())
|
|
assert torch.abs(std_ids - ids).max() == 0, "Expert IDs do not match!"
|
|
assert torch.abs(std_we - we).max() < 1e-2, "Expert Weights do not match!"
|
|
print("Expert IDs and Weights match successfully!")
|