[rocm-libraries] ROCm/rocm-libraries#4302 (commit e62bd8a)

[CK_TILE] add tf32 support
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

TF32 is added in CK on gfx942 and gfx950. This PR is to initiate tf32 in
CK_TILE on gfx942 and gfx950.

## Checklist

Please put an into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run  on all changed files
- [ ] Any dependent changes have been merged

## Discussion
This commit is contained in:
yinglu
2026-03-19 09:19:06 +00:00
committed by assistant-librarian[bot]
parent 652d3456ca
commit d460ab35b6
30 changed files with 1164 additions and 260 deletions

View File

@@ -49,6 +49,7 @@ struct GemmPipelineAgBgCrImplBase
// that only work for certain K warp tile sizes based on data type size:
// - For 1-byte types (fp8/bf8): K warp tile <= 64
// - For 2-byte types (fp16/bf16): K warp tile <= 32
// - For 4-byte types (float/tf32): transpose load not supported
static constexpr bool is_a_load_tr = []() {
using WarpTile = typename BlockGemmShape::WarpTile;
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
@@ -57,6 +58,8 @@ struct GemmPipelineAgBgCrImplBase
return false;
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(sizeof(ADataType) >= 4)
return false; // 4-byte types (float/tf32) don't support transpose load
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else
@@ -71,6 +74,8 @@ struct GemmPipelineAgBgCrImplBase
return false;
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
return false;
else if constexpr(sizeof(BDataType) >= 4)
return false; // 4-byte types (float/tf32) don't support transpose load
else if constexpr(kKWarpTile > kMaxKWarpTile)
return false;
else

View File

@@ -909,26 +909,28 @@ struct UniversalGemmPipelineAgBgCrPolicy
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using ATypeToUse = if_select_t<ADataType, pk_int4_t, BDataType, ADataType>;
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_t> ||
sizeof(BDataType) < sizeof(ADataType),
ADataType,
BDataType>;
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
Problem::UseStructuredSparsity,
wg_attr_num_access>;
using WarpGemm =
WarpGemmDispatcher<if_select_t<ComputeDataType, tf32_t, tf32_t, ATypeToUse>,
if_select_t<ComputeDataType, tf32_t, tf32_t, BTypeToUse>,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
Problem::UseStructuredSparsity,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<ATypeToUse,
BTypeToUse,

View File

@@ -257,33 +257,37 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
// Use ComputeDataType to detect tf32 mode for warp gemm selection
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
// Determine compute types to use
// This logic defaults to A/B DataType, but if one of them is packed falls back to the other
// If both are packed, it falls back to the explicitly defined ComputeDataType in the
// problem It might be a good idea to use ComputeDataType anyway, but that would break how
// this behaviour used to work
using ATypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::ComputeDataType>;
using BTypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::BDataType,
typename Problem::ADataType,
typename Problem::ComputeDataType>;
using ATypeToUse =
mixed_prec_compute_type_from_input_t<ADataType, BDataType, ComputeDataType>;
using BTypeToUse =
mixed_prec_compute_type_from_input_t<BDataType, ADataType, ComputeDataType>;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
// When BDataType is pk_int4_t, it is internally converted to fp8 for computation.
constexpr index_t KLaneBytes = KLane * sizeof(BTypeToUse);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
NumAccess>;
// For tf32 mode, use tf32_t for warp gemm; otherwise use original types
using WarpGemm =
WarpGemmDispatcher<if_select_t<ComputeDataType, tf32_t, tf32_t, ATypeToUse>,
if_select_t<ComputeDataType, tf32_t, tf32_t, BTypeToUse>,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
NumAccess>;
using BlockWeightPreshufflePolicy =
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,