diff --git a/CHANGELOG.md b/CHANGELOG.md index 04ba0283ab..c914224bb3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## (Unreleased) Composable Kernel 1.3.0 ### Added +* Added overload of load_tile_transpose that takes reference to output tensor as output parameter +* Use data type from LDS tensor view when determining tile distribution for transpose in the GEMM pipeline * Added preshuffleB support for abquant mode in blockscale GEMM. * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index af0f81e832..d1c06d4378 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -48,19 +48,19 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, * and an elementwise function. For each A = A0, A1… AN, the elementwise function * is additionally applied during a single read. */ -template -CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, +CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple& tile_windows, ElementWise_ elementwise, number = {}, bool_constant = {}) { - // TODO: Tile windows should works with unknow number of params - // Load element_wise API works only when the input typle is a tuple-tyupe - return tile_window[number<0>{}].load( - tile_window, elementwise, number{}, bool_constant{}); + // TODO: Tile windows should work with unknown number of params + // Load element_wise API works only when the input type is a tuple-type + return tile_windows[number<0>{}].load( + tile_windows, elementwise, number{}, bool_constant{}); } // Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. @@ -85,12 +85,12 @@ template -CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, +CK_TILE_DEVICE void load_tile(DistributedTensor_& dst_tile, const TileWindow_& tile_window, number = {}, bool_constant = {}) { - return tile_window.load(dst_tile, number{}, bool_constant{}); + tile_window.load(dst_tile, number{}, bool_constant{}); } /** @@ -131,7 +131,7 @@ template -CK_TILE_DEVICE auto load_tile_raw(T& tile, +CK_TILE_DEVICE void load_tile_raw(T& tile, const tile_window_linear::distr_encoding_valid, Policy>> -CK_TILE_DEVICE auto load_tile_transpose_with_offset( +CK_TILE_DEVICE void load_tile_transpose_with_offset( + DistributedTensor_& out_tensor, const tile_window_with_static_distribution& __restrict__ tile_window, index_t offset) { + auto trans_tensor = tile_window.template load_transpose_with_offset(offset); + constexpr auto input_distr = TileDistribution_{}; + constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{}; + + // Check that the tile distribution of out_tensor is the expected one for transposed loads. using OutTileDstrEncode = typename OutputTileDistributionTraits< typename TileDistribution_::DstrEncode, typename BottomTensorView_::DataType>::TransposedDstrEncode; - auto out_tensor = make_static_distributed_tensor( - make_static_tile_distribution(OutTileDstrEncode{})); - auto trans_tensor = tile_window.template load_transpose_with_offset(offset); - constexpr auto input_distr = TileDistribution_{}; - constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{}); + static_assert(std::is_same_v>); + + // Check that the datatype of out_tensor matches that of the bottom tensor view. + static_assert(std::is_same_v); constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor(); constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor(); @@ -443,8 +452,6 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( number{}, trans_tensor.get_thread_buffer().template get_as(number{})); }); - - return out_tensor; } /** @@ -456,6 +463,7 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( * element space size and vector length remain consistent between the input and output * distributions. * + * @tparam DistributedTensor_ The type of the tensor containing the transposed tile data. * @tparam BottomTensorView_ The type of the bottom tensor view. * @tparam WindowLengths_ The type representing the window lengths. * @tparam TileDistribution_ The type representing the tile distribution. @@ -463,16 +471,37 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( * @tparam Policy The transpose policy to use (defaults to DefaultTranspose). * the last is SFINAE to ensure the tile distribution encoding is valid. * + * @param out_tensor A statically distributed tensor containing the transposed tile + * data. * @param tile_window The tile window with static distribution to load and transpose. * indexing. * - * @return A statically distributed tensor containing the transposed tile data. - * * @note * - The function uses compile-time checks to ensure the input and output tile distributions * are compatible in terms of element space size and vector length. * - The transpose operation is performed according to the specified Policy. */ +template < + typename DistributedTensor_, + typename BottomTensorView_, + typename WindowLengths_, + typename TileDistribution_, + index_t NumCoord, + typename Policy = DefaultTranspose, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE void +load_tile_transpose(DistributedTensor_& out_tensor, + const tile_window_with_static_distribution& __restrict__ tile_window) +{ + load_tile_transpose_with_offset(out_tensor, tile_window, 0); +} + template < typename BottomTensorView_, typename WindowLengths_, @@ -489,7 +518,15 @@ load_tile_transpose(const tile_window_with_static_distribution& __restrict__ tile_window) { - return load_tile_transpose_with_offset(tile_window, 0); + using OutTileDstrEncode = typename OutputTileDistributionTraits< + typename TileDistribution_::DstrEncode, + typename BottomTensorView_::DataType>::TransposedDstrEncode; + auto out_tensor = make_static_distributed_tensor( + make_static_tile_distribution(OutTileDstrEncode{})); + + load_tile_transpose_with_offset(out_tensor, tile_window, 0); + + return out_tensor; } } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 2f2fe12f42..3e28544509 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -190,11 +190,11 @@ struct tile_window_with_static_distribution * The same thread, during vectorized reading, accesses the same set of * data from A0, A1, A2, … AN. */ - template - CK_TILE_DEVICE auto load(const TileWindow_& tile_window, + CK_TILE_DEVICE auto load(const ck_tile::tuple& tile_windows, ElementWise_ elementwise, number = {}, bool_constant = {}) const @@ -202,7 +202,7 @@ struct tile_window_with_static_distribution constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); load(dst_tensor, - tile_window, + tile_windows, elementwise, number{}, bool_constant{}); @@ -210,12 +210,12 @@ struct tile_window_with_static_distribution } template CK_TILE_DEVICE void load(DistributedTensor& dst_tensor, - const TileWindow_& tile_window, + const ck_tile::tuple& tile_windows, ElementWise_ elementwise, number = {}, bool_constant = {}) const @@ -226,14 +226,14 @@ struct tile_window_with_static_distribution using SFC_Ys = typename Traits::SFC_Ys; constexpr auto tile_dstr = typename Base::TileDstr{}; - constexpr auto sizeOfTuple = TileWindow_::size(); + constexpr auto sizeOfTuple = remove_cvref_t::size(); // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { /// TODO: use structure binding (to be captured later) if compiled in C++20 auto window_adaptor_thread_coord = - tile_window[number<0>{}].pre_computed_coords_[iCoord][I0]; + tile_windows[number<0>{}].pre_computed_coords_[iCoord][I0]; auto bottom_tensor_thread_coord = - tile_window[number<0>{}].pre_computed_coords_[iCoord][I1]; + tile_windows[number<0>{}].pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { constexpr auto iAccess = number{}; @@ -244,7 +244,7 @@ struct tile_window_with_static_distribution // read from bottom tensor const auto idx_vec_value = generate_tuple( [&](auto jj) { - return tile_window[number{}] + return tile_windows[number{}] .get_bottom_tensor_view() .template get_vectorized_elements( bottom_tensor_thread_coord, diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index 00234b20cf..aa0f632c21 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -8,7 +8,7 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_contraction.hpp b/include/ck_tile/ops/batched_contraction.hpp index 45fa52e505..9c90db67ed 100644 --- a/include/ck_tile/ops/batched_contraction.hpp +++ b/include/ck_tile/ops/batched_contraction.hpp @@ -6,7 +6,7 @@ #include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" #include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index b23e45c233..9cac035c44 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -11,7 +11,7 @@ #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 94243e674f..ad7da5c183 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -3,7 +3,7 @@ #pragma once #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp similarity index 58% rename from include/ck_tile/ops/common/load_interleaved_pk_type.hpp rename to include/ck_tile/ops/common/load_and_convert_tile.hpp index 3f1a3b8f1c..0748c5fb49 100644 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ b/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -5,22 +5,20 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" namespace ck_tile { template -struct InterleavedPKTypeLoader +struct ConverterLoader { template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, - const WarpWindow& warp_window) + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src) { - const element_wise::PassThroughPack8 elementwise_op{}; - static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); + const auto tmp = load_tile(src); // NOTE: we rely on types packing neatly here using RawSrcType = typename SrcDataType::type; @@ -29,29 +27,28 @@ struct InterleavedPKTypeLoader using SrcVectorType = ext_vector_t; using DstVectorType = ext_vector_t; static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); + const element_wise::PassThroughPack8 elementwise_op{}; + + elementwise_op(dst.get_thread_buffer().template get_as()(i), + tmp.get_thread_buffer().template get_as()[i]); }); } }; -template -CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) +template +CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src) { - if constexpr(is_packed_type_v) + using SrcDataType = typename WarpWindow::Base::DataType; + using DstDataType = typename WarpTile::DataType; + + if constexpr(is_packed_type_v && !is_packed_type_v) { static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t"); - InterleavedPKTypeLoader::load_interleaved_pk_type( - dst, src); + ConverterLoader::load_interleaved_pk_type(dst, src); } else if constexpr(LoadTranspose) { - dst = load_tile_transpose(src); + load_tile_transpose(dst, src); } else { diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 5752703ab6..bc72f3b0ba 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -9,7 +9,7 @@ #include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 433462b22e..d1b38a8bca 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -11,7 +11,7 @@ #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 2d3a819e80..e08fac48c7 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -22,7 +22,7 @@ #include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index eb4aa16d05..0639fa1b36 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -61,7 +61,7 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 9aeabaa8c2..16212c0d13 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -530,7 +530,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR s_acc = gemm_0(q_reg_tensor, k_reg_tensor); dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr); - dot_reg_tensor = load_tile_transpose(dot_lds_read_window); + load_tile_transpose(dot_reg_tensor, dot_lds_read_window); } if constexpr(is_epilogue) { @@ -634,7 +634,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr); - qt_reg_tensor = load_tile_transpose(qt_lds_read_window); + load_tile_transpose(qt_reg_tensor, qt_lds_read_window); // STAGE 3, P^T@OGrad^T Gemm1 auto pt_reg_tensor = make_static_distributed_tensor( @@ -715,7 +715,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR } if constexpr(is_epilogue) { - ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } if constexpr(is_main_body) @@ -728,7 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR static_for<0, k4_loops, 1>{}([&](auto i_k4) { if constexpr(i_k4 < k4_loops - 1) { - ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } auto kt_reg_tensor_slice = get_slice_tile( // diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 3d21928ced..37b4ae41a3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -455,10 +455,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR async_load_tile(q_lds_write_window, q_dram_window); async_load_tile(do_lds_write_window, do_dram_window); __builtin_amdgcn_s_waitcnt(0); - qt_reg_tensor = load_tile_transpose(qt_lds_read_window); - q_reg_tensor = load_tile(q_lds_read_window); - dot_reg_tensor = load_tile_transpose(dot_lds_read_window); - do_reg_tensor = load_tile(do_lds_read_window); + load_tile_transpose(qt_reg_tensor, qt_lds_read_window); + q_reg_tensor = load_tile(q_lds_read_window); + load_tile_transpose(dot_reg_tensor, dot_lds_read_window); + do_reg_tensor = load_tile(do_lds_read_window); lse_block_tile = load_tile(lse_dram_window); d_block_tile = load_tile(d_dram_window); @@ -490,9 +490,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR async_load_tile(v_lds_write_window, v_dram_window); move_tile_window(v_dram_window, {kN0, 0}); s_waitcnt(); - k_reg_tensor = load_tile(k_lds_read_window); - v_reg_tensor = load_tile(v_lds_read_window); - kt_reg_tensor = load_tile_transpose(kt_lds_read_window); + k_reg_tensor = load_tile(k_lds_read_window); + v_reg_tensor = load_tile(v_lds_read_window); + load_tile_transpose(kt_reg_tensor, kt_lds_read_window); } if constexpr(is_epilogue) { @@ -668,7 +668,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR block_sync_lds(); if constexpr(is_epilogue) { - ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } if constexpr(is_main_body) @@ -680,7 +680,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR static_for<0, k4_loops, 1>{}([&](auto i_k4) { if constexpr(i_k4 < k4_loops - 1) { - ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } auto kt_reg_tensor_slice = get_slice_tile( // diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index c25f57632f..4cca604ff1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -718,7 +718,7 @@ struct BlockFmhaFwdV3Pipeline }; auto V_lds_load = [&](auto v_lds_read_idx) { - kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx)); + load_tile_transpose(kv_tile.v_tile, v_lds_window_load(v_lds_read_idx)); }; decltype(m) m_old; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index e9ed9ac072..c0d5ca291f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -582,7 +582,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload // loop over along the [V]alue Sequence length move_tile_window(v_lds_read_window, {kK1, 0}); - v_tile = load_tile_transpose(v_lds_read_window); + load_tile_transpose(v_tile, v_lds_read_window); }); // move back to the origin move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index e6802e82dc..60f5bd1c4e 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -15,7 +15,7 @@ #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index b1681e07e4..f447ab4452 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -79,7 +79,7 @@ #include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 7f34ae24bb..f7f5cd33db 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" @@ -218,10 +218,8 @@ struct BlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_int4_tile(a_warp_tile_, - a_block_window); - load_int4_tile(b_warp_tile_, - b_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B @@ -290,9 +288,9 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; static constexpr auto ALdsTileDistr = - make_static_tile_distribution(MakeABlockDistributionEncode()); + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; static constexpr auto BLdsTileDistr = - make_static_tile_distribution(MakeBBlockDistributionEncode()); + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); @@ -349,10 +347,8 @@ struct BlockUniversalGemmAsBsCr auto b_lds_gemm_window = make_tile_window( b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); - load_int4_tile(a_warp_tile_, - a_lds_gemm_window); - load_int4_tile(b_warp_tile_, - b_lds_gemm_window); + load_and_convert_tile(a_warp_tile_, a_lds_gemm_window); + load_and_convert_tile(b_warp_tile_, b_lds_gemm_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 7cc14ecc39..2a0c09e41f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -79,9 +79,7 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } - template @@ -89,7 +87,7 @@ struct GemmPipelineAgBgCrImplBase SrcTileWindow& dram_tile_window, const DramTileWindowStep& dram_tile_window_step) const { - load_int4_tile(dst_block_tile, dram_tile_window); + load_and_convert_tile(dst_block_tile, dram_tile_window); move_tile_window(dram_tile_window, dram_tile_window_step); } @@ -124,7 +122,7 @@ struct GemmPipelineAgBgCrImplBase bool_constant = {}) const { if constexpr(LoadTranspose) - dst_block_tile = load_tile_transpose(lds_tile_window); + load_tile_transpose(dst_block_tile, lds_tile_window); else load_tile(dst_block_tile, lds_tile_window); } @@ -241,12 +239,6 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view, const ALdsLoadTileDistr&) const { - // with pk_int4_t load transpose the LDS type is always BDataType - using ADataTypeLDS = - std::conditional_t, - typename Problem::BDataType, - typename Problem::ADataType>; - auto a_lds_shape = []() { if constexpr(is_a_load_tr) return make_tuple(number{}, number{}); @@ -258,11 +250,16 @@ struct GemmPipelineAgBgCrImplBase auto a_lds_load_tile_distr = []() { if constexpr(is_a_load_tr) + { return make_static_tile_distribution( - typename InputTileDistributionTraits::TransposedDstrEncode{}); + typename InputTileDistributionTraits< + typename ALdsLoadTileDistr::DstrEncode, + typename ALdsTensorView::DataType>::TransposedDstrEncode{}); + } else + { return ALdsLoadTileDistr{}; + } }(); auto a_lds_gemm_window = @@ -333,18 +330,18 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); - using BLdsDataType = std::conditional_t; - auto b_lds_load_tile_distr = []() { if constexpr(is_b_load_tr) + { return make_static_tile_distribution( - typename InputTileDistributionTraits::TransposedDstrEncode{}); - + typename InputTileDistributionTraits< + typename BLdsLoadTileDistr::DstrEncode, + typename BLdsTensorView::DataType>::TransposedDstrEncode{}); + } else + { return BLdsLoadTileDistr{}; + } }(); auto b_lds_gemm_window = diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index cb112a11a7..1285cc8cee 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -127,7 +127,6 @@ struct UniversalGemmBasePolicy using ADataType = OverrideADataType; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = Derived::template GetSmemPackA(); if constexpr(is_a_load_tr) { @@ -261,6 +260,7 @@ struct UniversalGemmBasePolicy } else // A is in RowMajor { + constexpr index_t KPack = Derived::template GetSmemPackA(); constexpr auto DataTypeSize = sizeof(ADataType); constexpr uint64_t MinLdsLayer = 1ULL; constexpr auto MLdsLayer = diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index c9499106de..93999757b0 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" @@ -627,8 +627,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // // Prefetch A0 Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch( - b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); // Prefill A0 Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); @@ -652,7 +651,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 do { { - Base::template GlobalPrefetch( + Base::GlobalPrefetch( b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); Base::GlobalPrefetch( @@ -666,7 +665,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 HotLoopScheduler(); } { - Base::template GlobalPrefetch( + Base::GlobalPrefetch( b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); Base::GlobalPrefetch( @@ -687,7 +686,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 if constexpr(TailNum == TailNumber::Even) { { - Base::template GlobalPrefetch( + Base::GlobalPrefetch( b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); block_weight_preshuffle( diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index c2fe66ea5d..0cf4a331d0 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -34,7 +34,7 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp index 32c53d2f18..a068001482 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -238,7 +238,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg : public BlockGemmQuantBase constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - load_int4_tile( + load_and_convert_tile( a_warp_tensor(number{}), a_warp_windows(number{})(number{})); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 2c8b7031f5..24d9f9a1e5 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -268,10 +268,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase bool_constant = {}) { // If A/B datatype were pkint4/pkfp4 it would be converted prior to storing in LDS - load_int4_tile( - a_warp_tile_, a_block_window); - load_int4_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 22563da498..8b09530af1 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -248,10 +248,8 @@ struct AQuantBlockUniversalGemmAsBsCr // while ADatatype might not be the same as BDataType at the time of problem // initialization, we can safely use BDataType here because when A would be int4 we will // ensure A is converted to BDataType prior to loading - load_int4_tile( - a_warp_tile_, a_block_window); - load_int4_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B @@ -395,10 +393,8 @@ struct AQuantBlockUniversalGemmAsBsCr auto b_lds_gemm_window = make_tile_window( b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); - load_int4_tile( - a_warp_tile_, a_lds_gemm_window); - load_int4_tile( - b_warp_tile_, b_lds_gemm_window); + load_and_convert_tile(a_warp_tile_, a_lds_gemm_window); + load_and_convert_tile(b_warp_tile_, b_lds_gemm_window); } // C += A * B with quantization support diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 3af7177365..f5900fcdec 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -239,11 +239,9 @@ struct BQuantBlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_int4_tile( - a_warp_tile_, a_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); // If B datatype were pkint4 it would be converted prior to storing in LDS - load_int4_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // Load from LDS and scale (then the tile can directly be consumed in the block gemm) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index cb36d02aa5..573d76c51f 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -202,20 +202,16 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); } template CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile, const BDramWindow& b_dram_window) { - using DestDataType = typename BBlockTile_::DataType; - using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType; constexpr index_t UnaryOpSize = 8; - load_int4_tile(b_block_tile, b_dram_window); + load_and_convert_tile(b_block_tile, b_dram_window); } template ADramWindow& a_dram_window, const DramTileWindowStep& dram_tile_window_step) { - using DestDataType = typename ABlockTile_::DataType; - using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; constexpr index_t UnaryOpSize = 8; - load_int4_tile(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); move_tile_window(a_dram_window, dram_tile_window_step); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 76d8985fb1..cedc91d564 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -174,10 +174,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); move_tile_window(a_dram_window, dram_tile_window_step); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index df94eb7273..033a2ab073 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -185,10 +185,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_dram_window); + load_and_convert_tile(b_block_tile, b_dram_window); } template diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index a7a64518b8..f48e12984c 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -7,7 +7,7 @@ #include #include "ck_tile/core.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" @@ -373,8 +373,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); // move B window to next flat K @@ -413,8 +413,6 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe block_sync_lds(); // preload A00,A10 from lds - using ATypeToUse = - mixed_prec_compute_type_from_input_t; using ATileType = decltype(make_static_distributed_tensor(a_warp_tile_distribution)); statically_indexed_array a_warp_tensor; @@ -422,7 +420,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - load_int4_tile( + load_and_convert_tile( a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); __builtin_amdgcn_sched_barrier(0); @@ -456,8 +454,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -468,7 +466,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - load_int4_tile( + load_and_convert_tile( a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); @@ -481,8 +479,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -511,7 +509,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - load_int4_tile( + load_and_convert_tile( a_warp_tensor(loadIter), a_warp_windows_ping(number{})(number{})); }); iCounter--; @@ -529,8 +527,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); aq_block_tile_2 = load_tile(aq_copy_dram_window); @@ -551,7 +549,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - load_int4_tile( + load_and_convert_tile( a_warp_tensor(loadIter), a_warp_windows_pong(number{})(number{})); }); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index a49279585e..025ef53dbb 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -344,8 +344,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); // move B window to next flat K @@ -430,8 +430,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -467,8 +467,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -525,8 +525,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); bq_block_tile_2 = load_tile(bq_copy_dram_window); diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 6743e46613..eeb9b1d8a8 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -12,7 +12,7 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 1d33ebf39d..07d9989086 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -6,7 +6,7 @@ #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index ebb20aebf4..8f9ab205ac 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -9,7 +9,7 @@ #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index 469a98c256..eae0ea14a3 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -6,7 +6,7 @@ #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index 88a3d8a137..4d37f4fbc1 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -5,7 +5,7 @@ #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/pooling.hpp b/include/ck_tile/ops/pooling.hpp index 3e44122afa..faa77d5327 100644 --- a/include/ck_tile/ops/pooling.hpp +++ b/include/ck_tile/ops/pooling.hpp @@ -7,7 +7,7 @@ #include "ck_tile/ops/pooling/pipeline/pool_problem.hpp" #include "ck_tile/ops/pooling/pipeline/pool_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index 9e31b7bbe2..b5e53283e4 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -14,7 +14,7 @@ #include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index ad23a708b7..f271be5006 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -10,7 +10,7 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index 13372f3289..4c2fe9bee4 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -9,7 +9,7 @@ #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index 9cf3e08319..c79ba06abf 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -5,7 +5,7 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/sparse_attn.hpp b/include/ck_tile/ops/sparse_attn.hpp index 3ee643d729..c7c4171874 100644 --- a/include/ck_tile/ops/sparse_attn.hpp +++ b/include/ck_tile/ops/sparse_attn.hpp @@ -7,7 +7,7 @@ #include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" #include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index 090ad0919f..474ba93227 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -5,7 +5,7 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index 7afce1708b..066fbf5fee 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -7,7 +7,7 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp"