[rocm-libraries] ROCm/rocm-libraries#4294 (commit 6601702)

Cleanup and refactoring related to tile loading
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

Cleanup and refactoring done while implementing mixed precision for
fp16/bf16 x fp8

Key changes:

- Renamed load_interleaved_pk_type.hpp to load_and_convert_tile.hpp and
refactored the API to use consistent naming conventions
- Updated load_tile_transpose functions to use output parameters instead
of return values for consistency
- Removed unused variable declarations and simplified type deduction
logic
- Define load_tile_with_elementwise to use tuple types explicitly for
clarity

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [x] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [X] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
SamiAario-AMD
2026-03-02 12:21:44 +00:00
committed by assistant-librarian[bot]
parent 0438ab1b79
commit 95dc496d30
47 changed files with 190 additions and 182 deletions

View File

@@ -4,7 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
@@ -218,10 +218,8 @@ struct BlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
}
// C += A * B
@@ -290,9 +288,9 @@ struct BlockUniversalGemmAsBsCr
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
static constexpr auto ALdsTileDistr =
make_static_tile_distribution(MakeABlockDistributionEncode());
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
make_static_tile_distribution(MakeBBlockDistributionEncode());
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
@@ -349,10 +347,8 @@ struct BlockUniversalGemmAsBsCr
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_lds_gemm_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_lds_gemm_window);
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_lds_gemm_window);
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_lds_gemm_window);
}
// C += A * B

View File

@@ -79,9 +79,7 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
template <typename SrcDataType = void,
typename DstDataType = void,
index_t UnaryOpSize = 8,
template <index_t UnaryOpSize = 8,
typename DstBlockTile,
typename SrcTileWindow,
typename DramTileWindowStep>
@@ -89,7 +87,7 @@ struct GemmPipelineAgBgCrImplBase
SrcTileWindow& dram_tile_window,
const DramTileWindowStep& dram_tile_window_step) const
{
load_int4_tile<SrcDataType, DstDataType, UnaryOpSize>(dst_block_tile, dram_tile_window);
load_and_convert_tile<UnaryOpSize>(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, dram_tile_window_step);
}
@@ -124,7 +122,7 @@ struct GemmPipelineAgBgCrImplBase
bool_constant<LoadTranspose> = {}) const
{
if constexpr(LoadTranspose)
dst_block_tile = load_tile_transpose(lds_tile_window);
load_tile_transpose(dst_block_tile, lds_tile_window);
else
load_tile(dst_block_tile, lds_tile_window);
}
@@ -241,12 +239,6 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr&) const
{
// with pk_int4_t load transpose the LDS type is always BDataType
using ADataTypeLDS =
std::conditional_t<std::is_same_v<typename Problem::ADataType, pk_int4_t>,
typename Problem::BDataType,
typename Problem::ADataType>;
auto a_lds_shape = []() {
if constexpr(is_a_load_tr)
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
@@ -258,11 +250,16 @@ struct GemmPipelineAgBgCrImplBase
auto a_lds_load_tile_distr = []() {
if constexpr(is_a_load_tr)
{
return make_static_tile_distribution(
typename InputTileDistributionTraits<typename ALdsLoadTileDistr::DstrEncode,
ADataTypeLDS>::TransposedDstrEncode{});
typename InputTileDistributionTraits<
typename ALdsLoadTileDistr::DstrEncode,
typename ALdsTensorView::DataType>::TransposedDstrEncode{});
}
else
{
return ALdsLoadTileDistr{};
}
}();
auto a_lds_gemm_window =
@@ -333,18 +330,18 @@ struct GemmPipelineAgBgCrImplBase
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
using BLdsDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
typename Problem::ADataType,
typename Problem::BDataType>;
auto b_lds_load_tile_distr = []() {
if constexpr(is_b_load_tr)
{
return make_static_tile_distribution(
typename InputTileDistributionTraits<typename BLdsLoadTileDistr::DstrEncode,
BLdsDataType>::TransposedDstrEncode{});
typename InputTileDistributionTraits<
typename BLdsLoadTileDistr::DstrEncode,
typename BLdsTensorView::DataType>::TransposedDstrEncode{});
}
else
{
return BLdsLoadTileDistr{};
}
}();
auto b_lds_gemm_window =

View File

@@ -127,7 +127,6 @@ struct UniversalGemmBasePolicy
using ADataType = OverrideADataType;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = Derived::template GetSmemPackA<Problem>();
if constexpr(is_a_load_tr<Problem>)
{
@@ -261,6 +260,7 @@ struct UniversalGemmBasePolicy
}
else // A is in RowMajor
{
constexpr index_t KPack = Derived::template GetSmemPackA<Problem>();
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr uint64_t MinLdsLayer = 1ULL;
constexpr auto MLdsLayer =

View File

@@ -4,7 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
@@ -627,8 +627,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// // Prefetch A0
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
// Prefill A0
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
@@ -652,7 +651,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
do
{
{
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
Base::GlobalPrefetch(
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
Base::GlobalPrefetch(
@@ -666,7 +665,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
HotLoopScheduler();
}
{
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
Base::GlobalPrefetch(
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
Base::GlobalPrefetch(
@@ -687,7 +686,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
if constexpr(TailNum == TailNumber::Even)
{
{
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
Base::GlobalPrefetch(
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
block_weight_preshuffle(