[rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)

[CK_TILE] MX GEMM non-preshuffled RCR layout

## Motivation

Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled
layouts using async pipeline.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Sami Remes
2026-03-10 20:12:43 +00:00
committed by assistant-librarian[bot]
parent b8def2c724
commit 8f27f65d44
40 changed files with 2729 additions and 43 deletions

View File

@@ -57,6 +57,7 @@ add_subdirectory(add_rmsnorm2d_rdquant)
# add_subdirectory(layernorm2d)
# add_subdirectory(rmsnorm2d)
add_subdirectory(gemm_block_scale)
add_subdirectory(gemm_mx)
add_subdirectory(utility)
add_subdirectory(warp_gemm)
add_subdirectory(reduce)

View File

@@ -20,3 +20,21 @@ TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesCompAsync);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME
template <typename T>
class TestCkTileGemmPipelineCompAsync16x16x128
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineCompAsync16x16x128<T>>
{
public:
static constexpr bool check_data_type() { return true; }
};
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompAsync16x16x128, KernelTypesCompAsync16x16x128);
TYPED_TEST(TestCkTileGemmPipelineCompAsync16x16x128, QuickTest)
{
constexpr int M = 1024;
constexpr int N = 1024;
constexpr int K = 1024;
this->template RunSingle<false, false, false, false>(M, N, K, 0, 0, 0, 1);
}

View File

@@ -29,6 +29,7 @@ using NonPersistent = std::false_type;
using I16 = ck_tile::number<16>;
using I32 = ck_tile::number<32>;
using I64 = ck_tile::number<64>;
using I128 = ck_tile::number<128>;
using I256 = ck_tile::number<256>;
// clang-format off
@@ -224,6 +225,23 @@ using CompAsyncConfig = std::tuple<ALayout,
Intrawave,
CompAsync>;
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
using CompAsyncConfig16x16x128 = std::tuple<ALayout,
BLayout,
CLayout,
InputType, // AType
InputType, // BType
F32, // AccType
F16, // OutputType
I64, // MBlockTileSize
I64, // NBlockTileSize
I128, // KBlockTileSize
I16, // MWarpTileSize
I16, // NWarpTileSize
I128, // KWarpTileSize
Intrawave,
CompAsync>;
using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16>,
CompAsyncConfig<Row, Col, Row, F16>,
CompAsyncConfig<Col, Row, Row, F16>,
@@ -232,6 +250,10 @@ using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16
CompAsyncConfig<Row, Col, Row, F8>,
CompAsyncConfig<Col, Row, Row, F8>,
CompAsyncConfig<Col, Col, Row, F8>>;
using KernelTypesCompAsync16x16x128 = ::testing::Types<CompAsyncConfig16x16x128<Row, Col, Row, F4>,
CompAsyncConfig16x16x128<Row, Col, Row, F8>>;
// clang-format off
using KernelTypesCompV6 = ::testing::Types<

View File

@@ -7,6 +7,7 @@ using INT32 = ck_tile::int32_t;
using F16 = ck_tile::half_t;
using F32 = float;
using F8 = ck_tile::fp8_t;
using F4 = ck_tile::pk_fp4_t;
using BF16 = ck_tile::bf16_t;
using BF8 = ck_tile::bf8_t;

View File

@@ -0,0 +1,17 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(TEST_MX_GEMM_COMPILE_OPTIONS -Wno-undefined-func-template)
if(CK_USE_OCP_FP8)
list(APPEND TEST_MX_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_mx_gemm_fp4 test_mx_gemm_fp4.cpp)
target_compile_options(test_ck_tile_mx_gemm_fp4 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_mx_gemm_fp8 test_mx_gemm_fp8.cpp)
target_compile_options(test_ck_tile_mx_gemm_fp8 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile MX GEMM tests for current target")
endif()

View File

@@ -0,0 +1,95 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp"
template <typename ScaleM, typename ScaleN>
struct MXGemmHostArgs : ck_tile::UniversalGemmHostArgs<1, 1, 0>
{
using Base = ck_tile::UniversalGemmHostArgs<1, 1, 0>;
MXGemmHostArgs(const void* a_ptr,
const void* b_ptr,
void* c_ptr_,
ck_tile::index_t k_batch_,
ck_tile::index_t M_,
ck_tile::index_t N_,
ck_tile::index_t K_,
ck_tile::index_t stride_A_,
ck_tile::index_t stride_B_,
ck_tile::index_t stride_C_,
ScaleM scale_m_,
ScaleN scale_n_)
: Base({a_ptr},
{b_ptr},
{},
c_ptr_,
k_batch_,
M_,
N_,
K_,
{stride_A_},
{stride_B_},
{},
stride_C_),
scale_m(scale_m_),
scale_n(scale_n_)
{
}
ScaleM scale_m;
ScaleN scale_n;
};
struct MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 512;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool Preshuffle = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
struct MXfp4_GemmConfig16 : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
};
struct MXfp8_GemmConfig16 : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
};

View File

@@ -0,0 +1,30 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_mx_gemm_config.hpp"
#include "test_mx_gemm_util.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using MxFp4Types = ::testing::Types<
std::tuple<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, MXfp4_GemmConfig16, Row, Col, Row>>;
template <typename TypeParam>
class TestMxGemmFp4 : public TestMxGemmUtil<std::tuple_element_t<0, TypeParam>,
std::tuple_element_t<1, TypeParam>,
std::tuple_element_t<2, TypeParam>,
std::tuple_element_t<3, TypeParam>,
std::tuple_element_t<4, TypeParam>,
std::tuple_element_t<5, TypeParam>>
{
};
TYPED_TEST_SUITE(TestMxGemmFp4, MxFp4Types);
TYPED_TEST(TestMxGemmFp4, BasicSizes)
{
this->Run(64, 64, 256);
this->Run(128, 128, 256);
this->Run(64, 128, 512);
}

View File

@@ -0,0 +1,30 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_mx_gemm_config.hpp"
#include "test_mx_gemm_util.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using MxFp8Types =
::testing::Types<std::tuple<ck_tile::fp8_t, ck_tile::fp8_t, MXfp8_GemmConfig16, Row, Col, Row>>;
template <typename TypeParam>
class TestMxGemmFp8 : public TestMxGemmUtil<std::tuple_element_t<0, TypeParam>,
std::tuple_element_t<1, TypeParam>,
std::tuple_element_t<2, TypeParam>,
std::tuple_element_t<3, TypeParam>,
std::tuple_element_t<4, TypeParam>,
std::tuple_element_t<5, TypeParam>>
{
};
TYPED_TEST_SUITE(TestMxGemmFp8, MxFp8Types);
TYPED_TEST(TestMxGemmFp8, BasicSizes)
{
this->Run(64, 64, 256);
this->Run(128, 128, 256);
this->Run(64, 128, 512);
}

View File

@@ -0,0 +1,97 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
#include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp"
#include "test_mx_gemm_config.hpp"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename ScaleM,
typename ScaleN,
bool persistent,
bool Splitk>
float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using MXGemmTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using MXPipelineProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
MXGemmTraits,
GemmConfig::Scheduler>;
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
MXPipelineProblem::TransposeC>>;
using Kernel = ck_tile::MXGemmKernel<TilePartitioner, MXGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(std::array<const void*, 1>{args.as_ptr},
std::array<const void*, 1>{args.bs_ptr},
std::array<const void*, 0>{},
args.e_ptr,
args.k_batch,
args.M,
args.N,
args.K,
std::array<ck_tile::index_t, 1>{args.stride_As},
std::array<ck_tile::index_t, 1>{args.stride_Bs},
std::array<ck_tile::index_t, 0>{},
args.stride_E,
args.scale_m,
args.scale_n);
const auto kernel = ck_tile::make_kernel<Kernel::kBlockPerCu>(
Kernel{}, Kernel::GridSize(kargs), Kernel::BlockSize(), 0, kargs);
return ck_tile::launch_kernel(s, kernel);
}

View File

@@ -0,0 +1,137 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "test_mx_gemm_config.hpp"
#include "test_mx_gemm_instance.hpp"
template <typename Layout>
static constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol_mx(ck_tile::index_t K, float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(K);
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value, K);
return ck_tile::make_tuple(rtol, atol);
}
template <typename ADataType,
typename BDataType,
typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout>
class TestMxGemmUtil : public ::testing::Test
{
protected:
using AccDataType = float;
using CDataType = ck_tile::fp16_t;
using ScaleType = ck_tile::e8m0_t;
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>;
using ScaleN = ck_tile::MXScalePointer<ScaleType, 1, 32>;
void Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, int seed = 1234)
{
const ck_tile::index_t scale_k_size = K / 32;
const ck_tile::index_t stride_A =
ck_tile::get_default_stride(M, K, 0, is_row_major(ALayout{}));
const ck_tile::index_t stride_B =
ck_tile::get_default_stride(K, N, 0, is_row_major(BLayout{}));
const ck_tile::index_t stride_C =
ck_tile::get_default_stride(M, N, 0, is_row_major(CLayout{}));
const ck_tile::index_t stride_scale_a =
ck_tile::get_default_stride(M, scale_k_size, 0, is_row_major(ALayout{}));
const ck_tile::index_t stride_scale_b =
ck_tile::get_default_stride(scale_k_size, N, 0, is_row_major(BLayout{}));
ck_tile::HostTensor<ADataType> a_host(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_host(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(BLayout{})));
ck_tile::HostTensor<CDataType> c_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<ScaleType> scale_a_host(ck_tile::host_tensor_descriptor(
M, scale_k_size, stride_scale_a, is_row_major(ALayout{})));
ck_tile::HostTensor<ScaleType> scale_b_host(ck_tile::host_tensor_descriptor(
scale_k_size, N, stride_scale_b, is_row_major(BLayout{})));
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f, seed++}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f, seed++}(b_host);
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_a_host);
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_b_host);
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem scale_a_dev_buf(scale_a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem scale_b_dev_buf(scale_b_host.get_element_space_size_in_bytes());
a_dev_buf.ToDevice(a_host.data());
b_dev_buf.ToDevice(b_host.data());
c_dev_buf.SetZero();
scale_a_dev_buf.ToDevice(scale_a_host.data());
scale_b_dev_buf.ToDevice(scale_b_host.data());
ScaleM scale_m(reinterpret_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()));
ScaleN scale_n(reinterpret_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()));
MXGemmHostArgs<ScaleM, ScaleN> args(a_dev_buf.GetDeviceBuffer(),
b_dev_buf.GetDeviceBuffer(),
c_dev_buf.GetDeviceBuffer(),
1,
M,
N,
K,
stride_A,
stride_B,
stride_C,
scale_m,
scale_n);
mx_gemm_calc<GemmConfig,
ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
ScaleM,
ScaleN,
true,
false>(args, ck_tile::stream_config{nullptr, true, 1, 0, 1, true, true, 50});
c_dev_buf.FromDevice(c_host.data());
ck_tile::HostTensor<CDataType> c_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_ref.SetZero();
ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
a_host, b_host, c_ref, scale_a_host, scale_b_host);
const float max_accumulated_value = ck_tile::type_convert<float>(c_ref.max());
const auto rtol_atol = calculate_rtol_atol_mx<ADataType, BDataType, AccDataType, CDataType>(
K, max_accumulated_value);
const double rtol = rtol_atol.at(ck_tile::number<0>{});
const double atol = rtol_atol.at(ck_tile::number<1>{});
bool pass = ck_tile::check_err(c_host, c_ref, "MX GEMM: Incorrect results!", rtol, atol);
EXPECT_TRUE(pass);
}
};