[CK_TILE] Enable MXFP6 for MX GEMM op (#5095)

## Motivation

Add support for MXFP6 in the MX GEMM op in CK-Tile.

Depends on https://github.com/ROCm/rocm-libraries/pull/4594

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Sami Remes
2026-03-20 03:07:47 +02:00
committed by GitHub
parent ba10383bb8
commit bf707265a8
13 changed files with 160 additions and 31 deletions

View File

@@ -22,7 +22,10 @@ struct pk_fp6_t
static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem;
element_type data_[vector_size]; // packed data
using type = pk_fp6_t<packed_size>;
CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value = 0)
CK_TILE_HOST_DEVICE constexpr pk_fp6_t() : data_{element_type{}} {}
CK_TILE_HOST_DEVICE constexpr explicit pk_fp6_t(int value)
{
for(size_t i = 0; i < vector_size; ++i)
{
@@ -59,13 +62,14 @@ struct pk_fp6_t
const int bit_offset = bit_pos % num_bits_vec_elem;
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
int32_t bits = pk.data_[arr_idx] >> bit_offset;
uint32_t bits = static_cast<uint32_t>(pk.data_[arr_idx]) >> bit_offset;
if(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang);
bits |= (static_cast<uint32_t>(pk.data_[arr_idx + 1]) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return bits & 0x3F;
return static_cast<int32_t>(bits & 0x3F);
}
CK_TILE_HOST_DEVICE int32_t unpack(const index_t i) const { return unpack(*this, i); }
@@ -97,6 +101,22 @@ struct pk_fp6_t
}
return sign == 1 ? -1 * result : result;
}
CK_TILE_HOST static int32_t float_to_fp6_e2m3(float val)
{
int32_t best = 0;
float best_err = 1e30f;
for(int32_t i = 0; i < 64; i++)
{
float err = std::fabs(val - fp6_e2m3_to_float(i));
if(err < best_err)
{
best = i;
best_err = err;
}
}
return best;
}
};
using pk_fp6x16_t = pk_fp6_t<16>;
@@ -105,5 +125,7 @@ template <>
struct numeric_traits<pk_fp6x16_t>
{
static constexpr int PackedSize = 16;
static constexpr int exp = 2;
static constexpr int mant = 3;
};
} // namespace ck_tile

View File

@@ -19,6 +19,14 @@
namespace ck_tile {
// buffer_load_dwordx3 to LDS uses a fixed 16-byte per-thread stride,
// padding each 12-byte element to 16 bytes in LDS.
template <typename T>
CK_TILE_HOST_DEVICE constexpr index_t lds_padded_sizeof()
{
return (sizeof(T) == 12) ? 16 : sizeof(T);
}
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
@@ -840,7 +848,10 @@ struct buffer_view<address_space_enum::lds,
{
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
scalar_per_t_vector * scalar_per_x_vector>;
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i + linear_offset]);
constexpr index_t padded_stride = lds_padded_sizeof<T>();
const char* base =
reinterpret_cast<const char*>(p_data_) + (i + linear_offset) * padded_stride;
auto rtn = *c_style_pointer_cast<const buf_t*>(base);
return bit_cast<X>(rtn);
}
#endif
@@ -872,7 +883,8 @@ struct buffer_view<address_space_enum::lds,
bool /*is_valid_element*/,
bool_constant<pre_nop> = {}) const
{
smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
constexpr index_t padded_stride = lds_padded_sizeof<T>();
smem_load<sizeof(X)>{}(dst, v_offset * padded_stride, i_offset * padded_stride);
}
template <typename X,

View File

@@ -631,21 +631,24 @@ struct tile_scatter_gather
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// buffer load with dwordx3 requires 128-bit alignment
constexpr index_t lds_stride = lds_padded_sizeof<LdsDataType>();
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
lds_stride;
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
lds_stride -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
lds_stride -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
@@ -780,9 +783,12 @@ struct tile_scatter_gather
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
// Calculate SMEM address using base pointer
CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr +
lds_coord.get_offset() / Traits::PackedSize +
lds_ys_offset / Traits::PackedSize;
// Use byte arithmetic for dwordx3 padding (12-byte elements use 16-byte LDS stride)
CK_TILE_LDS_ADDR LdsDataType* smem =
reinterpret_cast<CK_TILE_LDS_ADDR LdsDataType*>(
reinterpret_cast<CK_TILE_LDS_ADDR char*>(lds_base_ptr) +
(lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize *
lds_padded_sizeof<LdsDataType>());
const auto dram_ys_offset = [&]() {
if constexpr(static_move_ys)

View File

@@ -501,21 +501,23 @@ struct tile_window_with_static_distribution
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
constexpr index_t lds_stride = lds_padded_sizeof<LdsDataType>();
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
lds_stride;
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
lds_stride -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
lds_stride -
size_per_buf;
// Use VALU so the compiler can optimize redundant/repeated computations
@@ -628,8 +630,12 @@ struct tile_window_with_static_distribution
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
// Calculate SMEM address using base pointer
// Use byte arithmetic for dwordx3 padding (12-byte elements use 16-byte LDS stride)
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_base_ptr + (lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize;
reinterpret_cast<CK_TILE_LDS_ADDR LdsDataType*>(
reinterpret_cast<CK_TILE_LDS_ADDR char*>(lds_base_ptr) +
(lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize *
lds_padded_sizeof<LdsDataType>());
const auto dram_ys_offset = [&]() {
if constexpr(static_move_ys)

View File

@@ -61,6 +61,7 @@ CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1
tf32_t,
pk_fp4_t,
pk_fp4_raw_t,
pk_fp6x16_t,
pk_int4_t,
I8,
I32,
@@ -135,6 +136,7 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
tf32_t,
pk_fp4_t,
pk_fp4_raw_t,
pk_fp6x16_t,
pk_int4_t,
I8,
I32,

View File

@@ -169,6 +169,41 @@ struct FillUniformDistribution<ck_tile::pk_int4_t>
}
};
template <>
struct FillUniformDistribution<ck_tile::pk_fp6x16_t>
{
float a_{-2.f};
float b_{2.f};
std::optional<uint32_t> seed_{11939};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
while(first != last)
{
ck_tile::pk_fp6x16_t pk{};
for(ck_tile::index_t i = 0; i < ck_tile::pk_fp6x16_t::packed_size; ++i)
{
pk.pack(ck_tile::pk_fp6x16_t::float_to_fp6_e2m3(dis(gen)), i);
}
*first = pk;
++first;
}
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const
-> std::void_t<decltype(std::declval<const FillUniformDistribution&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
namespace impl {
// clang-format off

View File

@@ -146,9 +146,10 @@ struct GemmPipelineAgBgCrImplBase
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy!
constexpr index_t APackedSize = numeric_traits<OverrideADataType>::PackedSize;
constexpr index_t a_lds_block_space_size =
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size() / APackedSize;
constexpr index_t APackedSize = numeric_traits<OverrideADataType>::PackedSize;
constexpr index_t a_lds_block_space_size = lds_padded_sizeof<OverrideADataType>() *
a_lds_block_desc.get_element_space_size() /
APackedSize;
constexpr index_t a_lds_block_space_size_aligned =
integer_least_multiple(a_lds_block_space_size, 16);

View File

@@ -834,9 +834,10 @@ struct UniversalGemmBasePolicy
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A)) *
numeric_traits<A>::PackedSize;
return (KPack < VecElems) ? KPack : VecElems;
return ck_tile::min(KPack, VecElems);
}
template <typename Problem>
@@ -846,9 +847,10 @@ struct UniversalGemmBasePolicy
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B));
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B)) *
numeric_traits<B>::PackedSize;
return (KPack < VecElems) ? KPack : VecElems;
return ck_tile::min(KPack, VecElems);
}
template <typename Problem>
@@ -857,8 +859,10 @@ struct UniversalGemmBasePolicy
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
constexpr index_t smem_size_a = integer_least_multiple(
a_lds_block_desc.get_element_space_size() * sizeof(ADataType) / APackedSize, 16);
constexpr index_t smem_size_a =
integer_least_multiple(a_lds_block_desc.get_element_space_size() *
lds_padded_sizeof<ADataType>() / APackedSize,
16);
return smem_size_a;
}
@@ -871,8 +875,10 @@ struct UniversalGemmBasePolicy
typename Problem::BDataType>;
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t smem_size_b = integer_least_multiple(
b_lds_block_desc.get_element_space_size() * sizeof(BDataType) / BPackedSize, 16);
constexpr index_t smem_size_b =
integer_least_multiple(b_lds_block_desc.get_element_space_size() *
lds_padded_sizeof<BDataType>() / BPackedSize,
16);
return smem_size_b;
}

View File

@@ -442,10 +442,12 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
MWarp / BlockSize,
"BLdsTile size is wrong!");
static_assert(Policy::template GetSmemSizeA<Problem>() ==
MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize),
MPerBlock *
(KPerBlock * lds_padded_sizeof<ADataType>() / APackedSize),
"SmemSizeA size is wrong!");
static_assert(Policy::template GetSmemSizeB<Problem>() ==
(KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock,
(KPerBlock * lds_padded_sizeof<BDataType>() / BPackedSize) *
NPerBlock,
"SmemSizeB size is wrong!");
////////////// MX Scale register tiles (ping-pong buffers) /////////////////

View File

@@ -9,7 +9,8 @@ endif()
if(GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_mx_gemm_fp4 test_mx_gemm_fp4.cpp)
target_compile_options(test_ck_tile_mx_gemm_fp4 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_mx_gemm_fp6 test_mx_gemm_fp6.cpp)
target_compile_options(test_ck_tile_mx_gemm_fp6 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_ck_tile_mx_gemm_fp8 test_mx_gemm_fp8.cpp)
target_compile_options(test_ck_tile_mx_gemm_fp8 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS})
else()

View File

@@ -87,6 +87,13 @@ struct MXfp4_GemmConfig16 : MxGemmConfig
static constexpr ck_tile::index_t K_Tile = 256;
};
struct MXfp6_GemmConfig16 : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
};
struct MXfp8_GemmConfig16 : MxGemmConfig
{
static constexpr ck_tile::index_t M_Tile = 64;

View File

@@ -0,0 +1,30 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_mx_gemm_config.hpp"
#include "test_mx_gemm_util.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using MxFp6Types = ::testing::Types<
std::tuple<ck_tile::pk_fp6x16_t, ck_tile::pk_fp6x16_t, MXfp6_GemmConfig16, Row, Col, Row>>;
template <typename TypeParam>
class TestMxGemmFp6 : public TestMxGemmUtil<std::tuple_element_t<0, TypeParam>,
std::tuple_element_t<1, TypeParam>,
std::tuple_element_t<2, TypeParam>,
std::tuple_element_t<3, TypeParam>,
std::tuple_element_t<4, TypeParam>,
std::tuple_element_t<5, TypeParam>>
{
};
TYPED_TEST_SUITE(TestMxGemmFp6, MxFp6Types);
TYPED_TEST(TestMxGemmFp6, BasicSizes)
{
this->Run(64, 64, 256);
this->Run(128, 128, 256);
this->Run(64, 128, 512);
}

View File

@@ -4,7 +4,6 @@
#pragma once
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/check_err.hpp"