[rocm-libraries] ROCm/rocm-libraries#4302 (commit e62bd8a)

[CK_TILE] add tf32 support
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

TF32 is added in CK on gfx942 and gfx950. This PR is to initiate tf32 in
CK_TILE on gfx942 and gfx950.

## Checklist

Please put an into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run  on all changed files
- [ ] Any dependent changes have been merged

## Discussion
This commit is contained in:
yinglu
2026-03-19 09:19:06 +00:00
committed by assistant-librarian[bot]
parent 652d3456ca
commit d460ab35b6
30 changed files with 1164 additions and 260 deletions

View File

@@ -35,6 +35,10 @@ struct GemmConfigBase
static constexpr bool TiledMMAPermuteN = false;
};
// Type trait for tf32 storage type (tf32 uses float for memory layout calculations)
template <typename T>
using prec_storage_type = ck_tile::if_select_t<T, ck_tile::tf32_t, float, T>;
template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
{
@@ -81,7 +85,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(prec_storage_type<PrecType>);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
@@ -121,7 +125,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(prec_storage_type<PrecType>);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
@@ -293,7 +297,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(prec_storage_type<PrecType>);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
@@ -302,7 +306,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile<prec_storage_type<PrecType>, M_Warp_Tile, true>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
@@ -324,6 +328,15 @@ struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<Pre
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;
template <>
struct GemmTypeConfig<ck_tile::tf32_t, ck_tile::tf32_t, float>
{
using ADataType = float;
using BDataType = float;
using AccDataType = float;
using CDataType = float;
};
template <>
struct GemmTypeConfig<ck_tile::half_t>
{
@@ -486,7 +499,7 @@ inline auto create_args()
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t/tf32 (tf32 only on gfx950)")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")