[rocm-libraries] ROCm/rocm-libraries#4294 (commit 6601702)

Cleanup and refactoring related to tile loading
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

Cleanup and refactoring done while implementing mixed precision for
fp16/bf16 x fp8

Key changes:

- Renamed load_interleaved_pk_type.hpp to load_and_convert_tile.hpp and
refactored the API to use consistent naming conventions
- Updated load_tile_transpose functions to use output parameters instead
of return values for consistency
- Removed unused variable declarations and simplified type deduction
logic
- Define load_tile_with_elementwise to use tuple types explicitly for
clarity

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [x] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [X] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
SamiAario-AMD
2026-03-02 12:21:44 +00:00
committed by assistant-librarian[bot]
parent 0438ab1b79
commit 95dc496d30
47 changed files with 190 additions and 182 deletions

View File

@@ -5,22 +5,20 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
namespace ck_tile {
template <typename SrcDataType, typename DstDataType, index_t UnaryOpSize>
struct InterleavedPKTypeLoader
struct ConverterLoader
{
template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
const WarpWindow& warp_window)
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src)
{
const element_wise::PassThroughPack8 elementwise_op{};
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
const auto tmp = load_tile(src);
// NOTE: we rely on types packing neatly here
using RawSrcType = typename SrcDataType::type;
@@ -29,29 +27,28 @@ struct InterleavedPKTypeLoader
using SrcVectorType = ext_vector_t<RawSrcType, UnaryOpSize / PackedSize>;
using DstVectorType = ext_vector_t<DstDataType, UnaryOpSize>;
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<DstVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<SrcVectorType>()[i]);
const element_wise::PassThroughPack8 elementwise_op{};
elementwise_op(dst.get_thread_buffer().template get_as<DstVectorType>()(i),
tmp.get_thread_buffer().template get_as<SrcVectorType>()[i]);
});
}
};
template <typename SrcDataType,
typename DstDataType,
index_t UnaryOpSize,
bool LoadTranspose = false,
typename WarpTile,
typename WarpWindow>
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
template <index_t UnaryOpSize, bool LoadTranspose = false, typename WarpTile, typename WarpWindow>
CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src)
{
if constexpr(is_packed_type_v<SrcDataType>)
using SrcDataType = typename WarpWindow::Base::DataType;
using DstDataType = typename WarpTile::DataType;
if constexpr(is_packed_type_v<SrcDataType> && !is_packed_type_v<DstDataType>)
{
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t");
InterleavedPKTypeLoader<SrcDataType, DstDataType, UnaryOpSize>::load_interleaved_pk_type(
dst, src);
ConverterLoader<SrcDataType, DstDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
}
else if constexpr(LoadTranspose)
{
dst = load_tile_transpose(src);
load_tile_transpose(dst, src);
}
else
{