mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-15 02:47:22 +00:00
337 lines
14 KiB
Python
337 lines
14 KiB
Python
import os, sys
|
||
import time
|
||
from typing import Optional
|
||
|
||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||
from kt_kernel import kt_kernel_ext
|
||
from kt_kernel_ext.kvcache import ggml_type
|
||
import torch
|
||
from torch import nn
|
||
from torch.nn import init
|
||
from torch_attention import apply_rotary_pos_emb, DeepseekV2RMSNorm, KDeepSeekV3Cache, DeepseekV3YarnRotaryEmbedding
|
||
|
||
|
||
seed = 42 # 你可以选择任何整数作为种子
|
||
torch.manual_seed(seed)
|
||
torch.cuda.manual_seed_all(seed)
|
||
|
||
qlen = 1024
|
||
kvlen = 0
|
||
|
||
|
||
page_table = range(20)
|
||
bsz_tensors = torch.tensor([1])
|
||
|
||
|
||
page_size = 256
|
||
pages_count = 200
|
||
tp_count = 4
|
||
|
||
|
||
hidden_size = 7168
|
||
q_lora_rank = 1536
|
||
kv_lora_rank = 512
|
||
num_heads = 128
|
||
nope_size = 128
|
||
rope_size = 64
|
||
|
||
rope_theta = 10000
|
||
max_qlen = 1024
|
||
max_kvlen = 4096
|
||
|
||
max_position_embeddings = 163840
|
||
|
||
|
||
rope_scaling = {
|
||
"beta_fast": 32,
|
||
"beta_slow": 1,
|
||
"factor": 40,
|
||
"mscale": 1.0,
|
||
"mscale_all_dim": 1.0,
|
||
"original_max_position_embeddings": 4096,
|
||
"type": "yarn",
|
||
}
|
||
|
||
|
||
CPUInfer = kt_kernel_ext.CPUInfer(64)
|
||
validation_iter = 100
|
||
|
||
|
||
q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=False, dtype=torch.float16)
|
||
q_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size + rope_size), bias=False, dtype=torch.float16)
|
||
kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + rope_size, bias=False, dtype=torch.float16)
|
||
kv_b_proj = nn.Linear(kv_lora_rank, num_heads * (nope_size + nope_size), bias=False, dtype=torch.float16)
|
||
o_proj = nn.Linear(num_heads * nope_size, hidden_size, bias=False, dtype=torch.float16)
|
||
|
||
init.normal_(q_a_proj.weight, mean=0.0, std=0.02)
|
||
init.normal_(q_b_proj.weight, mean=0.0, std=0.02)
|
||
init.normal_(kv_a_proj_with_mqa.weight, mean=0.0, std=0.02)
|
||
init.normal_(kv_b_proj.weight, mean=0.0, std=0.02)
|
||
init.normal_(o_proj.weight, mean=0.0, std=0.02)
|
||
|
||
q_a_proj_weight = q_a_proj.weight.to(torch.float16).to("cpu").contiguous()
|
||
q_b_proj_weight = q_b_proj.weight.to(torch.float16).to("cpu").contiguous()
|
||
kv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to("cpu").to(torch.float16).contiguous()
|
||
kv_b_proj_weight = kv_b_proj.weight.to(torch.float16).to("cpu").contiguous()
|
||
o_proj_weight = o_proj.weight.to(torch.float16).to("cpu").contiguous()
|
||
|
||
|
||
config = kt_kernel_ext.mla.MLAConfig(
|
||
hidden_size,
|
||
q_lora_rank,
|
||
kv_lora_rank,
|
||
num_heads,
|
||
nope_size,
|
||
rope_size,
|
||
)
|
||
config.max_qlen = max_qlen
|
||
config.max_kvlen = max_kvlen
|
||
config.max_position_embeddings = max_position_embeddings
|
||
config.rope_scaling_factor = rope_scaling["factor"]
|
||
config.rope_theta = rope_theta
|
||
config.rope_scaling_beta_fast = rope_scaling["beta_fast"]
|
||
config.rope_scaling_beta_slow = rope_scaling["beta_slow"]
|
||
config.rope_scaling_mscale = rope_scaling["mscale"]
|
||
config.rope_scaling_mscale_all_dim = rope_scaling["mscale_all_dim"]
|
||
config.rope_scaling_original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
|
||
|
||
config.q_a_proj = q_a_proj_weight.data_ptr()
|
||
config.q_b_proj = q_b_proj_weight.data_ptr()
|
||
config.kv_a_proj_with_mqa = kv_a_proj_with_mqa_weight.data_ptr()
|
||
config.kv_b_proj = kv_b_proj_weight.data_ptr()
|
||
config.o_proj = o_proj_weight.data_ptr()
|
||
|
||
config.q_a_proj_type = ggml_type.FP16
|
||
config.q_b_proj_type = ggml_type.FP16
|
||
config.kv_a_proj_with_mqa_type = ggml_type.FP16
|
||
config.kv_b_proj_type = ggml_type.FP16
|
||
config.w_o_type = ggml_type.FP16
|
||
|
||
|
||
config.pool = CPUInfer.backend_
|
||
|
||
|
||
mla = kt_kernel_ext.mla.MLA(config)
|
||
mla.load_weights()
|
||
mla.set_local_pages(pages_count)
|
||
|
||
|
||
input = torch.randn((qlen, hidden_size), dtype=torch.float16).to("cpu").contiguous()
|
||
|
||
|
||
output = torch.zeros((qlen, hidden_size), dtype=torch.float16).to("cpu").contiguous()
|
||
mla.forward([qlen], [page_table], [kvlen], input.data_ptr(), output.data_ptr())
|
||
print("CPU MLA Output: ", output)
|
||
|
||
|
||
softmax_scale = (nope_size + rope_size) ** -0.5
|
||
# 1代表的是压缩的kv的头数
|
||
k_caches = torch.randn(1, pages_count, page_size, 1, kv_lora_rank + rope_size).to(torch.float16)
|
||
kv_cache = KDeepSeekV3Cache(page_size=page_size, kv_lora_rank=kv_lora_rank, k_caches=k_caches)
|
||
|
||
q_a_layernorm = DeepseekV2RMSNorm(q_lora_rank)
|
||
|
||
x = torch.randn(q_lora_rank, dtype=torch.float16) * 100
|
||
print(x)
|
||
print(q_a_layernorm(x))
|
||
|
||
kv_a_layernorm = DeepseekV2RMSNorm(kv_lora_rank)
|
||
|
||
|
||
q_absorb = torch.randn(num_heads, nope_size, kv_lora_rank, dtype=torch.float16)
|
||
out_absorb = torch.randn(num_heads, nope_size, kv_lora_rank, dtype=torch.float16)
|
||
|
||
rotary_emb = DeepseekV3YarnRotaryEmbedding(
|
||
rope_size,
|
||
max_position_embeddings=max_position_embeddings,
|
||
scaling_factor=rope_scaling["factor"],
|
||
base=rope_theta,
|
||
beta_fast=rope_scaling["beta_fast"],
|
||
beta_slow=rope_scaling["beta_slow"],
|
||
mscale=rope_scaling["mscale"],
|
||
mscale_all_dim=rope_scaling["mscale_all_dim"],
|
||
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
|
||
)
|
||
# 构造一个qlen 长度的输入 hidden_states, 对应的历史 kv_indptr 是[0:bsz]
|
||
# kv_indices 是[0:bsz],page_idx=[0:bsz], page_offset=[kvlen:qlen+kvlen]
|
||
# last_page_len = [qlen+kvlen,...] layer_idx = 1
|
||
# position_ids = [kvlen:qlen+kvlen]
|
||
hidden_states = torch.randn(qlen, hidden_size, dtype=torch.float16)
|
||
q_indptr = torch.tensor([0, qlen]).to(torch.int32)
|
||
|
||
kv_indptr = torch.tensor([0, (qlen + kvlen + page_size - 1) // page_size]).to(torch.int32)
|
||
kv_indices = torch.tensor(range(pages_count)).to(torch.int32)
|
||
|
||
page_idx = torch.tensor([i // page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)
|
||
page_offset = torch.tensor([i % page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)
|
||
|
||
last_page_len = torch.tensor([(qlen + kvlen) % page_size], device=hidden_states.device)
|
||
position_ids = torch.tensor(range(kvlen, kvlen + qlen)).to(torch.int32)
|
||
|
||
|
||
# 按照行创建 mask [qlen,kvlen+qlen]
|
||
attention_masks = torch.zeros((qlen, kvlen + qlen), dtype=torch.float16)
|
||
for i in range(qlen):
|
||
attention_masks[i, i + kvlen + 1 : i + kvlen + qlen] = -65504.0
|
||
|
||
|
||
def torch_attn(
|
||
hidden_states: torch.Tensor,
|
||
kv_cache: KDeepSeekV3Cache,
|
||
position_ids: torch.Tensor,
|
||
page_idx: torch.Tensor,
|
||
page_offset: torch.Tensor,
|
||
attention_masks: Optional[list[torch.Tensor]] = None,
|
||
q_indptr: Optional[torch.Tensor] = None,
|
||
kv_indices: Optional[torch.Tensor] = None,
|
||
kv_indptr: Optional[torch.Tensor] = None,
|
||
bsz_tensors: Optional[torch.Tensor] = None,
|
||
last_page_len: Optional[torch.Tensor] = None,
|
||
layer_idx: Optional[int] = None,
|
||
):
|
||
global out_absorb
|
||
global q_absorb
|
||
# range bsz_tensors
|
||
final_attention_output = torch.tensor([], device=hidden_states.device)
|
||
for i in range(bsz_tensors[0]):
|
||
batch_num_tokens_tensors = q_indptr[i + 1] - q_indptr[i]
|
||
batch_last_page_len = last_page_len[i]
|
||
# kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe
|
||
batch_page_idx = page_idx[q_indptr[i] : q_indptr[i + 1]]
|
||
batch_page_offset = page_offset[q_indptr[i] : q_indptr[i + 1]]
|
||
# kv_page_nums is the number of pages for the current batch
|
||
kv_page_nums = kv_indptr[i + 1] - kv_indptr[i]
|
||
# kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm)
|
||
kv_total_len = kv_page_nums * page_size
|
||
if batch_last_page_len is not None:
|
||
kv_total_len = kv_total_len - (page_size - batch_last_page_len)
|
||
# print(f"kv_total_len's shape {kv_total_len.shape}")
|
||
# kv_index is the index of the kv cache pages for the current batch
|
||
kv_index = kv_indices[kv_indptr[i] : kv_indptr[i + 1]]
|
||
# we can index [kv_index, page_offset_indices] to get the kv cache for the current batch
|
||
# from q_indptr[i] to q_indptr[i+1] is the range of the current batch
|
||
batch_hidden_states = hidden_states[q_indptr[i] : q_indptr[i + 1]]
|
||
batch_position_ids = position_ids[q_indptr[i] : q_indptr[i + 1]]
|
||
qlen, _ = batch_hidden_states.size()
|
||
# print("qlen -> ", qlen)
|
||
q_lora = q_a_proj(batch_hidden_states)
|
||
print("q_a_proj", q_a_proj.weight)
|
||
print("q_lora", q_lora)
|
||
|
||
q = q_b_proj(q_a_layernorm(q_lora))
|
||
print("q_b_proj", q_b_proj.weight)
|
||
# for v3, bsz, qlen, num_heads(128), qk_head_dim(192=128(nope)+64(rope))
|
||
q = q.view(qlen, num_heads, nope_size + rope_size)
|
||
# q_nope is [qlen, num_heads(128), qk_nope_head_dim(128)]
|
||
# q_pe is [qlen, num_heads(128), qk_rope_head_dim(64)]
|
||
q_nope, q_pe = torch.split(q, [nope_size, rope_size], dim=-1)
|
||
print("q_nope", q_nope)
|
||
print("q_pe", q_pe)
|
||
# compressed_kv is [qlen, kv_lora_rank(512) + rope(64)]
|
||
compressed_kv = kv_a_proj_with_mqa(batch_hidden_states)
|
||
# compressed_kv is [qlen, kv_lora_rank(512)], k_pe is [qlen, rope(64)]
|
||
compressed_kv, k_pe = torch.split(compressed_kv, [kv_lora_rank, rope_size], dim=-1)
|
||
compressed_kv = compressed_kv.contiguous()
|
||
compressed_kv = kv_a_layernorm(compressed_kv)
|
||
# k_pe is [qlen, 1, qk_rope_head_dim(64)]
|
||
print("compressed_kv ", compressed_kv)
|
||
print("k_pe ", k_pe)
|
||
k_pe = k_pe.view(qlen, 1, rope_size)
|
||
# compressed_kv is [qlen, 1, kv_lora_rank(512)]
|
||
compressed_kv = compressed_kv.view(qlen, 1, kv_lora_rank)
|
||
|
||
cos, sin = rotary_emb(q_pe, batch_position_ids)
|
||
# print(f"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}")
|
||
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=1)
|
||
q_pe = q_pe.squeeze(0)
|
||
# q_pe is [num_heads(128), qlen, qk_rope_head_dim(64)]
|
||
q_pe.transpose_(0, 1)
|
||
if kv_cache is not None:
|
||
cache_kwargs = {
|
||
"sin": sin,
|
||
"cos": cos,
|
||
"page_idx": batch_page_idx,
|
||
"page_offset": batch_page_offset,
|
||
} # Specific to RoPE models
|
||
compressed_kv_with_k_pe = kv_cache.update(
|
||
compressed_kv.unsqueeze(0), k_pe, layer_idx, batch_page_idx, batch_page_offset, cache_kwargs
|
||
)
|
||
compressed_kv = compressed_kv_with_k_pe[:, :, :, :kv_lora_rank].view(-1, page_size, kv_lora_rank)
|
||
k_pe = compressed_kv_with_k_pe[:, :, :, kv_lora_rank:].view(-1, page_size, rope_size)
|
||
# q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)]
|
||
# out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim
|
||
# q_absorb, out_absorb = get_absorbed()
|
||
# q_nope is [num_heads(128), qlen, qk_nope_head_dim(128)]
|
||
q_nope = q_nope.transpose(0, 1) # qlen is 1, no GPU overhead, same below
|
||
# q_nope is [num_heads(128), qlen, kv_lora_rank(512)]
|
||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||
|
||
# # q_nope is [qlen, num_heads(128), kv_lora_rank(512)]
|
||
# q_nope = q_nope.transpose(0, 1)
|
||
|
||
# we need to index out the compressed_kv and k_pe for the current batch
|
||
batch_compressed_kv = None
|
||
batch_k_pe = None
|
||
for page_index in kv_index:
|
||
if kv_total_len > page_size:
|
||
tmp_compressed_kv = compressed_kv[page_index, 0:page_size, :]
|
||
tmp_k_pe = k_pe[page_index, 0:page_size, :]
|
||
if batch_compressed_kv is None or batch_k_pe is None:
|
||
batch_compressed_kv = tmp_compressed_kv
|
||
batch_k_pe = tmp_k_pe
|
||
else:
|
||
batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)
|
||
batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)
|
||
kv_total_len -= page_size
|
||
else:
|
||
tmp_compressed_kv = compressed_kv[page_index, 0:kv_total_len, :]
|
||
tmp_k_pe = k_pe[page_index, 0:kv_total_len, :]
|
||
if batch_compressed_kv is None or batch_k_pe is None:
|
||
batch_compressed_kv = tmp_compressed_kv
|
||
batch_k_pe = tmp_k_pe
|
||
else:
|
||
batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)
|
||
batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)
|
||
break
|
||
# batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]
|
||
# batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]
|
||
pe_weights = torch.matmul(q_pe, batch_k_pe.mT)
|
||
print("pe_weights", pe_weights)
|
||
attention_weights = (pe_weights + torch.matmul(q_nope, batch_compressed_kv.mT)) * softmax_scale
|
||
# attention_weights is [num_heads(128), qlen, k_len]
|
||
|
||
# attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(qlen,-1,-1).transpose(0,1)
|
||
|
||
# attention_masks[i] is [qlen, k_len]
|
||
|
||
attention_weights = attention_weights + attention_masks[i]
|
||
# attention_weights shape is [num_heads(128), qlen, k_len]
|
||
attention_weights = nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float16).to(q_pe.dtype)
|
||
attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),qlen, lora_rank(512)]
|
||
# out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]
|
||
out_absorb = out_absorb.transpose(1, 2)
|
||
# q for qlen, n for num_heads, h for v_head_dim, v for kv_lora_rank
|
||
attn_output = torch.matmul(attn_output, out_absorb) # [num_heads(128), qlen, v_head_dim(128)]
|
||
attn_output = attn_output.transpose(0, 1) # [qlen, num_heads(128), v_head_dim(128)]
|
||
attn_output = attn_output.reshape(qlen, num_heads * nope_size)
|
||
attn_output = o_proj(attn_output)
|
||
final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)
|
||
return final_attention_output
|
||
|
||
|
||
torch_output = torch_attn(
|
||
input,
|
||
kv_cache,
|
||
position_ids,
|
||
page_idx,
|
||
page_offset,
|
||
attention_masks=attention_masks,
|
||
q_indptr=q_indptr,
|
||
kv_indices=kv_indices,
|
||
kv_indptr=kv_indptr,
|
||
bsz_tensors=bsz_tensors,
|
||
last_page_len=last_page_len,
|
||
layer_idx=0,
|
||
)
|
||
print("Torch Output: ", torch_output)
|