[rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)

[CK_TILE] MX GEMM non-preshuffled RCR layout

## Motivation

Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled
layouts using async pipeline.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Sami Remes
2026-03-10 20:12:43 +00:00
committed by assistant-librarian[bot]
parent b8def2c724
commit 8f27f65d44
40 changed files with 2729 additions and 43 deletions

View File

@@ -20,3 +20,21 @@ TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesCompAsync);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME
template <typename T>
class TestCkTileGemmPipelineCompAsync16x16x128
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompAsync16x16x128<T>>
{
public:
static constexpr bool check_data_type() { return true; }
};
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompAsync16x16x128, KernelTypesCompAsync16x16x128);
TYPED_TEST(TestCkTileGemmPipelineCompAsync16x16x128, QuickTest)
{
constexpr int M = 1024;
constexpr int N = 1024;
constexpr int K = 1024;
this->template RunSingle<false, false, false, false>(M, N, K, 0, 0, 0, 1);
}

View File

@@ -29,6 +29,7 @@ using NonPersistent = std::false_type;
using I16 = ck_tile::number<16>;
using I32 = ck_tile::number<32>;
using I64 = ck_tile::number<64>;
using I128 = ck_tile::number<128>;
using I256 = ck_tile::number<256>;
// clang-format off
@@ -224,6 +225,23 @@ using CompAsyncConfig = std::tuple<ALayout,
Intrawave,
CompAsync>;
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
using CompAsyncConfig16x16x128 = std::tuple<ALayout,
BLayout,
CLayout,
InputType, // AType
InputType, // BType
F32, // AccType
F16, // OutputType
I64, // MBlockTileSize
I64, // NBlockTileSize
I128, // KBlockTileSize
I16, // MWarpTileSize
I16, // NWarpTileSize
I128, // KWarpTileSize
Intrawave,
CompAsync>;
using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16>,
CompAsyncConfig<Row, Col, Row, F16>,
CompAsyncConfig<Col, Row, Row, F16>,
@@ -232,6 +250,10 @@ using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16
CompAsyncConfig<Row, Col, Row, F8>,
CompAsyncConfig<Col, Row, Row, F8>,
CompAsyncConfig<Col, Col, Row, F8>>;
using KernelTypesCompAsync16x16x128 = ::testing::Types<CompAsyncConfig16x16x128<Row, Col, Row, F4>,
CompAsyncConfig16x16x128<Row, Col, Row, F8>>;
// clang-format off
using KernelTypesCompV6 = ::testing::Types<

View File

@@ -7,6 +7,7 @@ using INT32 = ck_tile::int32_t;
using F16 = ck_tile::half_t;
using F32 = float;
using F8 = ck_tile::fp8_t;
using F4 = ck_tile::pk_fp4_t;
using BF16 = ck_tile::bf16_t;
using BF8 = ck_tile::bf8_t;