[jit_kernel] Tiny unify jit_kernel tests style (#19694)

This commit is contained in:
Xiaoyu Zhang
2026-03-02 21:33:59 +08:00
committed by GitHub
parent 714c53d609
commit 53de53fb53
23 changed files with 37 additions and 47 deletions

View File

@@ -1,14 +1,16 @@
import pytest
import torch
from sglang.jit_kernel.add_constant import add_constant
def main():
c = 1024
src = torch.arange(0, 1024 + 1, dtype=torch.int32).cuda()
dst = add_constant(src, c)
assert torch.all(dst == src + c)
@pytest.mark.parametrize("size", [1, 2, 127, 128, 1024, 1025])
@pytest.mark.parametrize("constant", [0, 1, 7, 1024, -3])
def test_add_constant(size: int, constant: int) -> None:
src = torch.arange(0, size, dtype=torch.int32, device="cuda")
dst = add_constant(src, constant)
assert torch.all(dst == src + constant)
if __name__ == "__main__":
main()
pytest.main([__file__, "-v", "-s"])

View File

@@ -161,4 +161,4 @@ def test_awq_dequantize_jit_vs_aot(
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__, "-v", "-s"])

View File

@@ -112,6 +112,4 @@ def test_awq_marlin_moe_repack_shape(
if __name__ == "__main__":
import subprocess
subprocess.call(["pytest", "--tb=short", str(__file__)])
pytest.main([__file__, "-v", "-s"])

View File

@@ -98,6 +98,4 @@ def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size):
if __name__ == "__main__":
import subprocess
subprocess.call(["pytest", "--tb=short", str(__file__)])
pytest.main([__file__, "-v", "-s"])

View File

@@ -166,4 +166,4 @@ def test_concat_mla_absorb_q_jit_vs_aot(dim_0: int, dim_1: int) -> None:
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__, "-v", "-s"])

View File

@@ -296,4 +296,4 @@ def test_cutedsl_gdn_performance(B: int):
if __name__ == "__main__":
pytest.main([__file__, "-v"])
pytest.main([__file__, "-v", "-s"])

View File

@@ -1501,4 +1501,4 @@ def _generate_block_kvcache(
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__, "-v", "-s"])

View File

@@ -52,4 +52,4 @@ def test_fused_add_rmsnorm(batch_size: int, hidden_size: int) -> None:
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__, "-v", "-s"])

View File

@@ -233,4 +233,4 @@ class TestFusedScaleResidualNormScaleShift:
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__, "-v", "-s"])

View File

@@ -450,4 +450,4 @@ def test_reference_writes_nonzero():
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
pytest.main([__file__, "-v", "-s"])

View File

@@ -96,6 +96,4 @@ def test_gptq_marlin_gemm(
if __name__ == "__main__":
import subprocess
subprocess.call(["pytest", "--tb=short", str(__file__)])
pytest.main([__file__, "-v", "-s"])

View File

@@ -87,6 +87,4 @@ def test_gptq_marlin_repack(
if __name__ == "__main__":
import subprocess
subprocess.call(["pytest", "--tb=short", str(__file__)])
pytest.main([__file__, "-v", "-s"])

View File

@@ -324,6 +324,4 @@ def test_moe_wna16_marlin_gemm(
if __name__ == "__main__":
import subprocess
subprocess.call(["pytest", "--tb=short", "-v", str(__file__)])
pytest.main([__file__, "-v", "-s"])

View File

@@ -83,4 +83,4 @@ def test_jit_per_tensor_quant_supports_3d(shape):
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__, "-v", "-s"])

View File

@@ -202,4 +202,4 @@ def test_per_token_group_quant_with_column_major(
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__, "-v", "-s"])

View File

@@ -371,10 +371,10 @@ def test_correctness(
],
)
def test_performance(
head_size,
rotary_dim,
max_position_embeddings,
base,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style,
dtype,
device,
@@ -483,3 +483,7 @@ def test_performance(
print(f"Speedup (SGL/JIT): {speedup:.2f}x")
assert jit_time >= 0 and sgl_time >= 0
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -90,4 +90,4 @@ def test_qknorm(batch_size: int, n_k: int, n_q: int, head_dim: int) -> None:
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__, "-v", "-s"])

View File

@@ -72,4 +72,4 @@ def test_qknorm_across_heads(batch_size: int, hidden_dim: int) -> None:
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__, "-v", "-s"])

View File

@@ -38,4 +38,4 @@ def test_rmsnorm(batch_size: int, hidden_size: int) -> None:
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__, "-v", "-s"])

View File

@@ -241,4 +241,4 @@ def test_fused_rope_store(
if __name__ == "__main__":
pytest.main([__file__, "-v"])
pytest.main([__file__, "-v", "-s"])

View File

@@ -32,4 +32,4 @@ def test_store_cache(batch_size: int, element_dim: int) -> None:
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__, "-v", "-s"])

View File

@@ -157,4 +157,4 @@ def test_timestep_embedding_perf():
if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__, "-v", "-s"])