mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)
[CK_TILE] MX GEMM non-preshuffled RCR layout ## Motivation Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled layouts using async pipeline. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b8def2c724
commit
8f27f65d44
@@ -20,3 +20,21 @@ TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesCompAsync);
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompAsync16x16x128
|
||||
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompAsync16x16x128<T>>
|
||||
{
|
||||
public:
|
||||
static constexpr bool check_data_type() { return true; }
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompAsync16x16x128, KernelTypesCompAsync16x16x128);
|
||||
TYPED_TEST(TestCkTileGemmPipelineCompAsync16x16x128, QuickTest)
|
||||
{
|
||||
constexpr int M = 1024;
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 1024;
|
||||
|
||||
this->template RunSingle<false, false, false, false>(M, N, K, 0, 0, 0, 1);
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ using NonPersistent = std::false_type;
|
||||
using I16 = ck_tile::number<16>;
|
||||
using I32 = ck_tile::number<32>;
|
||||
using I64 = ck_tile::number<64>;
|
||||
using I128 = ck_tile::number<128>;
|
||||
using I256 = ck_tile::number<256>;
|
||||
|
||||
// clang-format off
|
||||
@@ -224,6 +225,23 @@ using CompAsyncConfig = std::tuple<ALayout,
|
||||
Intrawave,
|
||||
CompAsync>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
|
||||
using CompAsyncConfig16x16x128 = std::tuple<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
InputType, // AType
|
||||
InputType, // BType
|
||||
F32, // AccType
|
||||
F16, // OutputType
|
||||
I64, // MBlockTileSize
|
||||
I64, // NBlockTileSize
|
||||
I128, // KBlockTileSize
|
||||
I16, // MWarpTileSize
|
||||
I16, // NWarpTileSize
|
||||
I128, // KWarpTileSize
|
||||
Intrawave,
|
||||
CompAsync>;
|
||||
|
||||
using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16>,
|
||||
CompAsyncConfig<Row, Col, Row, F16>,
|
||||
CompAsyncConfig<Col, Row, Row, F16>,
|
||||
@@ -232,6 +250,10 @@ using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16
|
||||
CompAsyncConfig<Row, Col, Row, F8>,
|
||||
CompAsyncConfig<Col, Row, Row, F8>,
|
||||
CompAsyncConfig<Col, Col, Row, F8>>;
|
||||
|
||||
using KernelTypesCompAsync16x16x128 = ::testing::Types<CompAsyncConfig16x16x128<Row, Col, Row, F4>,
|
||||
CompAsyncConfig16x16x128<Row, Col, Row, F8>>;
|
||||
|
||||
// clang-format off
|
||||
|
||||
using KernelTypesCompV6 = ::testing::Types<
|
||||
|
||||
@@ -7,6 +7,7 @@ using INT32 = ck_tile::int32_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using F4 = ck_tile::pk_fp4_t;
|
||||
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
|
||||
Reference in New Issue
Block a user