mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-07-02 04:37:14 +00:00
[jit_kernel] Tiny unify jit_kernel tests style (#19694)
This commit is contained in:
@@ -1,14 +1,16 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from sglang.jit_kernel.add_constant import add_constant
|
||||
|
||||
|
||||
def main():
|
||||
c = 1024
|
||||
src = torch.arange(0, 1024 + 1, dtype=torch.int32).cuda()
|
||||
dst = add_constant(src, c)
|
||||
assert torch.all(dst == src + c)
|
||||
@pytest.mark.parametrize("size", [1, 2, 127, 128, 1024, 1025])
|
||||
@pytest.mark.parametrize("constant", [0, 1, 7, 1024, -3])
|
||||
def test_add_constant(size: int, constant: int) -> None:
|
||||
src = torch.arange(0, size, dtype=torch.int32, device="cuda")
|
||||
dst = add_constant(src, constant)
|
||||
assert torch.all(dst == src + constant)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -161,4 +161,4 @@ def test_awq_dequantize_jit_vs_aot(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -112,6 +112,4 @@ def test_awq_marlin_moe_repack_shape(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import subprocess
|
||||
|
||||
subprocess.call(["pytest", "--tb=short", str(__file__)])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -98,6 +98,4 @@ def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import subprocess
|
||||
|
||||
subprocess.call(["pytest", "--tb=short", str(__file__)])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -166,4 +166,4 @@ def test_concat_mla_absorb_q_jit_vs_aot(dim_0: int, dim_1: int) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -296,4 +296,4 @@ def test_cutedsl_gdn_performance(B: int):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -1501,4 +1501,4 @@ def _generate_block_kvcache(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -52,4 +52,4 @@ def test_fused_add_rmsnorm(batch_size: int, hidden_size: int) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -233,4 +233,4 @@ class TestFusedScaleResidualNormScaleShift:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -450,4 +450,4 @@ def test_reference_writes_nonzero():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -96,6 +96,4 @@ def test_gptq_marlin_gemm(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import subprocess
|
||||
|
||||
subprocess.call(["pytest", "--tb=short", str(__file__)])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -87,6 +87,4 @@ def test_gptq_marlin_repack(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import subprocess
|
||||
|
||||
subprocess.call(["pytest", "--tb=short", str(__file__)])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -324,6 +324,4 @@ def test_moe_wna16_marlin_gemm(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import subprocess
|
||||
|
||||
subprocess.call(["pytest", "--tb=short", "-v", str(__file__)])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -83,4 +83,4 @@ def test_jit_per_tensor_quant_supports_3d(shape):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -202,4 +202,4 @@ def test_per_token_group_quant_with_column_major(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -371,10 +371,10 @@ def test_correctness(
|
||||
],
|
||||
)
|
||||
def test_performance(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
@@ -483,3 +483,7 @@ def test_performance(
|
||||
print(f"Speedup (SGL/JIT): {speedup:.2f}x")
|
||||
|
||||
assert jit_time >= 0 and sgl_time >= 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -90,4 +90,4 @@ def test_qknorm(batch_size: int, n_k: int, n_q: int, head_dim: int) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -72,4 +72,4 @@ def test_qknorm_across_heads(batch_size: int, hidden_dim: int) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -38,4 +38,4 @@ def test_rmsnorm(batch_size: int, hidden_size: int) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -241,4 +241,4 @@ def test_fused_rope_store(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -32,4 +32,4 @@ def test_store_cache(batch_size: int, element_dim: int) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
@@ -157,4 +157,4 @@ def test_timestep_embedding_perf():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
Reference in New Issue
Block a user