[CLI] add cutedsl fp16 gemm tutorial from 2 to 6 (#3106)

* [CLI] add fp16 gemm tutorial from 2 to 6

* [CLI] refine comments
This commit is contained in:
Linfeng Zheng
2026-03-17 10:11:55 +08:00
committed by GitHub
parent 087c84df83
commit 772fbb264e
9 changed files with 5496 additions and 93 deletions

View File

@@ -28,6 +28,12 @@
from blackwell.tutorial_gemm import fp16_gemm_0
from blackwell.tutorial_gemm import fp16_gemm_1
from blackwell.tutorial_gemm import fp16_gemm_2
from blackwell.tutorial_gemm import fp16_gemm_3
from blackwell.tutorial_gemm import fp16_gemm_3_1
from blackwell.tutorial_gemm import fp16_gemm_4
from blackwell.tutorial_gemm import fp16_gemm_5
from blackwell.tutorial_gemm import fp16_gemm_6
import pytest
from typing import Tuple
@@ -54,4 +60,77 @@ def test_fp16_gemm_1(
mnk: Tuple[int, int, int],
tolerance: float,
):
fp16_gemm_1.run_dense_gemm(mnk, tolerance)
fp16_gemm_1.run_dense_gemm(mnk, tolerance)
@pytest.mark.parametrize(
"mnk",
[(512, 512, 256)],
)
@pytest.mark.parametrize("tolerance", [1e-01])
def test_fp16_gemm_2(
mnk: Tuple[int, int, int],
tolerance: float,
):
fp16_gemm_2.run_dense_gemm(mnk, tolerance)
@pytest.mark.parametrize(
"mnk",
[(512, 512, 256)],
)
@pytest.mark.parametrize("tolerance", [1e-01])
def test_fp16_gemm_3(
mnk: Tuple[int, int, int],
tolerance: float,
):
fp16_gemm_3.run_dense_gemm(mnk, tolerance)
@pytest.mark.parametrize(
"mnk",
[(512, 512, 256)],
)
@pytest.mark.parametrize("tolerance", [1e-01])
def test_fp16_gemm_3_1(
mnk: Tuple[int, int, int],
tolerance: float,
):
fp16_gemm_3_1.run_dense_gemm(mnk, tolerance)
@pytest.mark.parametrize(
"mnk",
[(512, 512, 256)],
)
@pytest.mark.parametrize("tolerance", [1e-01])
def test_fp16_gemm_4(
mnk: Tuple[int, int, int],
tolerance: float,
):
fp16_gemm_4.run_dense_gemm(mnk, tolerance)
@pytest.mark.parametrize(
"mnk",
[(512, 512, 256)],
)
@pytest.mark.parametrize("tolerance", [1e-01])
def test_fp16_gemm_5(
mnk: Tuple[int, int, int],
tolerance: float,
):
fp16_gemm_5.run_dense_gemm(mnk, tolerance)
@pytest.mark.parametrize(
"mnk",
[(512, 512, 256)],
)
@pytest.mark.parametrize("tolerance", [1e-01])
def test_fp16_gemm_6(
mnk: Tuple[int, int, int],
tolerance: float,
):
fp16_gemm_6.run_dense_gemm(mnk, tolerance)