[rocm-libraries] ROCm/rocm-libraries#4294 (commit 6601702)

Cleanup and refactoring related to tile loading
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

Cleanup and refactoring done while implementing mixed precision for
fp16/bf16 x fp8

Key changes:

- Renamed load_interleaved_pk_type.hpp to load_and_convert_tile.hpp and
refactored the API to use consistent naming conventions
- Updated load_tile_transpose functions to use output parameters instead
of return values for consistency
- Removed unused variable declarations and simplified type deduction
logic
- Define load_tile_with_elementwise to use tuple types explicitly for
clarity

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [x] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [X] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
SamiAario-AMD
2026-03-02 12:21:44 +00:00
committed by assistant-librarian[bot]
parent 0438ab1b79
commit 95dc496d30
47 changed files with 190 additions and 182 deletions

View File

@@ -48,19 +48,19 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
* is additionally applied during a single read.
*/
template <typename TileWindow_,
template <typename... TileWindow_,
typename ElementWise_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
// TODO: Tile windows should works with unknow number of params
// Load element_wise API works only when the input typle is a tuple-tyupe
return tile_window[number<0>{}].load(
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
// TODO: Tile windows should work with unknown number of params
// Load element_wise API works only when the input type is a tuple-type
return tile_windows[number<0>{}].load(
tile_windows, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
@@ -85,12 +85,12 @@ template <typename DistributedTensor_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
CK_TILE_DEVICE void load_tile(DistributedTensor_& dst_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
/**
@@ -131,7 +131,7 @@ template <typename T,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
CK_TILE_DEVICE void load_tile_raw(T& tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,

View File

@@ -374,6 +374,7 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
@@ -381,18 +382,19 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param out_tensor A statically distributed tensor containing the transposed tile
* data.
* @param tile_window The tile window with static distribution to load and transpose.
* @param offset The offset (in elements) added to the base address before
* indexing.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
@@ -402,21 +404,28 @@ template <
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE auto load_tile_transpose_with_offset(
CK_TILE_DEVICE void load_tile_transpose_with_offset(
DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window,
index_t offset)
{
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{};
// Check that the tile distribution of out_tensor is the expected one for transposed loads.
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
static_assert(std::is_same_v<decltype(make_static_tile_distribution(OutTileDstrEncode{})),
remove_cvref_t<decltype(output_distr)>>);
// Check that the datatype of out_tensor matches that of the bottom tensor view.
static_assert(std::is_same_v<typename DistributedTensor_::DataType,
typename BottomTensorView_::DataType>);
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
@@ -443,8 +452,6 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
number<iAccess>{},
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
});
return out_tensor;
}
/**
@@ -456,6 +463,7 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
@@ -463,16 +471,37 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param out_tensor A statically distributed tensor containing the transposed tile
* data.
* @param tile_window The tile window with static distribution to load and transpose.
* indexing.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE void
load_tile_transpose(DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window)
{
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
}
template <
typename BottomTensorView_,
typename WindowLengths_,
@@ -489,7 +518,15 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
TileDistribution_,
NumCoord>& __restrict__ tile_window)
{
return load_tile_transpose_with_offset(tile_window, 0);
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
return out_tensor;
}
} // namespace ck_tile

View File

@@ -190,11 +190,11 @@ struct tile_window_with_static_distribution
* The same thread, during vectorized reading, accesses the same set of
* data from A0, A1, A2, … AN.
*/
template <typename TileWindow_,
template <typename... TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
CK_TILE_DEVICE auto load(const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
@@ -202,7 +202,7 @@ struct tile_window_with_static_distribution
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
load(dst_tensor,
tile_window,
tile_windows,
elementwise,
number<i_access_unsupport_>{},
bool_constant<oob_conditional_check>{});
@@ -210,12 +210,12 @@ struct tile_window_with_static_distribution
}
template <typename DistributedTensor,
typename TileWindow_,
typename... TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
const TileWindow_& tile_window,
const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
@@ -226,14 +226,14 @@ struct tile_window_with_static_distribution
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = typename Base::TileDstr{};
constexpr auto sizeOfTuple = TileWindow_::size();
constexpr auto sizeOfTuple = remove_cvref_t<decltype(tile_windows)>::size();
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
tile_windows[number<0>{}].pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
tile_windows[number<0>{}].pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
@@ -244,7 +244,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
const auto idx_vec_value = generate_tuple(
[&](auto jj) {
return tile_window[number<jj>{}]
return tile_windows[number<jj>{}]
.get_bottom_tensor_view()
.template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,

View File

@@ -8,7 +8,7 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -6,7 +6,7 @@
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -11,7 +11,7 @@
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -3,7 +3,7 @@
#pragma once
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -5,22 +5,20 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
namespace ck_tile {
template <typename SrcDataType, typename DstDataType, index_t UnaryOpSize>
struct InterleavedPKTypeLoader
struct ConverterLoader
{
template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
const WarpWindow& warp_window)
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src)
{
const element_wise::PassThroughPack8 elementwise_op{};
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
const auto tmp = load_tile(src);
// NOTE: we rely on types packing neatly here
using RawSrcType = typename SrcDataType::type;
@@ -29,29 +27,28 @@ struct InterleavedPKTypeLoader
using SrcVectorType = ext_vector_t<RawSrcType, UnaryOpSize / PackedSize>;
using DstVectorType = ext_vector_t<DstDataType, UnaryOpSize>;
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<DstVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<SrcVectorType>()[i]);
const element_wise::PassThroughPack8 elementwise_op{};
elementwise_op(dst.get_thread_buffer().template get_as<DstVectorType>()(i),
tmp.get_thread_buffer().template get_as<SrcVectorType>()[i]);
});
}
};
template <typename SrcDataType,
typename DstDataType,
index_t UnaryOpSize,
bool LoadTranspose = false,
typename WarpTile,
typename WarpWindow>
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
template <index_t UnaryOpSize, bool LoadTranspose = false, typename WarpTile, typename WarpWindow>
CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src)
{
if constexpr(is_packed_type_v<SrcDataType>)
using SrcDataType = typename WarpWindow::Base::DataType;
using DstDataType = typename WarpTile::DataType;
if constexpr(is_packed_type_v<SrcDataType> && !is_packed_type_v<DstDataType>)
{
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t");
InterleavedPKTypeLoader<SrcDataType, DstDataType, UnaryOpSize>::load_interleaved_pk_type(
dst, src);
ConverterLoader<SrcDataType, DstDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
}
else if constexpr(LoadTranspose)
{
dst = load_tile_transpose(src);
load_tile_transpose(dst, src);
}
else
{

View File

@@ -9,7 +9,7 @@
#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -11,7 +11,7 @@
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -22,7 +22,7 @@
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -61,7 +61,7 @@
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -530,7 +530,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
load_tile_transpose(dot_reg_tensor, dot_lds_read_window);
}
if constexpr(is_epilogue)
{
@@ -634,7 +634,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr);
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
load_tile_transpose(qt_reg_tensor, qt_lds_read_window);
// STAGE 3, P^T@OGrad^T Gemm1
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
@@ -715,7 +715,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
}
if constexpr(is_epilogue)
{
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
load_tile_transpose(ds_reg_tensor, ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
if constexpr(is_main_body)
@@ -728,7 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
auto kt_reg_tensor_slice = get_slice_tile( //

View File

@@ -455,10 +455,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
async_load_tile(q_lds_write_window, q_dram_window);
async_load_tile(do_lds_write_window, do_dram_window);
__builtin_amdgcn_s_waitcnt(0);
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
q_reg_tensor = load_tile(q_lds_read_window);
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
do_reg_tensor = load_tile(do_lds_read_window);
load_tile_transpose(qt_reg_tensor, qt_lds_read_window);
q_reg_tensor = load_tile(q_lds_read_window);
load_tile_transpose(dot_reg_tensor, dot_lds_read_window);
do_reg_tensor = load_tile(do_lds_read_window);
lse_block_tile = load_tile(lse_dram_window);
d_block_tile = load_tile(d_dram_window);
@@ -490,9 +490,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
async_load_tile(v_lds_write_window, v_dram_window);
move_tile_window(v_dram_window, {kN0, 0});
s_waitcnt</*vmcnt=*/0>();
k_reg_tensor = load_tile(k_lds_read_window);
v_reg_tensor = load_tile(v_lds_read_window);
kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
k_reg_tensor = load_tile(k_lds_read_window);
v_reg_tensor = load_tile(v_lds_read_window);
load_tile_transpose(kt_reg_tensor, kt_lds_read_window);
}
if constexpr(is_epilogue)
{
@@ -668,7 +668,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
block_sync_lds();
if constexpr(is_epilogue)
{
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
load_tile_transpose(ds_reg_tensor, ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
if constexpr(is_main_body)
@@ -680,7 +680,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
auto kt_reg_tensor_slice = get_slice_tile( //

View File

@@ -718,7 +718,7 @@ struct BlockFmhaFwdV3Pipeline
};
auto V_lds_load = [&](auto v_lds_read_idx) {
kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx));
load_tile_transpose(kv_tile.v_tile, v_lds_window_load(v_lds_read_idx));
};
decltype(m) m_old;

View File

@@ -582,7 +582,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
// loop over along the [V]alue Sequence length
move_tile_window(v_lds_read_window, {kK1, 0});
v_tile = load_tile_transpose(v_lds_read_window);
load_tile_transpose(v_tile, v_lds_read_window);
});
// move back to the origin
move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0});

View File

@@ -15,7 +15,7 @@
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -79,7 +79,7 @@
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -4,7 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
@@ -218,10 +218,8 @@ struct BlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_block_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_block_window);
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
}
// C += A * B
@@ -290,9 +288,9 @@ struct BlockUniversalGemmAsBsCr
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
static constexpr auto ALdsTileDistr =
make_static_tile_distribution(MakeABlockDistributionEncode());
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
make_static_tile_distribution(MakeBBlockDistributionEncode());
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
@@ -349,10 +347,8 @@ struct BlockUniversalGemmAsBsCr
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
a_lds_gemm_window);
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
b_lds_gemm_window);
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_lds_gemm_window);
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_lds_gemm_window);
}
// C += A * B

View File

@@ -79,9 +79,7 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
template <typename SrcDataType = void,
typename DstDataType = void,
index_t UnaryOpSize = 8,
template <index_t UnaryOpSize = 8,
typename DstBlockTile,
typename SrcTileWindow,
typename DramTileWindowStep>
@@ -89,7 +87,7 @@ struct GemmPipelineAgBgCrImplBase
SrcTileWindow& dram_tile_window,
const DramTileWindowStep& dram_tile_window_step) const
{
load_int4_tile<SrcDataType, DstDataType, UnaryOpSize>(dst_block_tile, dram_tile_window);
load_and_convert_tile<UnaryOpSize>(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, dram_tile_window_step);
}
@@ -124,7 +122,7 @@ struct GemmPipelineAgBgCrImplBase
bool_constant<LoadTranspose> = {}) const
{
if constexpr(LoadTranspose)
dst_block_tile = load_tile_transpose(lds_tile_window);
load_tile_transpose(dst_block_tile, lds_tile_window);
else
load_tile(dst_block_tile, lds_tile_window);
}
@@ -241,12 +239,6 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr&) const
{
// with pk_int4_t load transpose the LDS type is always BDataType
using ADataTypeLDS =
std::conditional_t<std::is_same_v<typename Problem::ADataType, pk_int4_t>,
typename Problem::BDataType,
typename Problem::ADataType>;
auto a_lds_shape = []() {
if constexpr(is_a_load_tr)
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
@@ -258,11 +250,16 @@ struct GemmPipelineAgBgCrImplBase
auto a_lds_load_tile_distr = []() {
if constexpr(is_a_load_tr)
{
return make_static_tile_distribution(
typename InputTileDistributionTraits<typename ALdsLoadTileDistr::DstrEncode,
ADataTypeLDS>::TransposedDstrEncode{});
typename InputTileDistributionTraits<
typename ALdsLoadTileDistr::DstrEncode,
typename ALdsTensorView::DataType>::TransposedDstrEncode{});
}
else
{
return ALdsLoadTileDistr{};
}
}();
auto a_lds_gemm_window =
@@ -333,18 +330,18 @@ struct GemmPipelineAgBgCrImplBase
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
using BLdsDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
typename Problem::ADataType,
typename Problem::BDataType>;
auto b_lds_load_tile_distr = []() {
if constexpr(is_b_load_tr)
{
return make_static_tile_distribution(
typename InputTileDistributionTraits<typename BLdsLoadTileDistr::DstrEncode,
BLdsDataType>::TransposedDstrEncode{});
typename InputTileDistributionTraits<
typename BLdsLoadTileDistr::DstrEncode,
typename BLdsTensorView::DataType>::TransposedDstrEncode{});
}
else
{
return BLdsLoadTileDistr{};
}
}();
auto b_lds_gemm_window =

View File

@@ -127,7 +127,6 @@ struct UniversalGemmBasePolicy
using ADataType = OverrideADataType;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = Derived::template GetSmemPackA<Problem>();
if constexpr(is_a_load_tr<Problem>)
{
@@ -261,6 +260,7 @@ struct UniversalGemmBasePolicy
}
else // A is in RowMajor
{
constexpr index_t KPack = Derived::template GetSmemPackA<Problem>();
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr uint64_t MinLdsLayer = 1ULL;
constexpr auto MLdsLayer =

View File

@@ -4,7 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
@@ -627,8 +627,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// // Prefetch A0
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
// Prefill A0
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
@@ -652,7 +651,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
do
{
{
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
Base::GlobalPrefetch(
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
Base::GlobalPrefetch(
@@ -666,7 +665,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
HotLoopScheduler();
}
{
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
Base::GlobalPrefetch(
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
Base::GlobalPrefetch(
@@ -687,7 +686,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
if constexpr(TailNum == TailNumber::Even)
{
{
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
Base::GlobalPrefetch(
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
block_weight_preshuffle(

View File

@@ -34,7 +34,7 @@
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -238,7 +238,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg : public BlockGemmQuantBase
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize>(
load_and_convert_tile<UnaryOpSize>(
a_warp_tensor(number<AwarpIter>{}),
a_warp_windows(number<AmIter>{})(number<AkIter>{}));
}

View File

@@ -268,10 +268,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
bool_constant<BLoadTranspose> = {})
{
// If A/B datatype were pkint4/pkfp4 it would be converted prior to storing in LDS
load_int4_tile<OverrideADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
load_int4_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
}
// C += A * B

View File

@@ -248,10 +248,8 @@ struct AQuantBlockUniversalGemmAsBsCr
// while ADatatype might not be the same as BDataType at the time of problem
// initialization, we can safely use BDataType here because when A would be int4 we will
// ensure A is converted to BDataType prior to loading
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
}
// C += A * B
@@ -395,10 +393,8 @@ struct AQuantBlockUniversalGemmAsBsCr
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_lds_gemm_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_lds_gemm_window);
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_lds_gemm_window);
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_lds_gemm_window);
}
// C += A * B with quantization support

View File

@@ -239,11 +239,9 @@ struct BQuantBlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
// If B datatype were pkint4 it would be converted prior to storing in LDS
load_int4_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
}
// Load from LDS and scale (then the tile can directly be consumed in the block gemm)

View File

@@ -202,20 +202,16 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
CK_TILE_DEVICE static void LoadAndConvertATile(ABlockTile_& a_block_tile,
const ADramWindow& a_dram_window)
{
using DestDataType = typename ABlockTile_::DataType;
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
load_and_convert_tile<UnaryOpSize>(a_block_tile, a_dram_window);
}
template <typename BDramWindow, typename BBlockTile_>
CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile,
const BDramWindow& b_dram_window)
{
using DestDataType = typename BBlockTile_::DataType;
using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
load_and_convert_tile<UnaryOpSize>(b_block_tile, b_dram_window);
}
template <bool HasHotLoop,

View File

@@ -178,10 +178,8 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
ADramWindow& a_dram_window,
const DramTileWindowStep& dram_tile_window_step)
{
using DestDataType = typename ABlockTile_::DataType;
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
load_and_convert_tile<UnaryOpSize>(a_block_tile, a_dram_window);
move_tile_window(a_dram_window, dram_tile_window_step);
}

View File

@@ -174,10 +174,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
ADramWindow& a_dram_window,
const DramTileWindowStep& dram_tile_window_step)
{
using DestDataType = typename ABlockTile_::DataType;
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
load_and_convert_tile<UnaryOpSize>(a_block_tile, a_dram_window);
move_tile_window(a_dram_window, dram_tile_window_step);
}

View File

@@ -185,10 +185,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile,
const BDramWindow& b_dram_window)
{
using DestDataType = typename BBlockTile_::DataType;
using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType;
constexpr index_t UnaryOpSize = 8;
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
load_and_convert_tile<UnaryOpSize>(b_block_tile, b_dram_window);
}
template <typename BBlockTile_, typename BDramWindow, typename BDramTileWindowStep>

View File

@@ -7,7 +7,7 @@
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp"
@@ -373,8 +373,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
});
// move B window to next flat K
@@ -413,8 +413,6 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
block_sync_lds();
// preload A00,A10 from lds
using ATypeToUse =
mixed_prec_compute_type_from_input_t<ADataType, BDataType, ComputeDataType>;
using ATileType =
decltype(make_static_distributed_tensor<BTypeToUse>(a_warp_tile_distribution));
statically_indexed_array<ATileType, m_preload> a_warp_tensor;
@@ -422,7 +420,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
load_and_convert_tile<UnaryOpSize_>(
a_warp_tensor(loadIter), a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0);
@@ -456,8 +454,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -468,7 +466,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
load_and_convert_tile<UnaryOpSize_>(
a_warp_tensor(loadIter), a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
@@ -481,8 +479,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -511,7 +509,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
load_and_convert_tile<UnaryOpSize_>(
a_warp_tensor(loadIter), a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
iCounter--;
@@ -529,8 +527,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
});
aq_block_tile_2 = load_tile(aq_copy_dram_window);
@@ -551,7 +549,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
load_and_convert_tile<UnaryOpSize_>(
a_warp_tensor(loadIter), a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});

View File

@@ -344,8 +344,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
});
// move B window to next flat K
@@ -430,8 +430,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -467,8 +467,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
});
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
@@ -525,8 +525,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * flatNPerWarp, kIter * flatKPerWarp});
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
b_flat_dram_windows(nIter)(kIter));
});
});
bq_block_tile_2 = load_tile(bq_copy_dram_window);

View File

@@ -12,7 +12,7 @@
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -6,7 +6,7 @@
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -9,7 +9,7 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -6,7 +6,7 @@
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -5,7 +5,7 @@
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -7,7 +7,7 @@
#include "ck_tile/ops/pooling/pipeline/pool_problem.hpp"
#include "ck_tile/ops/pooling/pipeline/pool_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -14,7 +14,7 @@
#include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -10,7 +10,7 @@
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -9,7 +9,7 @@
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -5,7 +5,7 @@
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -7,7 +7,7 @@
#include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp"
#include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -5,7 +5,7 @@
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -7,7 +7,7 @@
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"