MX GEMM - FP6 Support in GEMM MX v3 Pipeline (#2481)

* Add GEMM MX BF6 example

* Fix BF6 type_convert

* Add type_convert for bf16x6

* Add compare operator to f4x2_pk_t

* Update README for 67_gemm_microscaling

* Fix host tensor initialization with integer values for FP8



[ROCm/composable_kernel commit: 518dc21ae8]
This commit is contained in:
Andriy Roshchenko
2025-07-11 13:07:05 -06:00
committed by GitHub
parent f3120e7526
commit a024e11036
11 changed files with 303 additions and 15 deletions

View File

@@ -13,6 +13,9 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_bf8)
add_example_executable(example_gemm_mx_fp6 gemm_mx_fp6.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp6)
add_example_executable(example_gemm_mx_bf6 gemm_mx_bf6.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_bf6)
add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp4)
@@ -62,3 +65,4 @@ example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS})
set(FP6_MXGEMM_OPTIONS)
list(APPEND FP6_MXGEMM_OPTIONS -mavx512f)
example_compile_options(example_gemm_mx_fp6 PRIVATE ${FP6_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_bf6 PRIVATE ${FP6_MXGEMM_OPTIONS})

View File

@@ -8,14 +8,16 @@ Custom verification parameters:
# arg2: initialization (0=constant values, 1=integer values, 2=decimal values)
# arg3: time kernel (0=no, 1=yes)
# arg4: verbosity (0=no info, 1=verbose info)
# arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC
# arg5 to 10: M(256x), N(256x), K(512x), StrideA, StrideB, StrideC
# arg11: KBatch
# arg12: warmup runs pre-timing
# arg13: repeat run count for timing
./bin/example_gemm_mx_fp8 1 1 0 1
```
Custom tensor shapes:
```bash
./bin/example_gemm_mx_fp8 1 2 1 0 128 128 256 -1 -1 -1 1
./bin/example_gemm_mx_fp8 1 2 1 0 256 256 512 -1 -1 -1 1 10 10
```
Default invocation:

View File

@@ -0,0 +1,101 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using ADataType = ck::bf6x16_pk_t;
using BDataType = ck::bf6x16_pk_t;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t DataPackedSize = 16; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 bf6 = 16 bf6x16_pk_t
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XPackedDataType, // AScaleDataType
BDataType, // BDataType
XPackedDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
128, // MPerBlock
128, // NPerBlock
KPerBlock, // KPerBlock
1, // AK1
1, // BK1
16, // MPerXDL
16, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
1, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
1, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
int main(int argc, char* argv[])
{
return run_mx_gemm_example<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
XPackedDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
ScaleBlockSize>(argc, argv)
? 0
: -1;
}

View File

@@ -100,8 +100,11 @@ bool parse_cmd_args(int argc,
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4: verbosity (0=no info, 1=verbose info)" << std::endl
<< "arg5 to 10: M(128x), N(128x), K(256x), StrideA, StrideB, StrideC" << std::endl
<< "arg11: KBatch" << std::endl;
<< "arg5 to 10: M(256x), N(256x), K(512x), StrideA, StrideB, StrideC" << std::endl
<< "arg11: KBatch" << std::endl
<< "arg12: warmup runs pre-timing" << std::endl
<< "arg13: repeat run count for timing" << std::endl;
return false;
}

View File

@@ -550,7 +550,14 @@ struct Tensor
auto dis_ = dis; // copy
g_.discard(ib_begin * BLOCK_SIZE * ck::packed_size_v<T>);
auto t_fn = [&]() {
if constexpr(ck::packed_size_v<T> == 1)
// As user can pass integer distribution in dis, we must ensure that the correct
// constructor/converter is called at all times. For f4/f6/f8 types, to ensure
// correct results, we convert from float to the target type. In these cases
// integer constructors are interpreted as direct initialization of the internal
// storage with binary values instead of treating integers as subset of floats.
if constexpr(ck::is_same_v<T, ck::f8_t> || ck::is_same_v<T, ck::bf8_t>)
return ck::type_convert<T>(static_cast<float>(fn(dis_(g_))));
else if constexpr(ck::packed_size_v<T> == 1)
return ck::type_convert<T>(fn(dis_(g_)));
else if constexpr(ck::is_same_v<T, ck::f4x2_pk_t>)
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(

View File

@@ -1118,6 +1118,54 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
#endif
}
template <class FloatC>
__device__ static void Run(const bf6x16x2_t& reg_a,
const int32_t scale_a,
const bf6x16x2_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
using arg_type = int32x8_t;
arg_type arg_a{
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][0]),
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][1]),
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][2]),
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][0]),
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][1]),
static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][2]),
0,
0};
arg_type arg_b{
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][0]),
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][1]),
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][2]),
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][0]),
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][1]),
static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][2]),
0,
0};
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_a,
arg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
OpselA, // OPSEL
scale_a,
OpselB, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a,
const int32_t scale_a,

View File

@@ -60,6 +60,17 @@ struct f4x2_pk_t
{
return (x0 << 4) | (x1 & 0b00001111);
}
// Compare operator
__host__ __device__ friend bool operator==(const f4x2_pk_t& lhs, const f4x2_pk_t& rhs)
{
return lhs.data == rhs.data;
}
__host__ __device__ friend bool operator!=(const f4x2_pk_t& lhs, const f4x2_pk_t& rhs)
{
return !(lhs == rhs);
}
};
template <typename BitType, index_t pk_size>

View File

@@ -2254,8 +2254,9 @@ using f6x16x2_t = typename vector_type<f6x16_pk_t, 2>::type;
using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
// bf6
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
using bf6x16x2_t = typename vector_type<bf6x16_pk_t, 2>::type;
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
// e8m0
using e8m0x4_bexp_t = typename vector_type<e8m0_bexp_t, 4>::type;

View File

@@ -2102,17 +2102,15 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1
float float_array[32];
} in{x};
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} out{};
using array_type = uint8_t __attribute__((ext_vector_type(32)));
array_type uint8_array;
// collect the 6-bit values into an array
ck::static_for<0, 32, 1>{}([&](auto i) {
out.bf6_array[i] = utils::sat_convert_to_type<bf6_t>(in.float_array[i] / scale);
uint8_array[static_cast<index_t>(i)] =
utils::sat_convert_to_type<bf6_t>(in.float_array[i] / scale);
});
return out.bf6_vector;
return bf6x32_t{bf6x32_pk_t{uint8_array}};
#endif
}
@@ -2257,6 +2255,37 @@ inline __host__ __device__ bf6x32_pk_t type_convert<bf6x32_pk_t, float32_t>(floa
return static_cast<bf6x32_pk_t>(type_convert<bf6x32_t>(x));
}
template <>
inline __host__ __device__ bf6x16_t type_convert<bf6x16_t, float16_t>(float16_t x)
{
union
{
float16_t v16x2[2];
float32_t v32;
} in{{x, x}};
union
{
bf6x32_t v32;
bf6x16_t v16x2[2];
} out{};
#if CK_USE_SR_F6_CONVERSION
out.v32 = bf6_convert_sr(in.v32);
#else
out.v32 = bf6_convert_rne(in.v32);
#endif
return out.v16x2[0];
}
template <>
inline __host__ __device__ bf6x16_pk_t type_convert<bf6x16_pk_t, float16_t>(float16_t x)
{
return static_cast<bf6x16_pk_t>(type_convert<bf6x16_t>(x));
}
/**
* @brief Specializes the type conversion template for converting a bf6_t value to float.
*
@@ -2329,6 +2358,32 @@ inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t
return out.float_vector;
#endif
}
template <>
inline __host__ __device__ float16_t type_convert<float16_t, bf6x16_t>(bf6x16_t x)
{
union
{
bf6x16_t v16x2[2];
bf6x32_t v32;
} in{{x, x}};
union
{
float16_t v16x2[2];
float32_t v32;
} out{};
out.v32 = type_convert<float32_t>(in.v32);
return out.v16x2[0];
}
template <>
inline __host__ __device__ float16_t type_convert<float16_t, bf6x16_pk_t>(bf6x16_pk_t x)
{
return type_convert<float16_t>(static_cast<bf6x16_t>(x));
}
#endif
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
template <typename Y, typename X, size_t NumElems>

View File

@@ -60,6 +60,7 @@ if(GPU_TARGETS MATCHES "gfx950")
add_gtest_executable(test_bf6 test_bf6.cpp)
if(result EQUAL 0)
target_compile_options(test_bf6 PRIVATE -mavx512f)
target_link_libraries(test_bf6 PRIVATE utility)
endif()
add_dependencies(test_mx_data_types test_bf6)

View File

@@ -6,6 +6,7 @@
#include "ck/utility/type_convert.hpp"
#include "ck/utility/env.hpp"
#include "ck/utility/scaled_type_convert.hpp"
#include "ck/library/utility/device_memory.hpp"
using ck::bf6_convert_rne;
using ck::bf6_convert_sr;
@@ -455,3 +456,57 @@ TEST(BF6, TestAllValues)
}
});
}
__global__ void test_bf6_convert_rne(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
ck::float32_t float32_in(1.0f);
ck::float32_t float32_out{};
auto bf6x32_vec = bf6_convert_rne(float32_in);
float32_out = type_convert<ck::float32_t>(bf6x32_vec);
ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32_out[static_cast<int>(ii)]; });
i = N;
}
TEST(MXBF6, DeviceBF6ConvertRNE)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_bf6_convert_rne<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
EXPECT_EQ(N, completed);
ck::static_for<0, N, 1>{}(
[&](auto ii) { EXPECT_EQ(out[static_cast<int>(ii)], 1.0f) << "ii: " << ii << std::endl; });
auto bf6x32_vec_tc = ck::type_convert<bf6x32_pk_t>(ck::float32_t(1.0f));
auto bf6x32_vec_cnstr = bf6x32_pk_t(0x0C);
EXPECT_EQ(bf6x32_vec_tc, bf6x32_vec_cnstr);
}