mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-20 12:59:01 +00:00
[CLI] add cutedsl fp16 gemm tutorial from 2 to 6 (#3106)
* [CLI] add fp16 gemm tutorial from 2 to 6 * [CLI] refine comments
This commit is contained in:
@@ -28,6 +28,12 @@
|
||||
|
||||
from blackwell.tutorial_gemm import fp16_gemm_0
|
||||
from blackwell.tutorial_gemm import fp16_gemm_1
|
||||
from blackwell.tutorial_gemm import fp16_gemm_2
|
||||
from blackwell.tutorial_gemm import fp16_gemm_3
|
||||
from blackwell.tutorial_gemm import fp16_gemm_3_1
|
||||
from blackwell.tutorial_gemm import fp16_gemm_4
|
||||
from blackwell.tutorial_gemm import fp16_gemm_5
|
||||
from blackwell.tutorial_gemm import fp16_gemm_6
|
||||
|
||||
import pytest
|
||||
from typing import Tuple
|
||||
@@ -54,4 +60,77 @@ def test_fp16_gemm_1(
|
||||
mnk: Tuple[int, int, int],
|
||||
tolerance: float,
|
||||
):
|
||||
fp16_gemm_1.run_dense_gemm(mnk, tolerance)
|
||||
fp16_gemm_1.run_dense_gemm(mnk, tolerance)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mnk",
|
||||
[(512, 512, 256)],
|
||||
)
|
||||
@pytest.mark.parametrize("tolerance", [1e-01])
|
||||
def test_fp16_gemm_2(
|
||||
mnk: Tuple[int, int, int],
|
||||
tolerance: float,
|
||||
):
|
||||
fp16_gemm_2.run_dense_gemm(mnk, tolerance)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mnk",
|
||||
[(512, 512, 256)],
|
||||
)
|
||||
@pytest.mark.parametrize("tolerance", [1e-01])
|
||||
def test_fp16_gemm_3(
|
||||
mnk: Tuple[int, int, int],
|
||||
tolerance: float,
|
||||
):
|
||||
fp16_gemm_3.run_dense_gemm(mnk, tolerance)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mnk",
|
||||
[(512, 512, 256)],
|
||||
)
|
||||
@pytest.mark.parametrize("tolerance", [1e-01])
|
||||
def test_fp16_gemm_3_1(
|
||||
mnk: Tuple[int, int, int],
|
||||
tolerance: float,
|
||||
):
|
||||
fp16_gemm_3_1.run_dense_gemm(mnk, tolerance)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mnk",
|
||||
[(512, 512, 256)],
|
||||
)
|
||||
@pytest.mark.parametrize("tolerance", [1e-01])
|
||||
def test_fp16_gemm_4(
|
||||
mnk: Tuple[int, int, int],
|
||||
tolerance: float,
|
||||
):
|
||||
fp16_gemm_4.run_dense_gemm(mnk, tolerance)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mnk",
|
||||
[(512, 512, 256)],
|
||||
)
|
||||
@pytest.mark.parametrize("tolerance", [1e-01])
|
||||
def test_fp16_gemm_5(
|
||||
mnk: Tuple[int, int, int],
|
||||
tolerance: float,
|
||||
):
|
||||
fp16_gemm_5.run_dense_gemm(mnk, tolerance)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mnk",
|
||||
[(512, 512, 256)],
|
||||
)
|
||||
@pytest.mark.parametrize("tolerance", [1e-01])
|
||||
def test_fp16_gemm_6(
|
||||
mnk: Tuple[int, int, int],
|
||||
tolerance: float,
|
||||
):
|
||||
fp16_gemm_6.run_dense_gemm(mnk, tolerance)
|
||||
|
||||
Reference in New Issue
Block a user