Files
ktransformers/kt-kernel/examples/test_mla_torch.py
2025-12-17 19:46:32 +08:00

337 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os, sys
import time
from typing import Optional
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
from kt_kernel import kt_kernel_ext
from kt_kernel_ext.kvcache import ggml_type
import torch
from torch import nn
from torch.nn import init
from torch_attention import apply_rotary_pos_emb, DeepseekV2RMSNorm, KDeepSeekV3Cache, DeepseekV3YarnRotaryEmbedding
seed = 42 # 你可以选择任何整数作为种子
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
qlen = 1024
kvlen = 0
page_table = range(20)
bsz_tensors = torch.tensor([1])
page_size = 256
pages_count = 200
tp_count = 4
hidden_size = 7168
q_lora_rank = 1536
kv_lora_rank = 512
num_heads = 128
nope_size = 128
rope_size = 64
rope_theta = 10000
max_qlen = 1024
max_kvlen = 4096
max_position_embeddings = 163840
rope_scaling = {
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
}
CPUInfer = kt_kernel_ext.CPUInfer(64)
validation_iter = 100
q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=False, dtype=torch.float16)
q_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size + rope_size), bias=False, dtype=torch.float16)
kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + rope_size, bias=False, dtype=torch.float16)
kv_b_proj = nn.Linear(kv_lora_rank, num_heads * (nope_size + nope_size), bias=False, dtype=torch.float16)
o_proj = nn.Linear(num_heads * nope_size, hidden_size, bias=False, dtype=torch.float16)
init.normal_(q_a_proj.weight, mean=0.0, std=0.02)
init.normal_(q_b_proj.weight, mean=0.0, std=0.02)
init.normal_(kv_a_proj_with_mqa.weight, mean=0.0, std=0.02)
init.normal_(kv_b_proj.weight, mean=0.0, std=0.02)
init.normal_(o_proj.weight, mean=0.0, std=0.02)
q_a_proj_weight = q_a_proj.weight.to(torch.float16).to("cpu").contiguous()
q_b_proj_weight = q_b_proj.weight.to(torch.float16).to("cpu").contiguous()
kv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to("cpu").to(torch.float16).contiguous()
kv_b_proj_weight = kv_b_proj.weight.to(torch.float16).to("cpu").contiguous()
o_proj_weight = o_proj.weight.to(torch.float16).to("cpu").contiguous()
config = kt_kernel_ext.mla.MLAConfig(
hidden_size,
q_lora_rank,
kv_lora_rank,
num_heads,
nope_size,
rope_size,
)
config.max_qlen = max_qlen
config.max_kvlen = max_kvlen
config.max_position_embeddings = max_position_embeddings
config.rope_scaling_factor = rope_scaling["factor"]
config.rope_theta = rope_theta
config.rope_scaling_beta_fast = rope_scaling["beta_fast"]
config.rope_scaling_beta_slow = rope_scaling["beta_slow"]
config.rope_scaling_mscale = rope_scaling["mscale"]
config.rope_scaling_mscale_all_dim = rope_scaling["mscale_all_dim"]
config.rope_scaling_original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
config.q_a_proj = q_a_proj_weight.data_ptr()
config.q_b_proj = q_b_proj_weight.data_ptr()
config.kv_a_proj_with_mqa = kv_a_proj_with_mqa_weight.data_ptr()
config.kv_b_proj = kv_b_proj_weight.data_ptr()
config.o_proj = o_proj_weight.data_ptr()
config.q_a_proj_type = ggml_type.FP16
config.q_b_proj_type = ggml_type.FP16
config.kv_a_proj_with_mqa_type = ggml_type.FP16
config.kv_b_proj_type = ggml_type.FP16
config.w_o_type = ggml_type.FP16
config.pool = CPUInfer.backend_
mla = kt_kernel_ext.mla.MLA(config)
mla.load_weights()
mla.set_local_pages(pages_count)
input = torch.randn((qlen, hidden_size), dtype=torch.float16).to("cpu").contiguous()
output = torch.zeros((qlen, hidden_size), dtype=torch.float16).to("cpu").contiguous()
mla.forward([qlen], [page_table], [kvlen], input.data_ptr(), output.data_ptr())
print("CPU MLA Output: ", output)
softmax_scale = (nope_size + rope_size) ** -0.5
# 1代表的是压缩的kv的头数
k_caches = torch.randn(1, pages_count, page_size, 1, kv_lora_rank + rope_size).to(torch.float16)
kv_cache = KDeepSeekV3Cache(page_size=page_size, kv_lora_rank=kv_lora_rank, k_caches=k_caches)
q_a_layernorm = DeepseekV2RMSNorm(q_lora_rank)
x = torch.randn(q_lora_rank, dtype=torch.float16) * 100
print(x)
print(q_a_layernorm(x))
kv_a_layernorm = DeepseekV2RMSNorm(kv_lora_rank)
q_absorb = torch.randn(num_heads, nope_size, kv_lora_rank, dtype=torch.float16)
out_absorb = torch.randn(num_heads, nope_size, kv_lora_rank, dtype=torch.float16)
rotary_emb = DeepseekV3YarnRotaryEmbedding(
rope_size,
max_position_embeddings=max_position_embeddings,
scaling_factor=rope_scaling["factor"],
base=rope_theta,
beta_fast=rope_scaling["beta_fast"],
beta_slow=rope_scaling["beta_slow"],
mscale=rope_scaling["mscale"],
mscale_all_dim=rope_scaling["mscale_all_dim"],
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
# 构造一个qlen 长度的输入 hidden_states, 对应的历史 kv_indptr 是[0:bsz]
# kv_indices 是[0:bsz]page_idx=[0:bsz], page_offset=[kvlen:qlen+kvlen]
# last_page_len = [qlen+kvlen,...] layer_idx = 1
# position_ids = [kvlen:qlen+kvlen]
hidden_states = torch.randn(qlen, hidden_size, dtype=torch.float16)
q_indptr = torch.tensor([0, qlen]).to(torch.int32)
kv_indptr = torch.tensor([0, (qlen + kvlen + page_size - 1) // page_size]).to(torch.int32)
kv_indices = torch.tensor(range(pages_count)).to(torch.int32)
page_idx = torch.tensor([i // page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)
page_offset = torch.tensor([i % page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)
last_page_len = torch.tensor([(qlen + kvlen) % page_size], device=hidden_states.device)
position_ids = torch.tensor(range(kvlen, kvlen + qlen)).to(torch.int32)
# 按照行创建 mask [qlen,kvlen+qlen]
attention_masks = torch.zeros((qlen, kvlen + qlen), dtype=torch.float16)
for i in range(qlen):
attention_masks[i, i + kvlen + 1 : i + kvlen + qlen] = -65504.0
def torch_attn(
hidden_states: torch.Tensor,
kv_cache: KDeepSeekV3Cache,
position_ids: torch.Tensor,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
attention_masks: Optional[list[torch.Tensor]] = None,
q_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_indptr: Optional[torch.Tensor] = None,
bsz_tensors: Optional[torch.Tensor] = None,
last_page_len: Optional[torch.Tensor] = None,
layer_idx: Optional[int] = None,
):
global out_absorb
global q_absorb
# range bsz_tensors
final_attention_output = torch.tensor([], device=hidden_states.device)
for i in range(bsz_tensors[0]):
batch_num_tokens_tensors = q_indptr[i + 1] - q_indptr[i]
batch_last_page_len = last_page_len[i]
# kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe
batch_page_idx = page_idx[q_indptr[i] : q_indptr[i + 1]]
batch_page_offset = page_offset[q_indptr[i] : q_indptr[i + 1]]
# kv_page_nums is the number of pages for the current batch
kv_page_nums = kv_indptr[i + 1] - kv_indptr[i]
# kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm)
kv_total_len = kv_page_nums * page_size
if batch_last_page_len is not None:
kv_total_len = kv_total_len - (page_size - batch_last_page_len)
# print(f"kv_total_len's shape {kv_total_len.shape}")
# kv_index is the index of the kv cache pages for the current batch
kv_index = kv_indices[kv_indptr[i] : kv_indptr[i + 1]]
# we can index [kv_index, page_offset_indices] to get the kv cache for the current batch
# from q_indptr[i] to q_indptr[i+1] is the range of the current batch
batch_hidden_states = hidden_states[q_indptr[i] : q_indptr[i + 1]]
batch_position_ids = position_ids[q_indptr[i] : q_indptr[i + 1]]
qlen, _ = batch_hidden_states.size()
# print("qlen -> ", qlen)
q_lora = q_a_proj(batch_hidden_states)
print("q_a_proj", q_a_proj.weight)
print("q_lora", q_lora)
q = q_b_proj(q_a_layernorm(q_lora))
print("q_b_proj", q_b_proj.weight)
# for v3, bsz, qlen, num_heads(128), qk_head_dim(192=128(nope)+64(rope))
q = q.view(qlen, num_heads, nope_size + rope_size)
# q_nope is [qlen, num_heads(128), qk_nope_head_dim(128)]
# q_pe is [qlen, num_heads(128), qk_rope_head_dim(64)]
q_nope, q_pe = torch.split(q, [nope_size, rope_size], dim=-1)
print("q_nope", q_nope)
print("q_pe", q_pe)
# compressed_kv is [qlen, kv_lora_rank(512) + rope(64)]
compressed_kv = kv_a_proj_with_mqa(batch_hidden_states)
# compressed_kv is [qlen, kv_lora_rank(512)], k_pe is [qlen, rope(64)]
compressed_kv, k_pe = torch.split(compressed_kv, [kv_lora_rank, rope_size], dim=-1)
compressed_kv = compressed_kv.contiguous()
compressed_kv = kv_a_layernorm(compressed_kv)
# k_pe is [qlen, 1, qk_rope_head_dim(64)]
print("compressed_kv ", compressed_kv)
print("k_pe ", k_pe)
k_pe = k_pe.view(qlen, 1, rope_size)
# compressed_kv is [qlen, 1, kv_lora_rank(512)]
compressed_kv = compressed_kv.view(qlen, 1, kv_lora_rank)
cos, sin = rotary_emb(q_pe, batch_position_ids)
# print(f"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}")
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=1)
q_pe = q_pe.squeeze(0)
# q_pe is [num_heads(128), qlen, qk_rope_head_dim(64)]
q_pe.transpose_(0, 1)
if kv_cache is not None:
cache_kwargs = {
"sin": sin,
"cos": cos,
"page_idx": batch_page_idx,
"page_offset": batch_page_offset,
} # Specific to RoPE models
compressed_kv_with_k_pe = kv_cache.update(
compressed_kv.unsqueeze(0), k_pe, layer_idx, batch_page_idx, batch_page_offset, cache_kwargs
)
compressed_kv = compressed_kv_with_k_pe[:, :, :, :kv_lora_rank].view(-1, page_size, kv_lora_rank)
k_pe = compressed_kv_with_k_pe[:, :, :, kv_lora_rank:].view(-1, page_size, rope_size)
# q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)]
# out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim
# q_absorb, out_absorb = get_absorbed()
# q_nope is [num_heads(128), qlen, qk_nope_head_dim(128)]
q_nope = q_nope.transpose(0, 1) # qlen is 1, no GPU overhead, same below
# q_nope is [num_heads(128), qlen, kv_lora_rank(512)]
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
# # q_nope is [qlen, num_heads(128), kv_lora_rank(512)]
# q_nope = q_nope.transpose(0, 1)
# we need to index out the compressed_kv and k_pe for the current batch
batch_compressed_kv = None
batch_k_pe = None
for page_index in kv_index:
if kv_total_len > page_size:
tmp_compressed_kv = compressed_kv[page_index, 0:page_size, :]
tmp_k_pe = k_pe[page_index, 0:page_size, :]
if batch_compressed_kv is None or batch_k_pe is None:
batch_compressed_kv = tmp_compressed_kv
batch_k_pe = tmp_k_pe
else:
batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)
batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)
kv_total_len -= page_size
else:
tmp_compressed_kv = compressed_kv[page_index, 0:kv_total_len, :]
tmp_k_pe = k_pe[page_index, 0:kv_total_len, :]
if batch_compressed_kv is None or batch_k_pe is None:
batch_compressed_kv = tmp_compressed_kv
batch_k_pe = tmp_k_pe
else:
batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)
batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)
break
# batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]
# batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]
pe_weights = torch.matmul(q_pe, batch_k_pe.mT)
print("pe_weights", pe_weights)
attention_weights = (pe_weights + torch.matmul(q_nope, batch_compressed_kv.mT)) * softmax_scale
# attention_weights is [num_heads(128), qlen, k_len]
# attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(qlen,-1,-1).transpose(0,1)
# attention_masks[i] is [qlen, k_len]
attention_weights = attention_weights + attention_masks[i]
# attention_weights shape is [num_heads(128), qlen, k_len]
attention_weights = nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float16).to(q_pe.dtype)
attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),qlen, lora_rank(512)]
# out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]
out_absorb = out_absorb.transpose(1, 2)
# q for qlen, n for num_heads, h for v_head_dim, v for kv_lora_rank
attn_output = torch.matmul(attn_output, out_absorb) # [num_heads(128), qlen, v_head_dim(128)]
attn_output = attn_output.transpose(0, 1) # [qlen, num_heads(128), v_head_dim(128)]
attn_output = attn_output.reshape(qlen, num_heads * nope_size)
attn_output = o_proj(attn_output)
final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)
return final_attention_output
torch_output = torch_attn(
input,
kv_cache,
position_ids,
page_idx,
page_offset,
attention_masks=attention_masks,
q_indptr=q_indptr,
kv_indices=kv_indices,
kv_indptr=kv_indptr,
bsz_tensors=bsz_tensors,
last_page_len=last_page_len,
layer_idx=0,
)
print("Torch Output: ", torch_output)