[CK_TILE] add tf32 support (#4302)

## Proposed changes

TF32 is added in CK on gfx942 and gfx950. This PR is to initiate tf32 in
CK_TILE on gfx942 and gfx950.

## Checklist

Please put an into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run  on all changed files
- [ ] Any dependent changes have been merged

## Discussion



---
🔁 Imported from
[ROCm/composable_kernel#3538](https://github.com/ROCm/composable_kernel/pull/3538)
🧑‍💻 Originally authored by @yingluAMD

---------

Co-authored-by: yingluAMD <Yingmao.Lu@amd.com>
Co-authored-by: assistant-librarian[bot] <assistant-librarian[bot]@users.noreply.github.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
assistant-librarian[bot]
2026-03-19 10:17:20 +01:00
committed by GitHub
parent 1333922f04
commit 39bc8453c6
30 changed files with 1164 additions and 260 deletions

View File

@@ -41,6 +41,17 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
return run_gemm_example_prec_type<GemmConfig, Invoker, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
#ifdef CK_GFX950_SUPPORT
else if(data_type == "tf32")
{
// Pass tf32_t as A/B types - epilogue auto-detects and maps to float for data operations
return run_gemm_example_prec_type<GemmConfig,
Invoker,
ck_tile::tf32_t,
ck_tile::tf32_t,
float>(a_layout, b_layout, arg_parser);
}
#endif
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig,

View File

@@ -6,8 +6,8 @@
struct BasicInvoker
{
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename ADataType_,
typename BDataType_,
typename DsDataType,
typename AccDataType,
typename CDataType,
@@ -19,14 +19,30 @@ struct BasicInvoker
typename CDEElementWise>
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// ADataTypeCompute: compute type (tf32_t for TF32 mode, used for warp gemm selection)
// ADataTypeBuf: buffer/storage type (fp32 when tf32)
using ADataTypeCompute = ADataType_;
using BDataTypeCompute = BDataType_;
using ADataTypeBuf = ck_tile::if_select_t<ADataType_, ck_tile::tf32_t, float, ADataType_>;
using BDataTypeBuf = ck_tile::if_select_t<BDataType_, ck_tile::tf32_t, float, BDataType_>;
if constexpr(std::is_same_v<ADataTypeCompute, ck_tile::tf32_t>)
{
static_assert(std::is_same_v<ADataTypeCompute, BDataTypeCompute>,
"ADataTypeCompute and BDataTypeCompute must be the same");
}
if constexpr(Persistent)
{
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
}
constexpr bool is_fp32_input = std::is_same_v<ADataTypeBuf, float>;
constexpr bool is_tf32_compute = std::is_same_v<ADataTypeCompute, ck_tile::tf32_t>;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t M_Tile = is_fp32_input ? 128 : 256;
constexpr ck_tile::index_t N_Tile = is_fp32_input ? 128 : 256;
constexpr ck_tile::index_t K_Tile = 64;
#if CK_TILE_USE_WMMA
@@ -38,12 +54,14 @@ struct BasicInvoker
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
// gfx950: fp32 uses 16x16x16 tile (native MFMA)
// tf32 uses 32x32x16 tile (3x bf16 32x32x16 MFMA emulation)
constexpr ck_tile::index_t M_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2;
constexpr ck_tile::index_t N_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t M_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32;
constexpr ck_tile::index_t N_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif
@@ -61,17 +79,21 @@ struct BasicInvoker
BLayout,
CLayout>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits>;
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<ADataTypeBuf,
BDataTypeBuf,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
ADataTypeCompute>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::CShuffleEpilogueProblem<ADataTypeCompute,
BDataTypeCompute,
ck_tile::tuple<>,
AccDataType,
CDataType,
@@ -112,7 +134,7 @@ struct BasicInvoker
}
// Declare rotating_mem_ptr here so it stays in scope until it is needed
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataTypeBuf, BDataTypeBuf>> rotating_mem_ptr;
std::function<void()> preprocess;
auto clear_gemm_output = [&]() {
@@ -125,16 +147,21 @@ struct BasicInvoker
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
ck_tile::HostTensor<ADataTypeBuf> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
ck_tile::HostTensor<BDataTypeBuf> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem_ptr =
std::make_unique<ck_tile::RotatingMemWrapper<ADataTypeBuf, BDataTypeBuf>>(
kargs.as_ptr[0],
kargs.bs_ptr[0],
s.rotating_count_,
size_a_buffer,
size_b_buffer);
rotating_mem_ptr->Print();
preprocess = [&]() {

View File

@@ -35,6 +35,10 @@ struct GemmConfigBase
static constexpr bool TiledMMAPermuteN = false;
};
// Type trait for tf32 storage type (tf32 uses float for memory layout calculations)
template <typename T>
using prec_storage_type = ck_tile::if_select_t<T, ck_tile::tf32_t, float, T>;
template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
{
@@ -81,7 +85,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(prec_storage_type<PrecType>);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
@@ -121,7 +125,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(prec_storage_type<PrecType>);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
@@ -293,7 +297,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(prec_storage_type<PrecType>);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
@@ -302,7 +306,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile<prec_storage_type<PrecType>, M_Warp_Tile, true>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
@@ -324,6 +328,15 @@ struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<Pre
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;
template <>
struct GemmTypeConfig<ck_tile::tf32_t, ck_tile::tf32_t, float>
{
using ADataType = float;
using BDataType = float;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmTypeConfig<ck_tile::half_t>
{
@@ -486,7 +499,7 @@ inline auto create_args()
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t/tf32 (tf32 only on gfx950)")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

View File

@@ -30,6 +30,7 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
@@ -205,11 +206,13 @@ std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> inline parse_ge
return std::make_tuple(M, N, K);
}
// ADataType_ and BDataType_ are original types (e.g., tf32_t for TF32 mode)
// They are passed through invoke_gemm to invoker for tf32 auto-detection
template <typename GemmConfig,
typename Invoker,
typename ADataType,
typename BDataType = ADataType,
typename CDataType = ADataType,
typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,
typename ALayout,
typename BLayout,
typename CLayout>
@@ -218,7 +221,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
// ADataTypeCompute: compute type (tf32_t for TF32 mode, used for warp gemm selection)
// ADataTypeBuf: buffer/storage type (fp32 when tf32, from TypeConfig)
using ADataTypeCompute = ADataType_;
using BDataTypeCompute = BDataType_;
// Use GemmTypeConfig to get actual data types for tensor operations
// This handles tf32 -> float mapping for host tensors and device buffers
using TypeConfig = GemmTypeConfig<ADataType_, BDataType_, CDataType_>;
using ADataTypeBuf = typename TypeConfig::ADataType;
using BDataTypeBuf = typename TypeConfig::BDataType;
using CDataType = typename TypeConfig::CDataType;
using AccDataType = typename TypeConfig::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
@@ -242,27 +256,27 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::HostTensor<ADataTypeBuf> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::HostTensor<BDataTypeBuf> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f}(b_k_n);
ck_tile::FillUniformDistribution<ADataTypeBuf>{-2.f, 2.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataTypeBuf>{-2.f, 2.f}(b_k_n);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
ck_tile::FillMonotonicSeq<ADataTypeBuf>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataTypeBuf>{}(b_k_n);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
ck_tile::FillUniformDistribution<ADataTypeBuf>{1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataTypeBuf>{1.f, 1.f}(b_k_n);
}
else
{
@@ -274,7 +288,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
{
if constexpr(GemmConfig::UseStructuredSparsity)
{
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
ck_tile::AdjustToStructuredSparsity<ADataTypeBuf>{}(a_m_k);
}
}
@@ -286,7 +300,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
if constexpr(preshuffle)
{
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
ck_tile::HostTensor<BDataTypeBuf> b_shuffle_host = [&]() {
if constexpr(GemmConfig::TiledMMAPermuteN)
{
std::cout << "Run with PermuteN" << std::endl;
@@ -299,7 +313,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
}
}();
// shuffled buffer B for device implementation
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
if constexpr(std::is_same_v<BDataTypeBuf, ck_tile::pk_int4_t>)
{
ck_tile::permute_vectors_i4x4_b(b_shuffle_host);
}
@@ -307,16 +321,16 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
}
else
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
if constexpr(std::is_same_v<BDataTypeBuf, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
ck_tile::HostTensor<BDataTypeBuf> b_k_n_dev = b_k_n;
if constexpr(GemmConfig::PermuteB)
{
permute_tensor_b<GemmConfig,
decltype(b_k_n_dev),
ADataType,
BDataType,
ADataTypeBuf,
BDataTypeBuf,
AccDataType,
CDataType,
ALayout,
@@ -343,8 +357,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
float ave_time = invoke_gemm<GemmConfig,
Invoker,
ADataType,
BDataType,
ADataTypeCompute,
BDataTypeCompute,
ck_tile::tuple<>,
AccDataType,
CDataType,
@@ -371,8 +385,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K / ck_tile::numeric_traits<ADataType>::PackedSize +
sizeof(BDataType) * N * K / ck_tile::numeric_traits<BDataType>::PackedSize +
sizeof(ADataTypeBuf) * M * K / ck_tile::numeric_traits<ADataTypeBuf>::PackedSize +
sizeof(BDataTypeBuf) * N * K / ck_tile::numeric_traits<BDataTypeBuf>::PackedSize +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
@@ -381,8 +395,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout=" << CLayout::name
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
<< " A_Type=" << ck_tile::DataTypeTraits<ADataTypeBuf>::name
<< " B_Type=" << ck_tile::DataTypeTraits<BDataTypeBuf>::name
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
@@ -397,17 +411,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
if(arg_parser.get_int("v") == 1)
{
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
ck_tile::reference_gemm<ADataTypeCompute, BDataTypeCompute, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
const auto rtol_atol =
calculate_rtol_atol<ADataTypeCompute, BDataTypeCompute, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
}
else if(arg_parser.get_int("v") == 2)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
if constexpr(std::is_same_v<BDataTypeBuf, ck_tile::pk_int4_t>)
{
// Restore input for B for gpu reference
b_k_n_dev_buf.ToDevice(b_k_n.data());
@@ -421,12 +436,12 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes());
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ADataTypeBuf* d_A = static_cast<ADataTypeBuf*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataTypeBuf* d_B = static_cast<BDataTypeBuf*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
ck_tile::reference_gemm_gpu<ADataTypeCompute,
BDataTypeCompute,
AccDataType,
CDataType,
ALayout,
@@ -437,8 +452,9 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
const float max_accumulated_value =
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
const auto rtol_atol =
calculate_rtol_atol<ADataTypeCompute, BDataTypeCompute, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU");
}
@@ -447,8 +463,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
dump_gemm_json_results<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
ADataTypeBuf,
BDataTypeBuf,
CDataType,
GemmConfig,
ck_tile::DataTypeTraits>(arg_parser.get_str("jsonfile"),

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/ext_vector_base.hpp"
#if CK_TILE_USE_LLVM_BUILTIN_BF16
#include <hip/hip_bfloat16.h>
#endif
@@ -440,4 +441,62 @@ CK_TILE_HOST_DEVICE constexpr bf16x2_t fp32x2_to_bf16x2(const fp32x2_t& x)
return bf16x2_t{float_to_bf16<rounding>(x.x), float_to_bf16<rounding>(x.y)};
}
// Available on gfx94x (gfx942, gfx950) and later
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
{
#if defined(__gfx94__) && CK_TILE_USE_LLVM_BUILTIN_BF16
return __builtin_convertvector(fp32x2_t{a, b}, bf16x2_t);
#else
return fp32x2_to_bf16x2(fp32x2_t{a, b});
#endif
}
// Packed bf16x2 to fp32x2 conversion
CK_TILE_HOST_DEVICE constexpr fp32x2_t bf16x2_to_fp32x2(bf16x2_t x)
{
#if CK_TILE_USE_LLVM_BUILTIN_BF16
return __builtin_convertvector(x, fp32x2_t);
#else
uint32_t packed = bit_cast<uint32_t>(x);
float f0 = bit_cast<float>(packed << 16);
float f1 = bit_cast<float>(packed & 0xFFFF0000u);
return fp32x2_t{f0, f1};
#endif
}
#ifndef CK_TILE_TF32_USE_PACKED_CVT
#define CK_TILE_TF32_USE_PACKED_CVT 1
#endif
template <int VecSize>
CK_TILE_DEVICE void convert_float_to_bf16_pairs(const ext_vector_t<float, VecSize>& reg_f32,
ext_vector_t<bfloat16_t, VecSize>& reg_bf16_big,
ext_vector_t<bfloat16_t, VecSize>& reg_bf16_small)
{
#if defined(__gfx94__) && CK_TILE_TF32_USE_PACKED_CVT && CK_TILE_USE_LLVM_BUILTIN_BF16
static_assert(VecSize % 2 == 0, "VecSize must be even for packed operations");
#pragma unroll
for(int i = 0; i < VecSize; i += 2)
{
fp32x2_t orig = {reg_f32[i], reg_f32[i + 1]};
bf16x2_t big_pair = cvt_pk_bf16_f32(orig[0], orig[1]);
fp32x2_t big_f32 = bf16x2_to_fp32x2(big_pair);
fp32x2_t diff = orig - big_f32;
bf16x2_t small_pair = cvt_pk_bf16_f32(diff[0], diff[1]);
reinterpret_cast<bf16x2_t*>(&reg_bf16_big)[i / 2] = big_pair;
reinterpret_cast<bf16x2_t*>(&reg_bf16_small)[i / 2] = small_pair;
}
#else
#pragma unroll
for(int i = 0; i < VecSize; i++)
{
reg_bf16_big[i] = float_to_bf16(reg_f32[i]);
reg_bf16_small[i] = float_to_bf16(reg_f32[i] - bf16_to_float(reg_bf16_big[i]));
}
#endif
}
} // namespace ck_tile

View File

@@ -0,0 +1,80 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <type_traits>
namespace ck_tile {
// this structure is used to pick up the <base> type inside
// using xxx = <base> __attribute__((ext_vector_type(N)));
// because clang only allow native type + bool in this term (custom type will fail)
// overload this structure to let proper <base> type
template <typename T>
struct native_t
{
using type = remove_cvref_t<T>;
};
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
// have compiler error
namespace impl {
template <typename T_, index_t N_, typename = void>
struct ext_vector;
template <typename T_, index_t N_>
struct ext_vector<T_, N_, std::enable_if_t<!std::is_class_v<typename native_t<T_>::type>>>
{
static constexpr index_t N = N_;
// struct type is not supported for ext_vector
using value_type = typename native_t<T_>::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename T_, index_t N_>
struct ext_vector<T_, N_, std::enable_if_t<std::is_class_v<typename native_t<T_>::type>>>
{
static constexpr index_t N = N_;
// struct type is not supported for ext_vector
using value_type = typename native_t<T_>::type::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename V_, index_t Vs_, index_t N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
N_,
std::enable_if_t<!std::is_class_v<typename native_t<V_>::type>>>
{
static constexpr index_t N = Vs_ * N_;
using value_type = typename native_t<remove_cvref_t<V_>>::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename V_, index_t Vs_, index_t N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
N_,
std::enable_if_t<std::is_class_v<typename native_t<V_>::type>>>
{
static constexpr index_t N = Vs_ * N_;
using value_type = typename native_t<remove_cvref_t<V_>>::type::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
} // namespace impl
template <typename T, index_t N>
using ext_vector_t = typename impl::ext_vector<T, N>::type;
} // namespace ck_tile

View File

@@ -9,6 +9,11 @@
namespace ck_tile {
// TF32 tag type: 1 sign bit, 8 exponent bits, 10 mantissa bits (see numeric_traits<tf32_t>)
struct tf32_t
{
};
// this struct has the information of
// 1. limit of a certain type, simliar to std::numeric_limits
// 2. some pre-defined value, zero, one...
@@ -101,6 +106,25 @@ struct numeric_traits<float>
using bitwise_type = uint32_t;
};
template <>
struct numeric_traits<tf32_t>
{
static constexpr int exp = 8;
static constexpr int mant = 10;
static constexpr int bias = 127;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t abs_mask = 0x7FFFFFFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
static constexpr int PackedSize = 1;
using bitwise_type = uint32_t;
};
} // namespace ck_tile
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \

View File

@@ -57,6 +57,44 @@ CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16)
CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8)
CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)
static constexpr uint32_t float32_exponent_mask = 0x7f800000u;
enum class tf32_rounding_mode
{
trunc = 0, // truncate
rne = 1, // round to nearest even (RTNE)
};
template <tf32_rounding_mode rounding = tf32_rounding_mode::trunc>
CK_TILE_HOST_DEVICE constexpr float float_to_tf32(float x)
{
uint32_t i = bit_cast<uint32_t>(x);
if constexpr(rounding == tf32_rounding_mode::rne)
{
// RTNE rounding.
if((i & float32_exponent_mask) != float32_exponent_mask)
{
// Add rounding bias for round-to-nearest-even (RTNE) before truncating:
// - 0xfff is the rounding bias corresponding to the 13 fraction bits that
// will be discarded.
// - (i >> 13) & 1 extracts the least significant of those discarded bits and
// adding it implements "ties to even" (round half-way cases to even).
i += 0xfff + ((i >> 13) & 1);
}
}
// Zero out the lowest 13 fraction bits to form the TF32-like value.
i &= 0xFFFFE000u;
return bit_cast<float>(i);
}
template <typename Y,
tf32_rounding_mode rounding = tf32_rounding_mode::trunc,
std::enable_if_t<std::is_same_v<Y, tf32_t>, bool> = false>
CK_TILE_HOST_DEVICE constexpr float type_convert(float x)
{
return float_to_tf32<rounding>(x);
}
CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)

View File

@@ -5,7 +5,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/ext_vector_base.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
@@ -13,77 +13,9 @@
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/e8m0.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// this structure is used to pick up the <base> type inside
// using xxx = <base> __attribute__((ext_vector_type(N)));
// because clang only allow native type + bool in this term (custom type will fail)
// overload this structure to let proper <base> type
template <typename T>
struct native_t
{
using type = remove_cvref_t<T>;
};
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
// have compiler error
namespace impl {
template <typename T_, index_t N_, typename = void>
struct ext_vector;
template <typename T_, index_t N_>
struct ext_vector<T_, N_, std::enable_if_t<!std::is_class_v<typename native_t<T_>::type>>>
{
static constexpr index_t N = N_;
// struct type is not supported for ext_vector
using value_type = typename native_t<T_>::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename T_, index_t N_>
struct ext_vector<T_, N_, std::enable_if_t<std::is_class_v<typename native_t<T_>::type>>>
{
static constexpr index_t N = N_;
// struct type is not supported for ext_vector
using value_type = typename native_t<T_>::type::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename V_, index_t Vs_, index_t N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
N_,
std::enable_if_t<!std::is_class_v<typename native_t<V_>::type>>>
{
static constexpr index_t N = Vs_ * N_;
using value_type = typename native_t<remove_cvref_t<V_>>::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename V_, index_t Vs_, index_t N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
N_,
std::enable_if_t<std::is_class_v<typename native_t<V_>::type>>>
{
static constexpr index_t N = Vs_ * N_;
using value_type = typename native_t<remove_cvref_t<V_>>::type::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
} // namespace impl
template <typename T, index_t N>
using ext_vector_t = typename impl::ext_vector<T, N>::type;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template <typename T, typename = void>

View File

@@ -112,6 +112,11 @@ CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
#pragma clang diagnostic pop
}
// Template ternary: if Cond == Match, use TrueType, else FalseType
// Usage: if_select_t<T, int, float, double> evaluates to float if T==int, else double
template <typename Cond, typename Match, typename TrueType, typename FalseType>
using if_select_t = std::conditional_t<std::is_same_v<Cond, Match>, TrueType, FalseType>;
template <typename CompareTo, typename... Rest>
struct is_any_of : std::false_type
{

View File

@@ -58,6 +58,7 @@ CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1
F16,
BF16,
F32,
tf32_t,
pk_fp4_t,
pk_fp4_raw_t,
pk_int4_t,
@@ -76,8 +77,9 @@ CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the relative threshold!");
static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, tf32_t, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0;
if constexpr(is_any_of<OutDataType, pk_int4_t, I8, I32, int>::value)
@@ -90,8 +92,9 @@ CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the relative threshold!");
static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, tf32_t, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0;
if constexpr(is_any_of<AccDataType, pk_int4_t, I8, I32, int>::value)
@@ -129,6 +132,7 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
F16,
BF16,
F32,
tf32_t,
pk_fp4_t,
pk_fp4_raw_t,
pk_int4_t,
@@ -151,8 +155,9 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
compute_error = std::pow(2, discrete_expo - numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
static_assert(
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, tf32_t, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0;
if constexpr(is_any_of<OutDataType, pk_int4_t, I8, I32, int>::value)
@@ -168,8 +173,9 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
static_assert(
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, tf32_t, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0;
if constexpr(is_any_of<AccDataType, pk_int4_t, I8, I32, int>::value)

View File

@@ -4,11 +4,11 @@
#pragma once
#include <cstdlib>
#include <mutex>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ck_tile {
@@ -447,24 +447,34 @@ CK_TILE_HOST void reference_mx_gemm_bquant(const HostTensor<ADataType>& a_m_k,
std::cout << std::endl;
}
template <typename ADataType,
typename BDataType,
template <typename ADataType_,
typename BDataType_,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
CK_TILE_HOST void
reference_gemm(const HostTensor<if_select_t<ADataType_, tf32_t, float, ADataType_>>& a_m_k,
const HostTensor<if_select_t<BDataType_, tf32_t, float, BDataType_>>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
if constexpr(std::is_same_v<ADataType_, tf32_t> || std::is_same_v<BDataType_, tf32_t>)
static_assert(std::is_same_v<ADataType_, BDataType_>,
"ADataType and BDataType must be the same");
using ADataTypeCompute = ADataType_;
using ADataTypeBuf = if_select_t<ADataType_, tf32_t, float, ADataType_>;
using BDataTypeBuf = if_select_t<BDataType_, tf32_t, float, BDataType_>;
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
const bool is_gfx950 = (ck_tile::get_device_name() == "gfx950");
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
@@ -472,7 +482,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
{
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
if constexpr(std::is_same_v<ADataTypeBuf, pk_fp4_t>)
{
// HostTensor automatically handles packed indexing: a_m_k(m,k) divides offset by
// PackedSize So a_m_k(m,0) and a_m_k(m,1) return the same packed byte
@@ -481,7 +491,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
v_a = ck_tile::type_convert<AccDataType>(a_element_op(unpacked));
}
else if constexpr(std::is_same_v<ADataType, pk_int4_t>)
else if constexpr(std::is_same_v<ADataTypeBuf, pk_int4_t>)
{
// HostTensor automatically handles packed indexing
const pk_int4_t pk_val = a_m_k(m, k);
@@ -493,7 +503,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
if constexpr(std::is_same_v<BDataTypeBuf, pk_fp4_t>)
{
// HostTensor automatically handles packed indexing
const pk_fp4_t pk_val = b_k_n(k, n);
@@ -501,7 +511,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
v_b = ck_tile::type_convert<AccDataType>(b_element_op(unpacked));
}
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
else if constexpr(std::is_same_v<BDataTypeBuf, pk_int4_t>)
{
// HostTensor automatically handles packed indexing
const pk_int4_t pk_val = b_k_n(k, n);
@@ -513,7 +523,36 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_acc += v_a * v_b;
if constexpr(std::is_same_v<ADataTypeCompute, tf32_t>)
{
if(is_gfx950)
{
// gfx950: use 3x bf16 emulation
bf16_t v_a_bf16_big = ck_tile::type_convert<bf16_t>(v_a);
bf16_t v_a_bf16_small = ck_tile::type_convert<bf16_t>(
v_a - type_convert<AccDataType>(v_a_bf16_big));
bf16_t v_b_bf16_big = ck_tile::type_convert<bf16_t>(v_b);
bf16_t v_b_bf16_small = ck_tile::type_convert<bf16_t>(
v_b - type_convert<AccDataType>(v_b_bf16_big));
v_acc += ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
ck_tile::type_convert<AccDataType>(v_b_bf16_small) +
ck_tile::type_convert<AccDataType>(v_a_bf16_small) *
ck_tile::type_convert<AccDataType>(v_b_bf16_big) +
ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
ck_tile::type_convert<AccDataType>(v_b_bf16_big);
}
else
{
// Other architectures: tf32 not supported or handled via fp32 fallback
v_acc += v_a * v_b;
}
}
else
{
v_acc += v_a * v_b;
}
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
@@ -764,15 +803,15 @@ reference_gemm_multiple_d(const HostTensor<ADataType>& a_m_k,
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
template <typename ADataType_,
typename BDataType_,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void naive_gemm_kernel(ADataType* A,
BDataType* B,
__global__ void naive_gemm_kernel(if_select_t<ADataType_, tf32_t, float, ADataType_>* A,
if_select_t<BDataType_, tf32_t, float, BDataType_>* B,
CDataType* C,
ck_tile::index_t M,
ck_tile::index_t N,
@@ -781,6 +820,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
ck_tile::index_t strideB,
ck_tile::index_t strideC)
{
if constexpr(std::is_same_v<ADataType_, tf32_t> || std::is_same_v<BDataType_, tf32_t>)
static_assert(std::is_same_v<ADataType_, BDataType_>,
"ADataType and BDataType must be the same");
using ADataTypeCompute = ADataType_;
// ADataTypeBuf: buffer/storage type (fp32 when tf32)
using ADataTypeBuf = if_select_t<ADataType_, tf32_t, float, ADataType_>;
using BDataTypeBuf = if_select_t<BDataType_, tf32_t, float, BDataType_>;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int row = idx / N; // Compute row index
int col = idx % N; // Compute column index
@@ -790,8 +837,8 @@ __global__ void naive_gemm_kernel(ADataType* A,
AccDataType acc = 0.0;
for(int k = 0; k < K; ++k)
{
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataTypeBuf>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataTypeBuf>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
@@ -802,7 +849,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
if constexpr(std::is_same_v<ADataTypeBuf, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
@@ -810,7 +857,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
else
v_a = fp32_val.lo;
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
else if constexpr(std::is_same_v<ADataTypeBuf, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
if(k % 2 == 1)
@@ -822,7 +869,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
if constexpr(std::is_same_v<BDataTypeBuf, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
@@ -830,7 +877,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
else if constexpr(std::is_same_v<BDataTypeBuf, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
if(k % 2 == 1)
@@ -842,7 +889,33 @@ __global__ void naive_gemm_kernel(ADataType* A,
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
}
acc += v_a * v_b;
if constexpr(std::is_same_v<ADataTypeCompute, tf32_t>)
{
#ifdef CK_GFX950_SUPPORT
// gfx950: use 3x bf16 emulation
bf16_t v_a_bf16_big = ck_tile::type_convert<bf16_t>(v_a);
bf16_t v_a_bf16_small =
ck_tile::type_convert<bf16_t>(v_a - type_convert<AccDataType>(v_a_bf16_big));
bf16_t v_b_bf16_big = ck_tile::type_convert<bf16_t>(v_b);
bf16_t v_b_bf16_small =
ck_tile::type_convert<bf16_t>(v_b - type_convert<AccDataType>(v_b_bf16_big));
acc += ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
ck_tile::type_convert<AccDataType>(v_b_bf16_small) +
ck_tile::type_convert<AccDataType>(v_a_bf16_small) *
ck_tile::type_convert<AccDataType>(v_b_bf16_big) +
ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
ck_tile::type_convert<AccDataType>(v_b_bf16_big);
#else
// Other architectures: use fp32 fallback
acc += v_a * v_b;
#endif
}
else
{
acc += v_a * v_b;
}
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
@@ -852,15 +925,15 @@ __global__ void naive_gemm_kernel(ADataType* A,
}
}
template <typename ADataType,
typename BDataType,
template <typename ADataType_,
typename BDataType_,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void blockwise_gemm_kernel(ADataType* A,
BDataType* B,
__global__ void blockwise_gemm_kernel(if_select_t<ADataType_, tf32_t, float, ADataType_>* A,
if_select_t<BDataType_, tf32_t, float, BDataType_>* B,
CDataType* C,
ck_tile::index_t M,
ck_tile::index_t N,
@@ -874,6 +947,14 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
float* scale_A_ptr,
float* scale_B_ptr)
{
if constexpr(std::is_same_v<ADataType_, tf32_t> || std::is_same_v<BDataType_, tf32_t>)
static_assert(std::is_same_v<ADataType_, BDataType_>,
"ADataType and BDataType must be the same");
using ADataTypeCompute = ADataType_;
// ADataTypeBuf: buffer/storage type (fp32 when tf32)
using ADataTypeBuf = if_select_t<ADataType_, tf32_t, float, ADataType_>;
using BDataTypeBuf = if_select_t<BDataType_, tf32_t, float, BDataType_>;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int row = idx / N; // Compute row index
int col = idx % N; // Compute column index
@@ -902,8 +983,8 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
(k / scale_granularity_k) * scale_B_stride];
}
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataTypeBuf>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataTypeBuf>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
@@ -914,7 +995,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
if constexpr(std::is_same_v<ADataTypeBuf, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
@@ -922,7 +1003,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
else
v_a = fp32_val.lo;
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
else if constexpr(std::is_same_v<ADataTypeBuf, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
if(k % 2 == 1)
@@ -935,7 +1016,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
if constexpr(std::is_same_v<BDataTypeBuf, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
@@ -943,7 +1024,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
else if constexpr(std::is_same_v<BDataTypeBuf, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
if(k % 2 == 1)
@@ -955,7 +1036,33 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
}
acc_temp += v_a * v_b;
if constexpr(std::is_same_v<ADataTypeCompute, tf32_t>)
{
#ifdef CK_GFX950_SUPPORT
// gfx950: use 3x bf16 emulation
bf16_t v_a_bf16_big = ck_tile::type_convert<bf16_t>(v_a);
bf16_t v_a_bf16_small =
ck_tile::type_convert<bf16_t>(v_a - type_convert<AccDataType>(v_a_bf16_big));
bf16_t v_b_bf16_big = ck_tile::type_convert<bf16_t>(v_b);
bf16_t v_b_bf16_small =
ck_tile::type_convert<bf16_t>(v_b - type_convert<AccDataType>(v_b_bf16_big));
acc_temp += ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
ck_tile::type_convert<AccDataType>(v_b_bf16_small) +
ck_tile::type_convert<AccDataType>(v_a_bf16_small) *
ck_tile::type_convert<AccDataType>(v_b_bf16_big) +
ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
ck_tile::type_convert<AccDataType>(v_b_bf16_big);
#else
// Other architectures: use fp32 fallback
acc_temp += v_a * v_b;
#endif
}
else
{
acc_temp += v_a * v_b;
}
}
// final accumulation
acc += acc_temp * scale_A * scale_B;
@@ -974,8 +1081,8 @@ template <typename ADataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_gemm_gpu(ADataType* a_ptr,
BDataType* b_ptr,
void reference_gemm_gpu(if_select_t<ADataType, tf32_t, float, ADataType>* a_ptr,
if_select_t<BDataType, tf32_t, float, BDataType>* b_ptr,
CDataType* c_ptr,
index_t M,
index_t N,
@@ -1002,8 +1109,8 @@ template <typename ADataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_blockwise_gemm_gpu(ADataType* a_ptr,
BDataType* b_ptr,
void reference_blockwise_gemm_gpu(if_select_t<ADataType, tf32_t, float, ADataType>* a_ptr,
if_select_t<BDataType, tf32_t, float, BDataType>* b_ptr,
CDataType* c_ptr,
index_t M,
index_t N,
@@ -1040,15 +1147,15 @@ void reference_blockwise_gemm_gpu(ADataType* a_ptr,
return;
}
template <typename ADataType,
typename BDataType,
template <typename ADataType_,
typename BDataType_,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_batched_gemm_gpu(ADataType* a_ptr,
BDataType* b_ptr,
void reference_batched_gemm_gpu(if_select_t<ADataType_, tf32_t, float, ADataType_>* a_ptr,
if_select_t<BDataType_, tf32_t, float, BDataType_>* b_ptr,
CDataType* c_ptr,
index_t M,
index_t N,
@@ -1061,18 +1168,29 @@ void reference_batched_gemm_gpu(ADataType* a_ptr,
index_t batch_stride_C,
index_t batch_count)
{
using ADataTypeBuf = if_select_t<ADataType_, tf32_t, float, ADataType_>;
using BDataTypeBuf = if_select_t<BDataType_, tf32_t, float, BDataType_>;
using ADataTypeCompute = ADataType_;
using BDataTypeCompute = BDataType_;
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
{
ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
ADataTypeBuf* d_ATemp = a_ptr + batch_id * batch_stride_A;
BDataTypeBuf* d_BTemp = b_ptr + batch_id * batch_stride_B;
CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
naive_gemm_kernel<ADataTypeCompute,
BDataTypeCompute,
AccDataType,
CDataType,
LayoutA,
LayoutB,
LayoutC><<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
}
return;

View File

@@ -89,19 +89,32 @@ struct CShuffleEpilogue
remove_cvref_t<BsDataType>,
remove_cvref_t<tuple<BsDataType>>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
// ADataTypeCompute: compute type from Problem (may be tf32_t for TF32 mode)
using ADataTypeCompute = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using BDataTypeCompute = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
using ATypeToUse = std::conditional_t<std::is_same_v<ADataType, pk_int4_t> ||
std::is_same_v<ADataType, pk_fp4_t>,
BDataType,
ADataType>;
// ADataTypeBuf: buffer/storage type (fp32 when tf32)
using ADataTypeBuf = if_select_t<ADataTypeCompute, tf32_t, float, ADataTypeCompute>;
using BDataTypeBuf = if_select_t<BDataTypeCompute, tf32_t, float, BDataTypeCompute>;
// For warp gemm selection: use tf32_t if compute type was tf32_t
// For pk_int4/pk_fp4: use the other data type
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataTypeCompute, tf32_t>,
tf32_t,
std::conditional_t<std::is_same_v<ADataTypeBuf, pk_int4_t> ||
std::is_same_v<ADataTypeBuf, pk_fp4_t>,
BDataTypeBuf,
ADataTypeBuf>>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_t> ||
sizeof(BDataType) < sizeof(ADataType),
ADataType,
BDataType>;
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataTypeCompute, tf32_t>,
tf32_t,
std::conditional_t<std::is_same_v<BDataTypeBuf, pk_int4_t> ||
std::is_same_v<BDataTypeBuf, pk_fp4_t> ||
sizeof(BDataTypeBuf) < sizeof(ADataTypeBuf),
ADataTypeBuf,
BDataTypeBuf>>;
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
@@ -137,7 +150,7 @@ struct CShuffleEpilogue
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "CShuffleEpilogue",
return concat('_', "CShuffleEpilogue",
concat('x', MWave, NWave),
concat('x', MPerXdl, NPerXdl, KPerXdl),
VectorSizeC,
@@ -440,8 +453,8 @@ struct CShuffleEpilogue
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
// this branch is for original a16w4
if constexpr(is_950 || is_any_of<ADataType, pk_int4_t, pk_fp4_t>::value ||
is_any_of<BDataType, pk_int4_t, pk_fp4_t>::value)
if constexpr(is_950 || is_any_of<ADataTypeBuf, pk_int4_t, pk_fp4_t>::value ||
is_any_of<BDataTypeBuf, pk_int4_t, pk_fp4_t>::value)
{
if constexpr(EightWave)
{

View File

@@ -229,15 +229,6 @@ CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
return result;
}
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
{
bf16x2_t result;
asm volatile("v_cvt_pk_bf16_f32 %[result], %[a], %[b]"
: [result] "=v"(result)
: [a] "v"(a), [b] "v"(b));
return result;
}
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
{
fp32x2_t result;
@@ -856,7 +847,7 @@ struct BlockFmhaFwdV3Pipeline
}
else
{
auto casted = detail::cvt_pk_bf16_f32(x, y);
auto casted = ck_tile::cvt_pk_bf16_f32(x, y);
sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
}

View File

@@ -49,6 +49,7 @@ struct GemmPipelineAgBgCrImplBase
// that only work for certain K warp tile sizes based on data type size:
// - For 1-byte types (fp8/bf8): K warp tile <= 64
// - For 2-byte types (fp16/bf16): K warp tile <= 32
// - For 4-byte types (float/tf32): transpose load not supported
static constexpr bool is_a_load_tr = []() {
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
@@ -57,6 +58,8 @@ struct GemmPipelineAgBgCrImplBase
return false;
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(sizeof(ADataType) >= 4)
return false; // 4-byte types (float/tf32) don't support transpose load
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
@@ -71,6 +74,8 @@ struct GemmPipelineAgBgCrImplBase
return false;
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(sizeof(BDataType) >= 4)
return false; // 4-byte types (float/tf32) don't support transpose load
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else

View File

@@ -909,26 +909,28 @@ struct UniversalGemmPipelineAgBgCrPolicy
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using ATypeToUse = if_select_t<ADataType, pk_int4_t, BDataType, ADataType>;
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_t> ||
sizeof(BDataType) < sizeof(ADataType),
ADataType,
BDataType>;
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
Problem::UseStructuredSparsity,
wg_attr_num_access>;
using WarpGemm =
WarpGemmDispatcher<if_select_t<ComputeDataType, tf32_t, tf32_t, ATypeToUse>,
if_select_t<ComputeDataType, tf32_t, tf32_t, BTypeToUse>,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
Problem::UseStructuredSparsity,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<ATypeToUse,
BTypeToUse,

View File

@@ -257,33 +257,37 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
// Use ComputeDataType to detect tf32 mode for warp gemm selection
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
// Determine compute types to use
// This logic defaults to A/B DataType, but if one of them is packed falls back to the other
// If both are packed, it falls back to the explicitly defined ComputeDataType in the
// problem It might be a good idea to use ComputeDataType anyway, but that would break how
// this behaviour used to work
using ATypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::ComputeDataType>;
using BTypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::BDataType,
typename Problem::ADataType,
typename Problem::ComputeDataType>;
using ATypeToUse =
mixed_prec_compute_type_from_input_t<ADataType, BDataType, ComputeDataType>;
using BTypeToUse =
mixed_prec_compute_type_from_input_t<BDataType, ADataType, ComputeDataType>;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
// When BDataType is pk_int4_t, it is internally converted to fp8 for computation.
constexpr index_t KLaneBytes = KLane * sizeof(BTypeToUse);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
NumAccess>;
// For tf32 mode, use tf32_t for warp gemm; otherwise use original types
using WarpGemm =
WarpGemmDispatcher<if_select_t<ComputeDataType, tf32_t, tf32_t, ATypeToUse>,
if_select_t<ComputeDataType, tf32_t, tf32_t, BTypeToUse>,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
NumAccess>;
using BlockWeightPreshufflePolicy =
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,

View File

@@ -48,6 +48,28 @@ using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution =
4,
AttrNumAccess>>;
// tf32
// On gfx950: uses 3x bf16 MFMA emulation (no native xf32 support)
#if defined(CK_GFX950_SUPPORT)
// gfx950: tf32 emulated using 3x bf16 MFMA
using WarpGemmMfmaTf32Tf32F32M32N32K16Native = WarpGemmImpl<WarpGemmAttributeMfma<
WarpGemmAttributeMfmaImplF32F32F32M32N32K16Tf32Gfx950<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaTf32Tf32F32M16N16K32Native = WarpGemmImpl<WarpGemmAttributeMfma<
WarpGemmAttributeMfmaImplF32F32F32M16N16K32Tf32Gfx950<WGAttrCtlEnum::Default_>>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaTf32Tf32F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfma<
WarpGemmAttributeMfmaImplF32F32F32M32N32K16Tf32Gfx950<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaTf32Tf32F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfma<
WarpGemmAttributeMfmaImplF32F32F32M16N16K32Tf32Gfx950<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
#endif
// fp16
using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<

View File

@@ -190,6 +190,141 @@ struct WarpGemmAttributeMfmaImplF32F32F32M32N32K2
}
};
// tf32/xf32 emulation on gfx950 using 3x bf16 MFMA
// Algorithm: split float into bf16_big and bf16_small, then compute:
// out = A_big * B_big + A_small * B_big + A_big * B_small
// This provides tf32-like precision using bf16 hardware
// V_MFMA_F32_32x32x16_XF32 emulated on gfx950 using 3x bf16 32x32x16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF32F32F32M32N32K16Tf32Gfx950
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = float;
using BDataType = float;
using CDataType = float;
// Input: 8 floats for K=16 (each lane holds 8 elements, kABKPerLane=8)
using AVecType = ext_vector_t<ADataType, 8>;
using BVecType = ext_vector_t<BDataType, 8>;
using CVecType = ext_vector_t<CDataType, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
static constexpr index_t kK = 16;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 32;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
#if defined(__gfx950__)
// Convert float to bf16 pairs using packed instructions
ext_vector_t<bf16_t, 8> a_big, a_small, b_big, b_small;
convert_float_to_bf16_pairs<8>(a_vec, a_big, a_small);
convert_float_to_bf16_pairs<8>(b_vec, b_big, b_small);
// Run 3 bf16 MFMAs: small*big, big*small, big*big
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_small, b_big, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_big, b_small, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_big, b_big, c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
CVecType c_vec{0.f};
(*this)(c_vec, a_vec, b_vec);
return c_vec;
}
};
// V_MFMA_F32_16x16x32_XF32 emulated on gfx950 using 3x bf16 16x16x32
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF32F32F32M16N16K32Tf32Gfx950
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = float;
using BDataType = float;
using CDataType = float;
// Input: 8 floats for K=32 (each lane holds 8 elements, kABKPerLane=8)
using AVecType = ext_vector_t<ADataType, 8>;
using BVecType = ext_vector_t<BDataType, 8>;
using CVecType = ext_vector_t<CDataType, 4>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 32;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
#if defined(__gfx950__)
// Convert float to bf16 pairs using packed instructions
ext_vector_t<bf16_t, 8> a_big, a_small, b_big, b_small;
convert_float_to_bf16_pairs<8>(a_vec, a_big, a_small);
convert_float_to_bf16_pairs<8>(b_vec, b_big, b_small);
// Run 3 bf16 MFMAs: small*big, big*small, big*big
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_small, b_big, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_big, b_small, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_big, b_big, c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
CVecType c_vec{0.f};
(*this)(c_vec, a_vec, b_vec);
return c_vec;
}
};
// V_MFMA_F32_16x16x32_BF16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32

View File

@@ -40,6 +40,22 @@ template<> struct Dispatcher<float, float, float, 32, 32, 4, false> { using Typ
template<> struct Dispatcher<float, float, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF32F32F32M32N32K8<>; };
template<> struct Dispatcher<float, float, float, 32, 32, 8, false, false, false, EDouble> { using Type = WarpGemmMfmaF32F32F32M32N32K8<EDouble>; };
template<> struct Dispatcher<float, float, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
// tf32 (on gfx950: uses 3x bf16 MFMA emulation)
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
#if defined(CK_GFX950_SUPPORT)
template<> struct Dispatcher<tf32_t, tf32_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
template<> struct Dispatcher<tf32_t, tf32_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<>; };
template<> struct Dispatcher<tf32_t, tf32_t, float, 32, 32, 16, false, false, false, EDouble> { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<EDouble>; };
template<> struct Dispatcher<tf32_t, tf32_t, float, 32, 32, 16, false, false, false, EQuad> { using Type = WarpGemmMfmaTf32Tf32F32M32N32K16<EQuad>; };
// TF32 16x16x32 for weight preshuffle pipeline (uses native 16x16x32 TF32 MFMA emulation)
template<> struct Dispatcher<tf32_t, tf32_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaTf32Tf32F32M16N16K32<>; };
template<> struct Dispatcher<tf32_t, tf32_t, float, 16, 16, 32, false, false, false, EDouble> { using Type = WarpGemmMfmaTf32Tf32F32M16N16K32<EDouble>; };
template<> struct Dispatcher<tf32_t, tf32_t, float, 16, 16, 32, false, false, false, EQuad> { using Type = WarpGemmMfmaTf32Tf32F32M16N16K32<EQuad>; };
#endif
// Note: For gfx11/gfx12 and other architectures that don't support tf32,
// these dispatchers are not defined. Code using tf32 should be guarded
// by CK_ENABLE_TF32 or CK_GFX950_SUPPORT macros.
// fp16
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };

View File

@@ -7,6 +7,8 @@ endif()
if(GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp)
add_gtest_executable(test_ck_tile_mx_scale test_mx_scale.cpp)
add_gtest_executable(test_ck_tile_tf32 test_tf32.cpp)
add_gtest_executable(test_ck_tile_bf16_f32_convert test_bf16_f32_convert.cpp)
endif()
if(CK_USE_OCP_FP8 OR CK_USE_FNUZ_FP8)

View File

@@ -0,0 +1,248 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gtest/gtest.h"
#include <cmath>
#include <vector>
#include <hip/hip_runtime.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
using ck_tile::bf16_to_float;
using ck_tile::bf16x2_t;
using ck_tile::bfloat16_t;
using ck_tile::bit_cast;
using ck_tile::float_to_bf16;
using ck_tile::fp32x2_t;
// =====================================================================
// Tests for bf16x2_to_fp32x2 (host-side, always available)
// =====================================================================
TEST(Bf16F32Convert, Bf16x2ToFp32x2_BasicValues)
{
auto a = float_to_bf16(1.0f);
auto b = float_to_bf16(-2.5f);
bf16x2_t packed{a, b};
fp32x2_t result = ck_tile::bf16x2_to_fp32x2(packed);
EXPECT_FLOAT_EQ(result[0], bf16_to_float(a));
EXPECT_FLOAT_EQ(result[1], bf16_to_float(b));
}
TEST(Bf16F32Convert, Bf16x2ToFp32x2_Zeros)
{
auto pos_zero = float_to_bf16(0.0f);
auto neg_zero = float_to_bf16(-0.0f);
bf16x2_t packed{pos_zero, neg_zero};
fp32x2_t result = ck_tile::bf16x2_to_fp32x2(packed);
EXPECT_FLOAT_EQ(result[0], 0.0f);
EXPECT_TRUE(std::signbit(result[1]));
EXPECT_FLOAT_EQ(result[1], -0.0f);
}
TEST(Bf16F32Convert, Bf16x2ToFp32x2_LargeSmall)
{
auto big = float_to_bf16(65504.0f);
auto small = float_to_bf16(0.00390625f);
bf16x2_t packed{big, small};
fp32x2_t result = ck_tile::bf16x2_to_fp32x2(packed);
EXPECT_FLOAT_EQ(result[0], bf16_to_float(big));
EXPECT_FLOAT_EQ(result[1], bf16_to_float(small));
}
TEST(Bf16F32Convert, Bf16x2ToFp32x2_RoundTrip)
{
const float test_values[] = {1.0f, -1.0f, 0.5f, 3.14f, 100.0f, -42.0f, 0.001f};
for(float v : test_values)
{
auto bf = float_to_bf16(v);
float expected = bf16_to_float(bf);
bf16x2_t packed{bf, bf};
fp32x2_t result = ck_tile::bf16x2_to_fp32x2(packed);
EXPECT_FLOAT_EQ(result[0], expected) << "v=" << v;
EXPECT_FLOAT_EQ(result[1], expected) << "v=" << v;
}
}
// =====================================================================
// Tests for fp32x2_to_bf16x2 (host-side)
// =====================================================================
TEST(Bf16F32Convert, Fp32x2ToBf16x2_BasicValues)
{
fp32x2_t input{1.5f, -3.0f};
bf16x2_t result = ck_tile::fp32x2_to_bf16x2(input);
EXPECT_FLOAT_EQ(bf16_to_float(result[0]), bf16_to_float(float_to_bf16(1.5f)));
EXPECT_FLOAT_EQ(bf16_to_float(result[1]), bf16_to_float(float_to_bf16(-3.0f)));
}
// =====================================================================
// Device tests for cvt_pk_bf16_f32 and convert_float_to_bf16_pairs
// =====================================================================
struct CvtPkBf16F32Result
{
bfloat16_t r0;
bfloat16_t r1;
};
__global__ void kernel_cvt_pk_bf16_f32(const float* in, CvtPkBf16F32Result* out, int n)
{
int idx = threadIdx.x;
if(idx < n)
{
bf16x2_t result = ck_tile::cvt_pk_bf16_f32(in[2 * idx], in[2 * idx + 1]);
out[idx].r0 = result[0];
out[idx].r1 = result[1];
}
}
TEST(Bf16F32Convert, CvtPkBf16F32_Device)
{
const std::vector<float> host_in = {1.0f, -1.0f, 0.0f, 3.14f, 100.0f, -0.5f, 42.0f, 0.001f};
const int num_pairs = host_in.size() / 2;
ck_tile::DeviceMem in_buf(host_in.size() * sizeof(float));
ck_tile::DeviceMem out_buf(num_pairs * sizeof(CvtPkBf16F32Result));
in_buf.ToDevice(host_in.data());
kernel_cvt_pk_bf16_f32<<<1, num_pairs>>>(
static_cast<const float*>(in_buf.GetDeviceBuffer()),
static_cast<CvtPkBf16F32Result*>(out_buf.GetDeviceBuffer()),
num_pairs);
(void)hipDeviceSynchronize();
std::vector<CvtPkBf16F32Result> host_out(num_pairs);
out_buf.FromDevice(host_out.data());
for(int i = 0; i < num_pairs; i++)
{
float ref0 = bf16_to_float(float_to_bf16(host_in[2 * i]));
float ref1 = bf16_to_float(float_to_bf16(host_in[2 * i + 1]));
EXPECT_FLOAT_EQ(bf16_to_float(host_out[i].r0), ref0) << "pair=" << i << " elem=0";
EXPECT_FLOAT_EQ(bf16_to_float(host_out[i].r1), ref1) << "pair=" << i << " elem=1";
}
}
// =====================================================================
// Device test for convert_float_to_bf16_pairs
// =====================================================================
template <int VecSize>
struct Bf16PairsResult
{
bfloat16_t big[VecSize];
bfloat16_t small_val[VecSize];
};
template <int VecSize>
__global__ void kernel_convert_float_to_bf16_pairs(const float* in, Bf16PairsResult<VecSize>* out)
{
using float_vec_t = ck_tile::ext_vector_t<float, VecSize>;
using bf16_vec_t = ck_tile::ext_vector_t<bfloat16_t, VecSize>;
float_vec_t reg_f32;
for(int i = 0; i < VecSize; i++)
reg_f32[i] = in[i];
bf16_vec_t reg_big, reg_small;
ck_tile::convert_float_to_bf16_pairs<VecSize>(reg_f32, reg_big, reg_small);
for(int i = 0; i < VecSize; i++)
{
out[0].big[i] = reg_big[i];
out[0].small_val[i] = reg_small[i];
}
}
template <int VecSize>
void test_convert_float_to_bf16_pairs_device()
{
static_assert(VecSize >= 2 && VecSize % 2 == 0);
std::vector<float> host_in(VecSize);
// Use diverse values: mix of exact and non-exact bf16 representable numbers
const float base_vals[] = {1.1f, -2.3f, 0.7f, 100.1f, -0.001f, 42.42f, 3.14f, -7.77f};
for(int i = 0; i < VecSize; i++)
host_in[i] = base_vals[i % 8];
ck_tile::DeviceMem in_buf(VecSize * sizeof(float));
ck_tile::DeviceMem out_buf(sizeof(Bf16PairsResult<VecSize>));
in_buf.ToDevice(host_in.data());
kernel_convert_float_to_bf16_pairs<VecSize>
<<<1, 1>>>(static_cast<const float*>(in_buf.GetDeviceBuffer()),
static_cast<Bf16PairsResult<VecSize>*>(out_buf.GetDeviceBuffer()));
(void)hipDeviceSynchronize();
Bf16PairsResult<VecSize> host_out;
out_buf.FromDevice(&host_out);
for(int i = 0; i < VecSize; i++)
{
float orig = host_in[i];
float big_f = bf16_to_float(host_out.big[i]);
// big should match scalar float_to_bf16
float ref_big = bf16_to_float(float_to_bf16(orig));
EXPECT_FLOAT_EQ(big_f, ref_big) << "VecSize=" << VecSize << " i=" << i;
// small should match float_to_bf16(orig - big)
float ref_small = bf16_to_float(float_to_bf16(orig - ref_big));
float small_f = bf16_to_float(host_out.small_val[i]);
EXPECT_FLOAT_EQ(small_f, ref_small) << "VecSize=" << VecSize << " i=" << i;
// big + small should be closer to orig than big alone
float reconstructed = big_f + small_f;
EXPECT_LE(std::fabs(reconstructed - orig), std::fabs(big_f - orig) + 1e-10f)
<< "VecSize=" << VecSize << " i=" << i;
}
}
TEST(Bf16F32Convert, ConvertFloatToBf16Pairs_Vec2) { test_convert_float_to_bf16_pairs_device<2>(); }
TEST(Bf16F32Convert, ConvertFloatToBf16Pairs_Vec4) { test_convert_float_to_bf16_pairs_device<4>(); }
TEST(Bf16F32Convert, ConvertFloatToBf16Pairs_Vec8) { test_convert_float_to_bf16_pairs_device<8>(); }
// =====================================================================
// 3x BF16 multiply-accumulate precision test
// =====================================================================
TEST(Bf16F32Convert, ThreeBf16MulAccPrecision)
{
// Verify that a_big*b_big + a_small*b_big + a_big*b_small is more precise
// than a single bf16(a)*bf16(b) for non-exact values
const float test_pairs[][2] = {
{1.1f, 2.3f}, {3.14f, -2.71f}, {0.123f, 456.789f}, {-100.1f, 0.99f}};
for(const auto& pair : test_pairs)
{
float a = pair[0];
float b = pair[1];
float a_big_f = bf16_to_float(float_to_bf16(a));
float a_small_f = bf16_to_float(float_to_bf16(a - a_big_f));
float b_big_f = bf16_to_float(float_to_bf16(b));
float b_small_f = bf16_to_float(float_to_bf16(b - b_big_f));
float exact = a * b;
float single_bf16 = a_big_f * b_big_f;
float three_bf16 = a_big_f * b_big_f + a_small_f * b_big_f + a_big_f * b_small_f;
float err_single = std::fabs(exact - single_bf16);
float err_three = std::fabs(exact - three_bf16);
EXPECT_LE(err_three, err_single + 1e-10f)
<< "a=" << a << " b=" << b << " exact=" << exact << " single=" << single_bf16
<< " three=" << three_bf16;
}
}

View File

@@ -0,0 +1,86 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gtest/gtest.h"
#include <cmath>
#include <cstring>
#include <limits>
#include "ck_tile/core.hpp"
using ck_tile::bit_cast;
using ck_tile::numeric_traits;
using ck_tile::tf32_rounding_mode;
using ck_tile::tf32_t;
using ck_tile::type_convert;
static uint32_t to_bits(float x) { return bit_cast<uint32_t>(x); }
static float from_bits(uint32_t i) { return bit_cast<float>(i); }
TEST(ConvertTest, NumericTraits)
{
EXPECT_EQ(numeric_traits<tf32_t>::exp, 8);
EXPECT_EQ(numeric_traits<tf32_t>::mant, 10);
EXPECT_EQ(numeric_traits<tf32_t>::bias, 127);
EXPECT_EQ(numeric_traits<tf32_t>::PackedSize, 1);
}
TEST(ConvertTest, ToTf32Trunc)
{
// exact values (low 13 bits already zero)
EXPECT_EQ(to_bits(type_convert<tf32_t>(1.0f)), 0x3F800000u); // 1.0f
EXPECT_EQ(to_bits(type_convert<tf32_t>(-1.0f)), 0xBF800000u); // -1.0f
EXPECT_EQ(to_bits(type_convert<tf32_t>(0.0f)), 0x00000000u); // +0.0f
EXPECT_EQ(to_bits(type_convert<tf32_t>(-0.0f)), 0x80000000u); // -0.0f
EXPECT_EQ(to_bits(type_convert<tf32_t>(2.0f)), 0x40000000u); // 2.0f
EXPECT_EQ(to_bits(type_convert<tf32_t>(0.5f)), 0x3F000000u); // 0.5f
// truncation zeros the low 13 mantissa bits
EXPECT_EQ(to_bits(type_convert<tf32_t>(1.1f)), 0x3F8CC000u); // 1.1f (0x3F8CCCCD)
EXPECT_EQ(to_bits(type_convert<tf32_t>(3.14159265358979323846f)),
0x40490000u); // pi (0x40490FDB)
EXPECT_EQ(to_bits(type_convert<tf32_t>(123.456f)),
0x42F6E000u); // 123.456f (0x42F6E979)
EXPECT_EQ(to_bits(type_convert<tf32_t>(-3.14f)), 0xC048E000u); // -3.14f (0xC048F5C3)
// special values
EXPECT_EQ(to_bits(type_convert<tf32_t>(std::numeric_limits<float>::infinity())), 0x7F800000u);
EXPECT_EQ(to_bits(type_convert<tf32_t>(-std::numeric_limits<float>::infinity())), 0xFF800000u);
EXPECT_TRUE(std::isnan(type_convert<tf32_t>(std::numeric_limits<float>::quiet_NaN())));
EXPECT_EQ(to_bits(type_convert<tf32_t>(std::numeric_limits<float>::denorm_min())), 0x00000000u);
// property: low 13 bits must be zero, top 19 bits preserved
for(float val : {1.0f, 1.5f, 2.0f, 0.1f, 100.0f, -42.5f, 1e10f, 1e-10f})
{
uint32_t orig = to_bits(val);
uint32_t tf32 = to_bits(type_convert<tf32_t>(val));
EXPECT_EQ(tf32 & 0xFFFFE000u, tf32) << "val=" << val;
EXPECT_EQ(orig & 0xFFFFE000u, tf32) << "val=" << val;
}
}
TEST(ConvertTest, ToTf32Rtne)
{
// exact values (low 13 bits already zero)
EXPECT_EQ(to_bits(type_convert<tf32_t, tf32_rounding_mode::rne>(1.0f)),
0x3F800000u); // 1.0f
EXPECT_EQ(to_bits(type_convert<tf32_t, tf32_rounding_mode::rne>(-1.0f)),
0xBF800000u); // -1.0f
EXPECT_EQ(to_bits(type_convert<tf32_t, tf32_rounding_mode::rne>(0.0f)),
0x00000000u); // +0.0f
// past midpoint (bit12 + bit11 set) -> rounds up
float val = from_bits(0x3F801800u);
EXPECT_EQ(to_bits(type_convert<tf32_t, tf32_rounding_mode::rne>(val)), 0x3F802000u);
// special values (keep the same as float)
EXPECT_EQ(to_bits(type_convert<tf32_t, tf32_rounding_mode::rne>(
std::numeric_limits<float>::infinity())),
0x7F800000u); // infinity in float is 0x7F800000
EXPECT_EQ(to_bits(type_convert<tf32_t, tf32_rounding_mode::rne>(
-std::numeric_limits<float>::infinity())),
0xFF800000u); // negative infinity in float is 0xFF800000
EXPECT_TRUE(std::isnan(type_convert<tf32_t, tf32_rounding_mode::rne>(
std::numeric_limits<float>::quiet_NaN()))); // quiet NaN in float is 0x7FC00000
}

View File

@@ -46,8 +46,8 @@ test_cshuffle_epilogue_kernel(const typename Problem::AccDataType* __restrict__
__shared__ char smem[Epilogue::GetSmemSize()];
// Create accumulator tile with GEMM accumulator distribution (matches BlockGemm)
using WG = ck_tile::WarpGemmDispatcher<typename Epilogue::ADataType,
typename Epilogue::BDataType,
using WG = ck_tile::WarpGemmDispatcher<typename Epilogue::ATypeToUse,
typename Epilogue::BTypeToUse,
typename Problem::AccDataType,
Problem::MPerXdl,
Problem::NPerXdl,

View File

@@ -46,8 +46,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
add_gtest_executable(test_ck_tile_gemm_pipeline_comp_async test_gemm_pipeline_comp_async.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_comp_async PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS})
add_gtest_executable(test_ck_tile_gemm_pipeline_tf32_mem test_gemm_pipeline_tf32_mem.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_tf32_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
list(APPEND CK_TILE_GEMM_TEST_TARGETS
test_ck_tile_gemm_pipeline_comp_async
test_ck_tile_gemm_pipeline_tf32_mem
)
add_gtest_executable(test_ck_tile_gemm_pipeline_comp_async_eight_waves test_gemm_pipeline_comp_async_eight_waves.cpp)

View File

@@ -320,4 +320,14 @@ using KernelTypesPersistentWmma = ::testing::Types<
std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV3, NonPersistent>
>;
// TF32 (gfx950 only): 3x bf16 MFMA emulation, uses float buffers with tf32_t compute type
// Tile: 128x128x64, Warp tile: 32x32x16
using KernelTypesTf32Mem = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_BlockSize, N_BlockSize, K_BlockSize, M_TileSize, N_TileSize, K_TileSize, Scheduler, PipelineType
std::tuple< Row, Row, Row, TF32, TF32, F32, F32, I128, I128, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Row, Row, Row, TF32, TF32, F32, F32, I128, I128, I64, I32, I32, I16, Interwave, Mem>,
std::tuple< Row, Col, Row, TF32, TF32, F32, F32, I128, I128, I64, I32, I32, I16, Intrawave, Mem>,
std::tuple< Row, Col, Row, TF32, TF32, F32, F32, I128, I128, I64, I32, I32, I16, Interwave, Mem>
>;
// clang-format on

View File

@@ -13,3 +13,5 @@ using BF16 = ck_tile::bf16_t;
using BF8 = ck_tile::bf8_t;
using I4 = ck_tile::pk_int4_t;
using TF32 = ck_tile::tf32_t;

View File

@@ -0,0 +1,22 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_util.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineTf32Mem
: public TestCkTileGemmPipeline<T, TestCkTileGemmPipelineTf32Mem<T>>
{
public:
static constexpr bool check_data_type() { return true; }
};
#define TEST_SUITE_NAME TestCkTileGemmPipelineTf32Mem
TYPED_TEST_SUITE(TEST_SUITE_NAME, KernelTypesTf32Mem);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -135,6 +135,10 @@ class TestCkTileGemmPipeline : public ::testing::Test
static constexpr bool Persistent =
ck_tile::tuple_element_or_default_t<Tuple, 15, std::false_type>::value;
// TF32 uses tf32_t as compute type but float as buffer/storage type
using ADataTypeBuf = ck_tile::if_select_t<ADataType, ck_tile::tf32_t, float, ADataType>;
using BDataTypeBuf = ck_tile::if_select_t<BDataType, ck_tile::tf32_t, float, BDataType>;
protected:
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
@@ -183,12 +187,16 @@ class TestCkTileGemmPipeline : public ::testing::Test
NumWaveGroup,
preshuffle>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
Scheduler>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataTypeBuf,
BDataTypeBuf,
AccDataType,
GemmShape,
GemmUniversalTraits,
Scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
ADataType>;
using GemmPipeline =
typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
@@ -304,24 +312,23 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::index_t stride_C =
ck_tile::get_default_stride(M, N, StrideC, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::HostTensor<ADataTypeBuf> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::HostTensor<BDataTypeBuf> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(BLayout{})));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, 11939}(a_m_k);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, 11940}(b_k_n);
ck_tile::FillUniformDistributionIntegerValue<ADataTypeBuf>{-5, 5, 11939}(a_m_k);
ck_tile::FillUniformDistributionIntegerValue<BDataTypeBuf>{-5, 5, 11940}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
if constexpr(std::is_same_v<BDataTypeBuf, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
ck_tile::HostTensor<BDataTypeBuf> b_k_n_dev = b_k_n;
permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}

View File

@@ -20,6 +20,12 @@ struct DataTypeTraits<float>
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<ck_tile::tf32_t>
{
static constexpr const char* name = "tf32";
};
template <>
struct DataTypeTraits<double>
{