update: add attention and ln ut for npu (#1698)

This commit is contained in:
Shaoxu Cheng
2025-12-10 16:12:26 +08:00
committed by GitHub
parent f992de55da
commit 8995378a91
3 changed files with 795 additions and 0 deletions

View File

@@ -0,0 +1,509 @@
import sys
import types
import torch
import torch.nn as nn
import pytest
torch_npu = pytest.importorskip("torch_npu")
from ktransformers.operators.ascend.ascend_attention import (
KDeepseekV2AttentionW8A8A2Serve,
)
import ktransformers.operators.ascend.ascend_attention as attn_mod
class DummyConfig:
def __init__(self, hidden_size=4, num_attention_heads=1):
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
class DummyOrigAttn(nn.Module):
def __init__(self, config=None, layer_idx=0):
super().__init__()
self.config = config
self.layer_idx = layer_idx
hidden_dim = config.hidden_size if config is not None else 4
self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.kv_a_proj_with_mqa = None
self.kv_a_layernorm = nn.LayerNorm(2)
self.o_proj = None
class DummyDynamicQuantOps:
def execute(self, inputs):
x = inputs[0]
return [x]
class DummyMatMulOps:
def execute(self, inputs):
x = inputs[0]
return [x]
class DummyQuantProj(nn.Module):
def __init__(self, dim):
super().__init__()
self.input_scale = torch.tensor(1.0, dtype=torch.float16)
self.input_offset = torch.tensor(0.0, dtype=torch.float16)
self.weight = nn.Parameter(torch.zeros(dim, dim, dtype=torch.float16))
self.quant_bias = torch.zeros(dim, dtype=torch.float16)
self.deq_scale = torch.tensor(1.0, dtype=torch.float16)
class DummyStaticCache:
def __init__(self, page_size=16):
self.page_size = page_size
def get_usable_length(self, kv_seq_len, layer_idx):
return 0
def update(self, combined, layer_idx, cache_kwargs):
return combined, None
class DummyNpuFusedAttention:
def __call__(self, q, k, v, **kwargs):
bsz, max_q_len, num_heads, dim = q.shape
out = torch.zeros(
bsz, max_q_len, num_heads, dim, dtype=q.dtype, device=q.device
)
softmax_lse = torch.zeros(1, dtype=q.dtype, device=q.device)
return out, softmax_lse
def out(self, q, k, v, workspace=None,
query_rope=None, key_rope=None,
num_heads=None, num_key_value_heads=None,
input_layout=None, scale=None,
antiquant_mode=None, antiquant_scale=None,
block_table=None, block_size=None,
actual_seq_lengths_kv=None,
sparse_mode=None,
out=None):
attn_output, softmax_lse = out
attn_output.zero_()
softmax_lse.zero_()
return attn_output, softmax_lse
class DummyOpsNpu:
def npu_fused_infer_attention_score(self, q, k, v, **kwargs):
bsz, num_heads, q_len, dim = q.shape
out = torch.zeros(
bsz, num_heads, q_len, dim, dtype=q.dtype, device=q.device
)
softmax_lse = torch.zeros(1, dtype=q.dtype, device=q.device)
return out, softmax_lse
def fake_apply_rotary_pos_emb_fusion(q_pe, k_pe, cos, sin):
return q_pe, k_pe
def build_attention_module(q_lora_rank=None):
if hasattr(attn_mod, "get_tensor_parallel_size"):
attn_mod.get_tensor_parallel_size = lambda: 1 # type: ignore
config = DummyConfig(hidden_size=4, num_attention_heads=1)
orig = DummyOrigAttn(config=config, layer_idx=0)
attn = KDeepseekV2AttentionW8A8A2Serve(
key="test",
gguf_loader=None,
config=config,
orig_module=orig,
prefill_device="npu",
generate_device="npu",
)
hidden_dim = 4
num_heads = 1
qk_nope_head_dim = 2
qk_rope_head_dim = 2
q_head_dim = qk_nope_head_dim + qk_rope_head_dim # 4
kv_lora_rank = 2
v_head_dim = 2
attn.num_heads = num_heads
attn.q_head_dim = q_head_dim
attn.qk_nope_head_dim = qk_nope_head_dim
attn.qk_rope_head_dim = qk_rope_head_dim
attn.kv_lora_rank = kv_lora_rank
attn.v_head_dim = v_head_dim
attn.softmax_scale = 1.0
attn.layer_idx = 0
attn.sparse_mode = 0
attn.q_lora_rank = q_lora_rank
attn.elewise_quant = DummyDynamicQuantOps()
attn.matmulDequant_operation = DummyMatMulOps()
attn.matmulDequant_operation_aclnn = DummyMatMulOps()
orig_mod = attn.orig_module
if q_lora_rank is None:
orig_mod.q_proj = nn.Linear(hidden_dim, num_heads * q_head_dim, bias=False)
orig_mod.q_proj = orig_mod.q_proj.to(dtype=torch.float16)
else:
orig_mod.q_a_proj = DummyQuantProj(hidden_dim)
orig_mod.q_b_proj = DummyQuantProj(hidden_dim)
orig_mod.q_a_layernorm = nn.LayerNorm(hidden_dim)
orig_mod.kv_a_proj_with_mqa = DummyQuantProj(hidden_dim)
orig_mod.kv_a_layernorm = nn.LayerNorm(kv_lora_rank)
orig_mod.o_proj = DummyQuantProj(num_heads * v_head_dim)
attn.q_absorb = torch.randn(
num_heads, qk_nope_head_dim, kv_lora_rank, dtype=torch.float16
)
attn.out_absorb = torch.randn(
num_heads, kv_lora_rank, v_head_dim, dtype=torch.float16
)
def fake_rotary_emb(q_pe, position_ids):
bsz, n_heads, q_len, dim = q_pe.shape
cos = torch.ones(1, 1, q_len, dim, dtype=q_pe.dtype, device=q_pe.device)
sin = torch.zeros(1, 1, q_len, dim, dtype=q_pe.dtype, device=q_pe.device)
return cos, sin
attn.rotary_emb = fake_rotary_emb
return attn
@pytest.fixture(autouse=True)
def _patch_env(monkeypatch):
if hasattr(attn_mod, "apply_rotary_pos_emb_fusion"):
monkeypatch.setattr(
attn_mod, "apply_rotary_pos_emb_fusion",
fake_apply_rotary_pos_emb_fusion
)
if hasattr(attn_mod, "get_use_npu_graph"):
monkeypatch.setattr(attn_mod, "get_use_npu_graph", lambda: False)
if hasattr(attn_mod, "get_tensor_parallel_size"):
monkeypatch.setattr(attn_mod, "get_tensor_parallel_size", lambda: 1)
if hasattr(attn_mod, "get_tensor_parallel_group"):
monkeypatch.setattr(attn_mod, "get_tensor_parallel_group", lambda: None)
if hasattr(attn_mod, "get_current_device"):
monkeypatch.setattr(attn_mod, "get_current_device", lambda: "cpu")
# torch.distributed.barrier -> no-op
if hasattr(torch, "distributed") and hasattr(torch.distributed, "barrier"):
monkeypatch.setattr(
torch.distributed, "barrier",
lambda *args, **kwargs: None,
raising=False,
)
dummy_op = DummyNpuFusedAttention()
monkeypatch.setattr(
torch_npu, "npu_fused_infer_attention_score",
dummy_op, raising=False
)
def fake_get_workspace(q, k, v, **kwargs):
return torch.empty(1, dtype=q.dtype, device=q.device)
monkeypatch.setattr(
torch_npu, "_npu_fused_infer_attention_score_get_max_workspace",
fake_get_workspace, raising=False
)
monkeypatch.setattr(torch.ops, "npu", DummyOpsNpu(), raising=False)
yield
# ==========================
# 测试用例
# ==========================
def test_print_callback_smoke():
attn = build_attention_module()
bsz, q_len, hidden_dim = 1, 3, 4
hidden_states = torch.randn(bsz, q_len, hidden_dim)
position_ids = torch.arange(q_len).unsqueeze(0)
cache_position = torch.arange(q_len).unsqueeze(0)
page_idx = torch.zeros(bsz, dtype=torch.int32)
page_offset = torch.zeros(bsz, dtype=torch.int32)
block_table = torch.zeros(bsz, 1, dtype=torch.int32)
attn.print_callback(
(hidden_states, position_ids, cache_position,
page_idx, page_offset, block_table)
)
def _common_inputs_prefill():
bsz, q_len, hidden_dim = 1, 3, 4
hidden_states = torch.randn(bsz, q_len, hidden_dim, dtype=torch.float16)
attention_mask = torch.zeros(bsz, 1, q_len, q_len, dtype=torch.float32)
position_ids = torch.arange(q_len).unsqueeze(0)
cache_position = torch.arange(q_len).unsqueeze(0)
page_idx = torch.zeros(bsz, dtype=torch.int32)
page_offset = torch.zeros(bsz, dtype=torch.int32)
block_table = torch.zeros(bsz, 1, dtype=torch.int32)
past_key_value = DummyStaticCache(page_size=16)
q_len_raw = torch.tensor([q_len], dtype=torch.int32)
kv_len_raw = torch.tensor([q_len], dtype=torch.int32)
return (
hidden_states, attention_mask, position_ids, cache_position,
page_idx, page_offset, block_table,
past_key_value, q_len_raw, kv_len_raw
)
def test_forward_prefill_with_mask():
"""
is_prefill=True + attention_mask 不为 None + past_key_value 不为 None
"""
attn = build_attention_module(q_lora_rank=None)
(hidden_states, attention_mask, position_ids, cache_position,
page_idx, page_offset, block_table,
past_key_value, q_len_raw, kv_len_raw) = _common_inputs_prefill()
outputs = attn.forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=False,
use_cache=True,
cache_position=cache_position,
is_prefill=True,
page_idx=page_idx,
page_offset=page_offset,
block_table=block_table,
q_len_raw=q_len_raw,
kv_len_raw=kv_len_raw,
stream=None,
)
attn_output, attn_weights, new_cache = outputs
assert attn_output.shape == (
1, # bsz
3, # q_len
attn.num_heads * attn.v_head_dim,
)
assert attn_weights is None
assert new_cache is past_key_value
def test_forward_prefill_without_mask_and_q_lora():
"""
is_prefill=True + attention_mask=None + q_lora_rank 非 None 分支
"""
attn = build_attention_module(q_lora_rank=1)
(hidden_states, attention_mask, position_ids, cache_position,
page_idx, page_offset, block_table,
past_key_value, q_len_raw, kv_len_raw) = _common_inputs_prefill()
outputs = attn.forward(
hidden_states=hidden_states,
attention_mask=None,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=False,
use_cache=True,
cache_position=cache_position,
is_prefill=True,
page_idx=None,
page_offset=None,
block_table=None,
q_len_raw=q_len_raw,
kv_len_raw=kv_len_raw,
stream=None,
)
attn_output, attn_weights, new_cache = outputs
assert attn_output.shape == (
1,
3,
attn.num_heads * attn.v_head_dim,
)
assert attn_weights is None
assert new_cache is past_key_value
def test_forward_decode_paged_path():
"""
is_prefill=False + get_use_npu_graph=False
=> 走 forward_paged + torch.ops.npu.npu_fused_infer_attention_score 分支
"""
attn = build_attention_module(q_lora_rank=None)
bsz, q_len, hidden_dim = 1, 1, 4
hidden_states = torch.randn(bsz, q_len, hidden_dim, dtype=torch.float16)
position_ids = torch.arange(q_len).unsqueeze(0)
cache_position = torch.arange(q_len).unsqueeze(0)
past_key_value = DummyStaticCache(page_size=16)
q_len_raw = torch.tensor([q_len], dtype=torch.int32)
kv_len_raw = torch.tensor([q_len], dtype=torch.int32)
block_table = torch.zeros(bsz, 1, dtype=torch.int32)
outputs = attn.forward(
hidden_states=hidden_states,
attention_mask=None,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=False,
use_cache=True,
cache_position=cache_position,
is_prefill=False,
page_idx=None,
page_offset=None,
block_table=block_table,
q_len_raw=q_len_raw,
kv_len_raw=kv_len_raw,
stream=None,
)
attn_output, attn_weights, new_cache = outputs
assert attn_output.shape == (
bsz,
q_len,
attn.num_heads * attn.v_head_dim,
)
assert attn_weights is None
assert new_cache is past_key_value
def test_forward_prefill_layer_idx_none_raises():
"""
覆盖: past_key_value 不为 None 且 layer_idx 为 None 的异常分支。
"""
attn = build_attention_module(q_lora_rank=None)
attn.layer_idx = None # 手动破坏 layer_idx
(hidden_states, attention_mask, position_ids, cache_position,
page_idx, page_offset, block_table,
past_key_value, q_len_raw, kv_len_raw) = _common_inputs_prefill()
with pytest.raises(ValueError):
attn.forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=False,
use_cache=True,
cache_position=cache_position,
is_prefill=True,
page_idx=page_idx,
page_offset=page_offset,
block_table=block_table,
q_len_raw=q_len_raw,
kv_len_raw=kv_len_raw,
stream=None,
)
def test_forward_prefill_attn_output_shape_mismatch_raises(monkeypatch):
"""
覆盖: attn_output 形状不符合期望时的 ValueError 分支。
"""
attn = build_attention_module(q_lora_rank=None)
def bad_fused(q, k, v, **kwargs):
bsz, max_q_len, num_heads, dim = q.shape
# 刻意制造 num_heads+1触发 size 检查不通过
out = torch.zeros(
bsz, max_q_len, num_heads + 1, attn.v_head_dim,
dtype=q.dtype, device=q.device
)
lse = torch.zeros(1, dtype=q.dtype, device=q.device)
return out, lse
monkeypatch.setattr(
torch_npu, "npu_fused_infer_attention_score",
bad_fused, raising=False
)
(hidden_states, attention_mask, position_ids, cache_position,
page_idx, page_offset, block_table,
past_key_value, q_len_raw, kv_len_raw) = _common_inputs_prefill()
with pytest.raises(ValueError):
attn.forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=False,
use_cache=True,
cache_position=cache_position,
is_prefill=True,
page_idx=page_idx,
page_offset=page_offset,
block_table=block_table,
q_len_raw=q_len_raw,
kv_len_raw=kv_len_raw,
stream=None,
)
def test_forward_paged_use_npu_graph(monkeypatch):
"""
覆盖: get_use_npu_graph() == True 的 graph 路径。
"""
# 让 ascend_attention.get_use_npu_graph 返回 True
monkeypatch.setattr(attn_mod, "get_use_npu_graph", lambda: True)
# 伪造 model_runner 模块,满足 import ktransformers.server.balance_serve.inference.model_runner
dummy_runner = type(
"DummyRunner", (), {"__init__": lambda self: setattr(self, "workspace", [None] * 4)}
)
dummy_mr = types.SimpleNamespace(
ModelRunner=dummy_runner,
get_or_create_model_runner=lambda device=None: dummy_runner(),
)
sys.modules[
"ktransformers.server.balance_serve.inference.model_runner"
] = dummy_mr
attn = build_attention_module(q_lora_rank=None)
bsz, q_len, hidden_dim = 1, 1, 4
hidden_states = torch.randn(bsz, q_len, hidden_dim, dtype=torch.float16)
position_ids = torch.arange(q_len).unsqueeze(0)
cache_position = torch.arange(q_len).unsqueeze(0)
past_key_value = DummyStaticCache(page_size=16)
q_len_raw = torch.tensor([q_len], dtype=torch.int32)
kv_len_raw = torch.tensor([q_len], dtype=torch.int32)
block_table = torch.zeros(bsz, 1, dtype=torch.int32)
outputs = attn.forward(
hidden_states=hidden_states,
attention_mask=None,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=False,
use_cache=True,
cache_position=cache_position,
is_prefill=False,
page_idx=None,
page_offset=None,
block_table=block_table,
q_len_raw=q_len_raw,
kv_len_raw=kv_len_raw,
stream=None,
)
attn_output, attn_weights, new_cache = outputs
assert attn_output.shape == (
bsz,
q_len,
attn.num_heads * attn.v_head_dim,
)
assert attn_weights is None
assert new_cache is past_key_value

View File

@@ -0,0 +1,163 @@
import torch
import torch.nn as nn
import pytest
# 按你实际代码位置改路径:
from ktransformers.operators.ascend.ascend_layernorm import KDeepseekV3RMSNormW8A8
import ktransformers.util.utils as utils_mod
torch_npu = pytest.importorskip("torch_npu")
# ==========================
# Dummy 依赖
# ==========================
class DummyOrigModule(nn.Module):
def __init__(self, hidden_size=4, variance_epsilon=1e-5):
super().__init__()
self.hidden_size = hidden_size
self.variance_epsilon = variance_epsilon
class DummySafeTensorLoader:
def __init__(self):
self.tensors = {}
self.load_calls = []
def load_tensor(self, name: str):
self.load_calls.append(name)
return self.tensors[name]
class DummyGGUFLoader:
def __init__(self, safetensor_loader: DummySafeTensorLoader):
self.safetensor_loader = safetensor_loader
class DummyConfig:
pass
class FakeRMSNorm:
def __init__(self):
self.last_args = None
def __call__(self, hidden_states, weight, eps):
self.last_args = (hidden_states, weight, eps)
out = hidden_states * weight
return (out,)
def build_rms_module(hidden_size=4, eps=1e-5, safetensor_loader=None):
orig = DummyOrigModule(hidden_size=hidden_size, variance_epsilon=eps)
if safetensor_loader is None:
safetensor_loader = DummySafeTensorLoader()
gguf_loader = DummyGGUFLoader(safetensor_loader)
config = DummyConfig()
module = KDeepseekV3RMSNormW8A8(
key="rms",
gguf_loader=gguf_loader,
config=config,
orig_module=orig,
prefill_device="npu",
generate_device="npu",
)
return module, safetensor_loader, orig
@pytest.fixture(autouse=True)
def patch_utils_and_npu(monkeypatch):
monkeypatch.setattr(utils_mod, "get_current_device", lambda: "cpu", raising=False)
fake = FakeRMSNorm()
monkeypatch.setattr(torch_npu, "npu_rms_norm", fake, raising=False)
import sys
sys.modules[__name__]._fake_rms = fake
yield
def get_fake_rms():
import sys
return sys.modules[__name__]._fake_rms
def test_forward_preserves_shape_and_dtype():
hidden_size = 4
module, _, orig = build_rms_module(hidden_size=hidden_size, eps=1e-6)
x = torch.randn(2, 3, hidden_size, dtype=torch.float16)
out = module(x)
assert out.shape == x.shape
assert out.dtype == x.dtype
fake_rms = get_fake_rms()
hs_arg, w_arg, eps_arg = fake_rms.last_args
assert hs_arg is x
assert w_arg is module.weight
assert eps_arg == orig.variance_epsilon
def test_forward_with_bfloat16_dtype():
hidden_size = 4
module, _, _ = build_rms_module(hidden_size=hidden_size, eps=1e-6)
x = torch.randn(1, 2, hidden_size, dtype=torch.bfloat16)
out = module(x)
assert out.shape == x.shape
assert out.dtype == torch.bfloat16
def test_forward_uses_bias():
hidden_size = 4
module, _, _ = build_rms_module(hidden_size=hidden_size, eps=1e-6)
module.weight.data = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
module.bias.data = torch.tensor([-1.0, 0.5, 0.0, 2.0], dtype=torch.float32)
x = torch.arange(2 * 3 * hidden_size, dtype=torch.float16).view(2, 3, hidden_size)
out = module(x)
expected_rms = x.to(torch.float32) * module.weight
expected = expected_rms + module.bias
assert torch.allclose(out, expected.to(out.dtype))
def test_load_from_safetensor_loader():
hidden_size = 4
module, safe_loader, _ = build_rms_module(hidden_size=hidden_size, eps=1e-5)
w_loaded = torch.arange(hidden_size, dtype=torch.float32)
b_loaded = torch.full((hidden_size,), 3.0, dtype=torch.float32)
safe_loader.tensors["rms.weight"] = w_loaded
safe_loader.tensors["rms.bias"] = b_loaded
module.load()
assert torch.allclose(module.weight, w_loaded)
assert torch.allclose(module.bias, b_loaded)
assert safe_loader.load_calls == ["rms.weight", "rms.bias"]
def test_unload_sets_weight_and_bias_to_none_idempotent():
module, _, _ = build_rms_module(hidden_size=4, eps=1e-5)
assert module.weight is not None
assert module.bias is not None
module.unload()
assert module.weight is None
assert module.bias is None
module.unload()
assert module.weight is None
assert module.bias is None

View File

@@ -0,0 +1,123 @@
import os
import ast
import argparse
from coverage import Coverage
def main():
parser = argparse.ArgumentParser(
description="统计某个类在 .coverage 数据中的行覆盖率"
)
parser.add_argument(
"--data-file",
default=".coverage",
help="coverage 数据文件路径(默认 ./.coverage",
)
parser.add_argument(
"--file",
dest="file_pattern",
default="ktransformers/operators/ascend/ascend_attention.py",
help=(
"要统计的源码文件路径(可用结尾匹配,默认 "
"ktransformers/operators/ascend/ascend_attention.py"
),
)
parser.add_argument(
"--class",
dest="class_name",
default="KDeepseekV2AttentionW8A8A2Serve",
help="要统计的类名(默认 KDeepseekV2AttentionW8A8A2Serve",
)
args = parser.parse_args()
if not os.path.exists(args.data_file):
print(f"找不到 coverage 数据文件: {args.data_file}")
raise SystemExit(1)
cov = Coverage(data_file=args.data_file)
cov.load()
data = cov.get_data()
file_pattern_norm = os.path.normpath(args.file_pattern)
target_file = None
for f in data.measured_files():
f_norm = os.path.normpath(f)
if f_norm.endswith(file_pattern_norm) or file_pattern_norm in f_norm:
target_file = f
break
if not target_file:
print(
f"没有在 coverage 数据里找到匹配文件: {args.file_pattern}\n"
f"实际记录的文件有:"
)
for f in data.measured_files():
print(" ", f)
raise SystemExit(1)
print("使用的源码文件:", target_file)
executed_lines = set(data.lines(target_file) or [])
try:
with open(target_file, "r", encoding="utf-8") as f:
source_text = f.read()
except OSError as e:
print(f"无法打开源码文件 {target_file}: {e}")
raise SystemExit(1)
source_lines = source_text.splitlines()
tree = ast.parse(source_text)
class_start = None
class_end = None
for node in tree.body:
if isinstance(node, ast.ClassDef) and node.name == args.class_name:
class_start = node.lineno
max_lineno = node.lineno
for sub in ast.walk(node):
ln = getattr(sub, "end_lineno", getattr(sub, "lineno", None))
if ln is not None and ln > max_lineno:
max_lineno = ln
class_end = max_lineno
break
if class_start is None:
print(f"在源码 {target_file} 中没有找到类 {args.class_name}")
raise SystemExit(1)
print(
f"{args.class_name} 行范围: {class_start} ~ {class_end}"
)
total = 0
covered = 0
missed_lines = []
for lineno in range(class_start, class_end + 1):
line = source_lines[lineno - 1].strip()
# 跳过空行和纯注释
if not line or line.startswith("#"):
continue
total += 1
if lineno in executed_lines:
covered += 1
else:
missed_lines.append(lineno)
percent = (covered / total * 100) if total > 0 else 0.0
print(
f"{args.class_name} 覆盖: {covered}/{total} 行, 覆盖率 = {percent:.1f}%"
)
if missed_lines:
print("未覆盖行号:", missed_lines)
else:
print("该类所有有效代码行均被覆盖")
if __name__ == "__main__":
main()