[rocm-libraries] ROCm/rocm-libraries#4302 (commit e62bd8a)

[CK_TILE] add tf32 support
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

TF32 is added in CK on gfx942 and gfx950. This PR is to initiate tf32 in
CK_TILE on gfx942 and gfx950.

## Checklist

Please put an into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run  on all changed files
- [ ] Any dependent changes have been merged

## Discussion
This commit is contained in:
yinglu
2026-03-19 09:19:06 +00:00
committed by assistant-librarian[bot]
parent 652d3456ca
commit d460ab35b6
30 changed files with 1164 additions and 260 deletions

View File

@@ -7,6 +7,8 @@ endif()
if(GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp)
add_gtest_executable(test_ck_tile_mx_scale test_mx_scale.cpp)
add_gtest_executable(test_ck_tile_tf32 test_tf32.cpp)
add_gtest_executable(test_ck_tile_bf16_f32_convert test_bf16_f32_convert.cpp)
endif()
if(CK_USE_OCP_FP8 OR CK_USE_FNUZ_FP8)

View File

@@ -0,0 +1,248 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gtest/gtest.h"
#include <cmath>
#include <vector>
#include <hip/hip_runtime.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
using ck_tile::bf16_to_float;
using ck_tile::bf16x2_t;
using ck_tile::bfloat16_t;
using ck_tile::bit_cast;
using ck_tile::float_to_bf16;
using ck_tile::fp32x2_t;
// =====================================================================
// Tests for bf16x2_to_fp32x2 (host-side, always available)
// =====================================================================
TEST(Bf16F32Convert, Bf16x2ToFp32x2_BasicValues)
{
auto a = float_to_bf16(1.0f);
auto b = float_to_bf16(-2.5f);
bf16x2_t packed{a, b};
fp32x2_t result = ck_tile::bf16x2_to_fp32x2(packed);
EXPECT_FLOAT_EQ(result[0], bf16_to_float(a));
EXPECT_FLOAT_EQ(result[1], bf16_to_float(b));
}
TEST(Bf16F32Convert, Bf16x2ToFp32x2_Zeros)
{
auto pos_zero = float_to_bf16(0.0f);
auto neg_zero = float_to_bf16(-0.0f);
bf16x2_t packed{pos_zero, neg_zero};
fp32x2_t result = ck_tile::bf16x2_to_fp32x2(packed);
EXPECT_FLOAT_EQ(result[0], 0.0f);
EXPECT_TRUE(std::signbit(result[1]));
EXPECT_FLOAT_EQ(result[1], -0.0f);
}
TEST(Bf16F32Convert, Bf16x2ToFp32x2_LargeSmall)
{
auto big = float_to_bf16(65504.0f);
auto small = float_to_bf16(0.00390625f);
bf16x2_t packed{big, small};
fp32x2_t result = ck_tile::bf16x2_to_fp32x2(packed);
EXPECT_FLOAT_EQ(result[0], bf16_to_float(big));
EXPECT_FLOAT_EQ(result[1], bf16_to_float(small));
}
TEST(Bf16F32Convert, Bf16x2ToFp32x2_RoundTrip)
{
const float test_values[] = {1.0f, -1.0f, 0.5f, 3.14f, 100.0f, -42.0f, 0.001f};
for(float v : test_values)
{
auto bf = float_to_bf16(v);
float expected = bf16_to_float(bf);
bf16x2_t packed{bf, bf};
fp32x2_t result = ck_tile::bf16x2_to_fp32x2(packed);
EXPECT_FLOAT_EQ(result[0], expected) << "v=" << v;
EXPECT_FLOAT_EQ(result[1], expected) << "v=" << v;
}
}
// =====================================================================
// Tests for fp32x2_to_bf16x2 (host-side)
// =====================================================================
TEST(Bf16F32Convert, Fp32x2ToBf16x2_BasicValues)
{
fp32x2_t input{1.5f, -3.0f};
bf16x2_t result = ck_tile::fp32x2_to_bf16x2(input);
EXPECT_FLOAT_EQ(bf16_to_float(result[0]), bf16_to_float(float_to_bf16(1.5f)));
EXPECT_FLOAT_EQ(bf16_to_float(result[1]), bf16_to_float(float_to_bf16(-3.0f)));
}
// =====================================================================
// Device tests for cvt_pk_bf16_f32 and convert_float_to_bf16_pairs
// =====================================================================
struct CvtPkBf16F32Result
{
bfloat16_t r0;
bfloat16_t r1;
};
__global__ void kernel_cvt_pk_bf16_f32(const float* in, CvtPkBf16F32Result* out, int n)
{
int idx = threadIdx.x;
if(idx < n)
{
bf16x2_t result = ck_tile::cvt_pk_bf16_f32(in[2 * idx], in[2 * idx + 1]);
out[idx].r0 = result[0];
out[idx].r1 = result[1];
}
}
TEST(Bf16F32Convert, CvtPkBf16F32_Device)
{
const std::vector<float> host_in = {1.0f, -1.0f, 0.0f, 3.14f, 100.0f, -0.5f, 42.0f, 0.001f};
const int num_pairs = host_in.size() / 2;
ck_tile::DeviceMem in_buf(host_in.size() * sizeof(float));
ck_tile::DeviceMem out_buf(num_pairs * sizeof(CvtPkBf16F32Result));
in_buf.ToDevice(host_in.data());
kernel_cvt_pk_bf16_f32<<<1, num_pairs>>>(
static_cast<const float*>(in_buf.GetDeviceBuffer()),
static_cast<CvtPkBf16F32Result*>(out_buf.GetDeviceBuffer()),
num_pairs);
(void)hipDeviceSynchronize();
std::vector<CvtPkBf16F32Result> host_out(num_pairs);
out_buf.FromDevice(host_out.data());
for(int i = 0; i < num_pairs; i++)
{
float ref0 = bf16_to_float(float_to_bf16(host_in[2 * i]));
float ref1 = bf16_to_float(float_to_bf16(host_in[2 * i + 1]));
EXPECT_FLOAT_EQ(bf16_to_float(host_out[i].r0), ref0) << "pair=" << i << " elem=0";
EXPECT_FLOAT_EQ(bf16_to_float(host_out[i].r1), ref1) << "pair=" << i << " elem=1";
}
}
// =====================================================================
// Device test for convert_float_to_bf16_pairs
// =====================================================================
template <int VecSize>
struct Bf16PairsResult
{
bfloat16_t big[VecSize];
bfloat16_t small_val[VecSize];
};
template <int VecSize>
__global__ void kernel_convert_float_to_bf16_pairs(const float* in, Bf16PairsResult<VecSize>* out)
{
using float_vec_t = ck_tile::ext_vector_t<float, VecSize>;
using bf16_vec_t = ck_tile::ext_vector_t<bfloat16_t, VecSize>;
float_vec_t reg_f32;
for(int i = 0; i < VecSize; i++)
reg_f32[i] = in[i];
bf16_vec_t reg_big, reg_small;
ck_tile::convert_float_to_bf16_pairs<VecSize>(reg_f32, reg_big, reg_small);
for(int i = 0; i < VecSize; i++)
{
out[0].big[i] = reg_big[i];
out[0].small_val[i] = reg_small[i];
}
}
template <int VecSize>
void test_convert_float_to_bf16_pairs_device()
{
static_assert(VecSize >= 2 && VecSize % 2 == 0);
std::vector<float> host_in(VecSize);
// Use diverse values: mix of exact and non-exact bf16 representable numbers
const float base_vals[] = {1.1f, -2.3f, 0.7f, 100.1f, -0.001f, 42.42f, 3.14f, -7.77f};
for(int i = 0; i < VecSize; i++)
host_in[i] = base_vals[i % 8];
ck_tile::DeviceMem in_buf(VecSize * sizeof(float));
ck_tile::DeviceMem out_buf(sizeof(Bf16PairsResult<VecSize>));
in_buf.ToDevice(host_in.data());
kernel_convert_float_to_bf16_pairs<VecSize>
<<<1, 1>>>(static_cast<const float*>(in_buf.GetDeviceBuffer()),
static_cast<Bf16PairsResult<VecSize>*>(out_buf.GetDeviceBuffer()));
(void)hipDeviceSynchronize();
Bf16PairsResult<VecSize> host_out;
out_buf.FromDevice(&host_out);
for(int i = 0; i < VecSize; i++)
{
float orig = host_in[i];
float big_f = bf16_to_float(host_out.big[i]);
// big should match scalar float_to_bf16
float ref_big = bf16_to_float(float_to_bf16(orig));
EXPECT_FLOAT_EQ(big_f, ref_big) << "VecSize=" << VecSize << " i=" << i;
// small should match float_to_bf16(orig - big)
float ref_small = bf16_to_float(float_to_bf16(orig - ref_big));
float small_f = bf16_to_float(host_out.small_val[i]);
EXPECT_FLOAT_EQ(small_f, ref_small) << "VecSize=" << VecSize << " i=" << i;
// big + small should be closer to orig than big alone
float reconstructed = big_f + small_f;
EXPECT_LE(std::fabs(reconstructed - orig), std::fabs(big_f - orig) + 1e-10f)
<< "VecSize=" << VecSize << " i=" << i;
}
}
TEST(Bf16F32Convert, ConvertFloatToBf16Pairs_Vec2) { test_convert_float_to_bf16_pairs_device<2>(); }
TEST(Bf16F32Convert, ConvertFloatToBf16Pairs_Vec4) { test_convert_float_to_bf16_pairs_device<4>(); }
TEST(Bf16F32Convert, ConvertFloatToBf16Pairs_Vec8) { test_convert_float_to_bf16_pairs_device<8>(); }
// =====================================================================
// 3x BF16 multiply-accumulate precision test
// =====================================================================
TEST(Bf16F32Convert, ThreeBf16MulAccPrecision)
{
// Verify that a_big*b_big + a_small*b_big + a_big*b_small is more precise
// than a single bf16(a)*bf16(b) for non-exact values
const float test_pairs[][2] = {
{1.1f, 2.3f}, {3.14f, -2.71f}, {0.123f, 456.789f}, {-100.1f, 0.99f}};
for(const auto& pair : test_pairs)
{
float a = pair[0];
float b = pair[1];
float a_big_f = bf16_to_float(float_to_bf16(a));
float a_small_f = bf16_to_float(float_to_bf16(a - a_big_f));
float b_big_f = bf16_to_float(float_to_bf16(b));
float b_small_f = bf16_to_float(float_to_bf16(b - b_big_f));
float exact = a * b;
float single_bf16 = a_big_f * b_big_f;
float three_bf16 = a_big_f * b_big_f + a_small_f * b_big_f + a_big_f * b_small_f;
float err_single = std::fabs(exact - single_bf16);
float err_three = std::fabs(exact - three_bf16);
EXPECT_LE(err_three, err_single + 1e-10f)
<< "a=" << a << " b=" << b << " exact=" << exact << " single=" << single_bf16
<< " three=" << three_bf16;
}
}

View File

@@ -0,0 +1,86 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gtest/gtest.h"
#include <cmath>
#include <cstring>
#include <limits>
#include "ck_tile/core.hpp"
using ck_tile::bit_cast;
using ck_tile::numeric_traits;
using ck_tile::tf32_rounding_mode;
using ck_tile::tf32_t;
using ck_tile::type_convert;
static uint32_t to_bits(float x) { return bit_cast<uint32_t>(x); }
static float from_bits(uint32_t i) { return bit_cast<float>(i); }
TEST(ConvertTest, NumericTraits)
{
EXPECT_EQ(numeric_traits<tf32_t>::exp, 8);
EXPECT_EQ(numeric_traits<tf32_t>::mant, 10);
EXPECT_EQ(numeric_traits<tf32_t>::bias, 127);
EXPECT_EQ(numeric_traits<tf32_t>::PackedSize, 1);
}
TEST(ConvertTest, ToTf32Trunc)
{
// exact values (low 13 bits already zero)
EXPECT_EQ(to_bits(type_convert<tf32_t>(1.0f)), 0x3F800000u); // 1.0f
EXPECT_EQ(to_bits(type_convert<tf32_t>(-1.0f)), 0xBF800000u); // -1.0f
EXPECT_EQ(to_bits(type_convert<tf32_t>(0.0f)), 0x00000000u); // +0.0f
EXPECT_EQ(to_bits(type_convert<tf32_t>(-0.0f)), 0x80000000u); // -0.0f
EXPECT_EQ(to_bits(type_convert<tf32_t>(2.0f)), 0x40000000u); // 2.0f
EXPECT_EQ(to_bits(type_convert<tf32_t>(0.5f)), 0x3F000000u); // 0.5f
// truncation zeros the low 13 mantissa bits
EXPECT_EQ(to_bits(type_convert<tf32_t>(1.1f)), 0x3F8CC000u); // 1.1f (0x3F8CCCCD)
EXPECT_EQ(to_bits(type_convert<tf32_t>(3.14159265358979323846f)),
0x40490000u); // pi (0x40490FDB)
EXPECT_EQ(to_bits(type_convert<tf32_t>(123.456f)),
0x42F6E000u); // 123.456f (0x42F6E979)
EXPECT_EQ(to_bits(type_convert<tf32_t>(-3.14f)), 0xC048E000u); // -3.14f (0xC048F5C3)
// special values
EXPECT_EQ(to_bits(type_convert<tf32_t>(std::numeric_limits<float>::infinity())), 0x7F800000u);
EXPECT_EQ(to_bits(type_convert<tf32_t>(-std::numeric_limits<float>::infinity())), 0xFF800000u);
EXPECT_TRUE(std::isnan(type_convert<tf32_t>(std::numeric_limits<float>::quiet_NaN())));
EXPECT_EQ(to_bits(type_convert<tf32_t>(std::numeric_limits<float>::denorm_min())), 0x00000000u);
// property: low 13 bits must be zero, top 19 bits preserved
for(float val : {1.0f, 1.5f, 2.0f, 0.1f, 100.0f, -42.5f, 1e10f, 1e-10f})
{
uint32_t orig = to_bits(val);
uint32_t tf32 = to_bits(type_convert<tf32_t>(val));
EXPECT_EQ(tf32 & 0xFFFFE000u, tf32) << "val=" << val;
EXPECT_EQ(orig & 0xFFFFE000u, tf32) << "val=" << val;
}
}
TEST(ConvertTest, ToTf32Rtne)
{
// exact values (low 13 bits already zero)
EXPECT_EQ(to_bits(type_convert<tf32_t, tf32_rounding_mode::rne>(1.0f)),
0x3F800000u); // 1.0f
EXPECT_EQ(to_bits(type_convert<tf32_t, tf32_rounding_mode::rne>(-1.0f)),
0xBF800000u); // -1.0f
EXPECT_EQ(to_bits(type_convert<tf32_t, tf32_rounding_mode::rne>(0.0f)),
0x00000000u); // +0.0f
// past midpoint (bit12 + bit11 set) -> rounds up
float val = from_bits(0x3F801800u);
EXPECT_EQ(to_bits(type_convert<tf32_t, tf32_rounding_mode::rne>(val)), 0x3F802000u);
// special values (keep the same as float)
EXPECT_EQ(to_bits(type_convert<tf32_t, tf32_rounding_mode::rne>(
std::numeric_limits<float>::infinity())),
0x7F800000u); // infinity in float is 0x7F800000
EXPECT_EQ(to_bits(type_convert<tf32_t, tf32_rounding_mode::rne>(
-std::numeric_limits<float>::infinity())),
0xFF800000u); // negative infinity in float is 0xFF800000
EXPECT_TRUE(std::isnan(type_convert<tf32_t, tf32_rounding_mode::rne>(
std::numeric_limits<float>::quiet_NaN()))); // quiet NaN in float is 0x7FC00000
}

View File

@@ -46,8 +46,8 @@ test_cshuffle_epilogue_kernel(const typename Problem::AccDataType* __restrict__
__shared__ char smem[Epilogue::GetSmemSize()];
// Create accumulator tile with GEMM accumulator distribution (matches BlockGemm)
using WG = ck_tile::WarpGemmDispatcher<typename Epilogue::ADataType,
typename Epilogue::BDataType,
using WG = ck_tile::WarpGemmDispatcher<typename Epilogue::ATypeToUse,
typename Epilogue::BTypeToUse,
typename Problem::AccDataType,
Problem::MPerXdl,
Problem::NPerXdl,

View File

@@ -46,8 +46,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
add_gtest_executable(test_ck_tile_gemm_pipeline_comp_async test_gemm_pipeline_comp_async.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_comp_async PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS})
add_gtest_executable(test_ck_tile_gemm_pipeline_tf32_mem test_gemm_pipeline_tf32_mem.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_tf32_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
list(APPEND CK_TILE_GEMM_TEST_TARGETS
test_ck_tile_gemm_pipeline_comp_async
test_ck_tile_gemm_pipeline_tf32_mem
)
add_gtest_executable(test_ck_tile_gemm_pipeline_comp_async_eight_waves test_gemm_pipeline_comp_async_eight_waves.cpp)

View File

@@ -320,4 +320,14 @@ using KernelTypesPersistentWmma = ::testing::Types<
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3, NonPersistent>
>;
// TF32 (gfx950 only): 3x bf16 MFMA emulation, uses float buffers with tf32_t compute type
// Tile: 128x128x64, Warp tile: 32x32x16
using KernelTypesTf32Mem = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, N_TileSize, K_TileSize, Scheduler, PipelineType
std::tuple< Row, Row, Row, TF32, TF32, F32, F32, I128, I128, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Row, Row, Row, TF32, TF32, F32, F32, I128, I128, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Row, Col, Row, TF32, TF32, F32, F32, I128, I128, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Row, Col, Row, TF32, TF32, F32, F32, I128, I128, I64, I32, I32, I16, Interwave, Mem>
>;
// clang-format on

View File

@@ -13,3 +13,5 @@ using BF16 = ck_tile::bf16_t;
using BF8 = ck_tile::bf8_t;
using I4 = ck_tile::pk_int4_t;
using TF32 = ck_tile::tf32_t;

View File

@@ -0,0 +1,22 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_util.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineTf32Mem
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineTf32Mem<T>>
{
public:
static constexpr bool check_data_type() { return true; }
};
#define TEST_SUITE_NAME TestCkTileGemmPipelineTf32Mem
TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesTf32Mem);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -135,6 +135,10 @@ class TestCkTileGemmPipeline : public ::testing::Test
static constexpr bool Persistent =
ck_tile::tuple_element_or_default_t<Tuple, 15, std::false_type>::value;
// TF32 uses tf32_t as compute type but float as buffer/storage type
using ADataTypeBuf = ck_tile::if_select_t<ADataType, ck_tile::tf32_t, float, ADataType>;
using BDataTypeBuf = ck_tile::if_select_t<BDataType, ck_tile::tf32_t, float, BDataType>;
protected:
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
@@ -183,12 +187,16 @@ class TestCkTileGemmPipeline : public ::testing::Test
NumWaveGroup,
preshuffle>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
Scheduler>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataTypeBuf,
BDataTypeBuf,
AccDataType,
GemmShape,
GemmUniversalTraits,
Scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
ADataType>;
using GemmPipeline =
typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
@@ -304,24 +312,23 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::index_t stride_C =
ck_tile::get_default_stride(M, N, StrideC, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::HostTensor<ADataTypeBuf> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::HostTensor<BDataTypeBuf> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(BLayout{})));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, 11939}(a_m_k);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, 11940}(b_k_n);
ck_tile::FillUniformDistributionIntegerValue<ADataTypeBuf>{-5, 5, 11939}(a_m_k);
ck_tile::FillUniformDistributionIntegerValue<BDataTypeBuf>{-5, 5, 11940}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
if constexpr(std::is_same_v<BDataTypeBuf, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
ck_tile::HostTensor<BDataTypeBuf> b_k_n_dev = b_k_n;
permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}