mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Merge commit '47cd0d5cff77658adc1c9f184c012ec3496e8214' into develop
This commit is contained in:
58
include/ck_tile/ops/common/load_interleaved_pk_type.hpp
Normal file
58
include/ck_tile/ops/common/load_interleaved_pk_type.hpp
Normal file
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <class T>
|
||||
struct is_pk_int4 : std::false_type
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct is_pk_int4<pk_int4_t> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ComputeDataType, index_t UnaryOpSize>
|
||||
struct InterleavedPKTypeLoader
|
||||
{
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
|
||||
const WarpWindow& warp_window)
|
||||
{
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
|
||||
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BDataType,
|
||||
typename ComputeDataType,
|
||||
index_t UnaryOpSize,
|
||||
typename WarpTile,
|
||||
typename WarpWindow>
|
||||
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
|
||||
{
|
||||
if constexpr(is_pk_int4<std::remove_cv_t<BDataType>>::value)
|
||||
{
|
||||
InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
|
||||
}
|
||||
else
|
||||
{
|
||||
dst = load_tile(src);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user