diff --git a/.claude/skills/add-jit-kernel/SKILL.md b/.claude/skills/add-jit-kernel/SKILL.md index 3175e761c..d7f944bf5 100644 --- a/.claude/skills/add-jit-kernel/SKILL.md +++ b/.claude/skills/add-jit-kernel/SKILL.md @@ -446,13 +446,7 @@ def test_scale_unsupported_dtype(): if __name__ == "__main__": - pytest.main([__file__, "-q"]) -``` - -Run: - -```bash -pytest python/sglang/jit_kernel/tests/test_scale.py -q + pytest.main([__file__, "-v", "-s"]) ``` --- diff --git a/python/sglang/jit_kernel/tests/test_add_constant.py b/python/sglang/jit_kernel/tests/test_add_constant.py index d588fc518..36ea024ba 100644 --- a/python/sglang/jit_kernel/tests/test_add_constant.py +++ b/python/sglang/jit_kernel/tests/test_add_constant.py @@ -1,14 +1,16 @@ +import pytest import torch from sglang.jit_kernel.add_constant import add_constant -def main(): - c = 1024 - src = torch.arange(0, 1024 + 1, dtype=torch.int32).cuda() - dst = add_constant(src, c) - assert torch.all(dst == src + c) +@pytest.mark.parametrize("size", [1, 2, 127, 128, 1024, 1025]) +@pytest.mark.parametrize("constant", [0, 1, 7, 1024, -3]) +def test_add_constant(size: int, constant: int) -> None: + src = torch.arange(0, size, dtype=torch.int32, device="cuda") + dst = add_constant(src, constant) + assert torch.all(dst == src + constant) if __name__ == "__main__": - main() + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_awq_dequantize.py b/python/sglang/jit_kernel/tests/test_awq_dequantize.py index e29475843..d2970e99b 100644 --- a/python/sglang/jit_kernel/tests/test_awq_dequantize.py +++ b/python/sglang/jit_kernel/tests/test_awq_dequantize.py @@ -161,4 +161,4 @@ def test_awq_dequantize_jit_vs_aot( if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_awq_marlin_moe_repack.py b/python/sglang/jit_kernel/tests/test_awq_marlin_moe_repack.py index 217dfc0a6..e4741b373 100644 --- a/python/sglang/jit_kernel/tests/test_awq_marlin_moe_repack.py +++ b/python/sglang/jit_kernel/tests/test_awq_marlin_moe_repack.py @@ -112,6 +112,4 @@ def test_awq_marlin_moe_repack_shape( if __name__ == "__main__": - import subprocess - - subprocess.call(["pytest", "--tb=short", str(__file__)]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_awq_marlin_repack.py b/python/sglang/jit_kernel/tests/test_awq_marlin_repack.py index 819fcf276..ba959ccac 100644 --- a/python/sglang/jit_kernel/tests/test_awq_marlin_repack.py +++ b/python/sglang/jit_kernel/tests/test_awq_marlin_repack.py @@ -98,6 +98,4 @@ def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size): if __name__ == "__main__": - import subprocess - - subprocess.call(["pytest", "--tb=short", str(__file__)]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_concat_mla.py b/python/sglang/jit_kernel/tests/test_concat_mla.py index 6c5d3631d..cf1013c5e 100644 --- a/python/sglang/jit_kernel/tests/test_concat_mla.py +++ b/python/sglang/jit_kernel/tests/test_concat_mla.py @@ -166,4 +166,4 @@ def test_concat_mla_absorb_q_jit_vs_aot(dim_0: int, dim_1: int) -> None: if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_cutedsl_gdn.py b/python/sglang/jit_kernel/tests/test_cutedsl_gdn.py index 3831ac2e0..4139f24ae 100644 --- a/python/sglang/jit_kernel/tests/test_cutedsl_gdn.py +++ b/python/sglang/jit_kernel/tests/test_cutedsl_gdn.py @@ -296,4 +296,4 @@ def test_cutedsl_gdn_performance(B: int): if __name__ == "__main__": - pytest.main([__file__, "-v"]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_flash_attention_4.py b/python/sglang/jit_kernel/tests/test_flash_attention_4.py index fe19c9175..1540d4601 100644 --- a/python/sglang/jit_kernel/tests/test_flash_attention_4.py +++ b/python/sglang/jit_kernel/tests/test_flash_attention_4.py @@ -1501,4 +1501,4 @@ def _generate_block_kvcache( if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_fused_add_rmsnorm.py b/python/sglang/jit_kernel/tests/test_fused_add_rmsnorm.py index a763408d0..52c2dc612 100644 --- a/python/sglang/jit_kernel/tests/test_fused_add_rmsnorm.py +++ b/python/sglang/jit_kernel/tests/test_fused_add_rmsnorm.py @@ -52,4 +52,4 @@ def test_fused_add_rmsnorm(batch_size: int, hidden_size: int) -> None: if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_fused_norm_scale_shift.py b/python/sglang/jit_kernel/tests/test_fused_norm_scale_shift.py index 443edf3cc..592103b12 100644 --- a/python/sglang/jit_kernel/tests/test_fused_norm_scale_shift.py +++ b/python/sglang/jit_kernel/tests/test_fused_norm_scale_shift.py @@ -233,4 +233,4 @@ class TestFusedScaleResidualNormScaleShift: if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_fused_store_index_cache.py b/python/sglang/jit_kernel/tests/test_fused_store_index_cache.py index 6a1b401cc..00edb5819 100644 --- a/python/sglang/jit_kernel/tests/test_fused_store_index_cache.py +++ b/python/sglang/jit_kernel/tests/test_fused_store_index_cache.py @@ -450,4 +450,4 @@ def test_reference_writes_nonzero(): if __name__ == "__main__": - pytest.main([__file__, "-v", "--tb=short"]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_gptq_marlin.py b/python/sglang/jit_kernel/tests/test_gptq_marlin.py index b023b5956..c7cdb1e6c 100644 --- a/python/sglang/jit_kernel/tests/test_gptq_marlin.py +++ b/python/sglang/jit_kernel/tests/test_gptq_marlin.py @@ -96,6 +96,4 @@ def test_gptq_marlin_gemm( if __name__ == "__main__": - import subprocess - - subprocess.call(["pytest", "--tb=short", str(__file__)]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_gptq_marlin_repack.py b/python/sglang/jit_kernel/tests/test_gptq_marlin_repack.py index ef02be3e8..0c571dbff 100644 --- a/python/sglang/jit_kernel/tests/test_gptq_marlin_repack.py +++ b/python/sglang/jit_kernel/tests/test_gptq_marlin_repack.py @@ -87,6 +87,4 @@ def test_gptq_marlin_repack( if __name__ == "__main__": - import subprocess - - subprocess.call(["pytest", "--tb=short", str(__file__)]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py b/python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py index f894dc118..e40f82461 100644 --- a/python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py +++ b/python/sglang/jit_kernel/tests/test_moe_wna16_marlin.py @@ -324,6 +324,4 @@ def test_moe_wna16_marlin_gemm( if __name__ == "__main__": - import subprocess - - subprocess.call(["pytest", "--tb=short", "-v", str(__file__)]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_per_tensor_quant_fp8.py b/python/sglang/jit_kernel/tests/test_per_tensor_quant_fp8.py index a08b698f9..b560127ae 100644 --- a/python/sglang/jit_kernel/tests/test_per_tensor_quant_fp8.py +++ b/python/sglang/jit_kernel/tests/test_per_tensor_quant_fp8.py @@ -83,4 +83,4 @@ def test_jit_per_tensor_quant_supports_3d(shape): if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_per_token_group_quant_8bit.py b/python/sglang/jit_kernel/tests/test_per_token_group_quant_8bit.py index 55f6f35be..eebd49527 100644 --- a/python/sglang/jit_kernel/tests/test_per_token_group_quant_8bit.py +++ b/python/sglang/jit_kernel/tests/test_per_token_group_quant_8bit.py @@ -202,4 +202,4 @@ def test_per_token_group_quant_with_column_major( if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_pos_enc.py b/python/sglang/jit_kernel/tests/test_pos_enc.py index 3656a6c2f..4b809b002 100644 --- a/python/sglang/jit_kernel/tests/test_pos_enc.py +++ b/python/sglang/jit_kernel/tests/test_pos_enc.py @@ -371,10 +371,10 @@ def test_correctness( ], ) def test_performance( - head_size, - rotary_dim, - max_position_embeddings, - base, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, is_neox_style, dtype, device, @@ -483,3 +483,7 @@ def test_performance( print(f"Speedup (SGL/JIT): {speedup:.2f}x") assert jit_time >= 0 and sgl_time >= 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_qknorm.py b/python/sglang/jit_kernel/tests/test_qknorm.py index 4dd1963d8..ee72e9ec6 100644 --- a/python/sglang/jit_kernel/tests/test_qknorm.py +++ b/python/sglang/jit_kernel/tests/test_qknorm.py @@ -90,4 +90,4 @@ def test_qknorm(batch_size: int, n_k: int, n_q: int, head_dim: int) -> None: if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_qknorm_across_heads.py b/python/sglang/jit_kernel/tests/test_qknorm_across_heads.py index d00c713de..f090d82ad 100644 --- a/python/sglang/jit_kernel/tests/test_qknorm_across_heads.py +++ b/python/sglang/jit_kernel/tests/test_qknorm_across_heads.py @@ -72,4 +72,4 @@ def test_qknorm_across_heads(batch_size: int, hidden_dim: int) -> None: if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_rmsnorm.py b/python/sglang/jit_kernel/tests/test_rmsnorm.py index 168a95334..501124daf 100644 --- a/python/sglang/jit_kernel/tests/test_rmsnorm.py +++ b/python/sglang/jit_kernel/tests/test_rmsnorm.py @@ -38,4 +38,4 @@ def test_rmsnorm(batch_size: int, hidden_size: int) -> None: if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_rope.py b/python/sglang/jit_kernel/tests/test_rope.py index 99c271966..0aba5cf4c 100644 --- a/python/sglang/jit_kernel/tests/test_rope.py +++ b/python/sglang/jit_kernel/tests/test_rope.py @@ -241,4 +241,4 @@ def test_fused_rope_store( if __name__ == "__main__": - pytest.main([__file__, "-v"]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_store_cache.py b/python/sglang/jit_kernel/tests/test_store_cache.py index 770f257f9..ee5ddae14 100644 --- a/python/sglang/jit_kernel/tests/test_store_cache.py +++ b/python/sglang/jit_kernel/tests/test_store_cache.py @@ -32,4 +32,4 @@ def test_store_cache(batch_size: int, element_dim: int) -> None: if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_timestep_embedding.py b/python/sglang/jit_kernel/tests/test_timestep_embedding.py index 2ed242429..068363774 100644 --- a/python/sglang/jit_kernel/tests/test_timestep_embedding.py +++ b/python/sglang/jit_kernel/tests/test_timestep_embedding.py @@ -157,4 +157,4 @@ def test_timestep_embedding_perf(): if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "-v", "-s"])