From 1d9320c8f3e39e3fb260b62794ba6a96276ab8cb Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Thu, 16 Oct 2025 20:33:56 +0200 Subject: [PATCH] Wave Tile Transfer supporting global load with transpose (#3027) * Initial implementation: - add new thread group transfer supporting transpose instruction - refactor AB transfer to switch between thread and wave tiles methods * Add some comments and remove explicit wave and lane calculations * Remove compiler option for performance * fp16 example: use tuned instance * Missing cleanup * Integrate wave transfer in existing gemm and batched gemm instances * Add fast instances * extend implementation for 8 bit datatypes packed types not supported * Address review comments * Optimize pipeline v1 and re-introduce compiler option * Disable wave tile approach for b scale gemm * Fix for clang20 * Avoid code duplication of amd_global_load_transpose_to_vgpr function [ROCm/composable_kernel commit: 440358c16851de74575798c539feca1b0be0799f] --- example/01_gemm/gemm_wmma_fp16_v3.cpp | 17 +- .../blockwise_gemm_pipeline_wmmaops_v1.hpp | 133 ++- ...ead_group_tensor_slice_transfer_global.hpp | 405 +++++++++ .../gridwise_ab_transfer_thread_tiles.hpp | 402 +++++++++ .../grid/gridwise_ab_transfer_wave_tiles.hpp | 343 +++++++ .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 12 +- ...gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp | 9 +- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 842 ++++-------------- include/ck/utility/amd_transpose_load.hpp | 37 + include/ck/utility/dynamic_buffer.hpp | 13 +- include/ck/utility/synchronization.hpp | 16 +- ...mm_wmma_universal_f16_f16_f16_km_kn_mn.hpp | 1 + ...mm_wmma_universal_f16_f16_f16_km_nk_mn.hpp | 1 + ...mm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp | 1 + ...mm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp | 1 + 15 files changed, 1513 insertions(+), 720 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp create mode 100644 include/ck/utility/amd_transpose_load.hpp diff --git a/example/01_gemm/gemm_wmma_fp16_v3.cpp b/example/01_gemm/gemm_wmma_fp16_v3.cpp index 7225dba721..7699364a7a 100644 --- a/example/01_gemm/gemm_wmma_fp16_v3.cpp +++ b/example/01_gemm/gemm_wmma_fp16_v3.cpp @@ -26,17 +26,18 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, PassThrough, PassThrough, PassThrough, GemmDefault, - 128, - 128, 64, - 64, 8, 8, + 256, + 128, 256, 64, + 8, 8, 16, 16, - 4, 2, - S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 2, 8, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, - S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, - 1, 1, S<1, 32, 1, 4>, 8, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>; + 1, 1, + S<1, 64, 1, 4>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index 76d748eb27..87ccc7c5e0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -116,6 +116,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1; using Base::I0; + using Base::I1; + using Base::WaveSize; + using typename Base::HotLoopInstList; using Base::A_K1; using Base::A_KRow; @@ -213,38 +216,42 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, m0, I0, I0, I0, I0), a_block_buf, a_thread_desc_, - make_tuple(I0, m0, k0, I0, I0, I0), + make_tuple(I0, I0, I0, I0, I0, I0), a_thread_buf); - }); - if constexpr(ck::is_same::value == true) - { - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, k0, I0, I0, I0), - b_thread_buf); - }); - } - else - { - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_scale_struct.b_scale_thread_bufs( - I0)[Number{}], - b_thread_desc_, - make_tuple(I0, n0, k0, I0, I0, I0), - b_thread_buf); - }); - } - static_for<0, MRepeat, 1>{}([&](auto m0) { + if constexpr(m0 == I0) + { + if constexpr(ck::is_same::value == true) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple( + Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple( + Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + }); + } + } + static_for<0, NRepeat, 1>{}([&](auto n0) { vector_type a_thread_vec; vector_type b_thread_vec; @@ -252,12 +259,12 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto ik) { a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}, m0, k0, I0, I0, Number{}))>{}]; + Number{}, I0, I0, I0, I0, Number{}))>{}]; }); static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}, n0, k0, I0, I0, Number{}))>{}]; + Number{}, n0, I0, I0, I0, Number{}))>{}]; }); using wmma_input_type_a = @@ -296,6 +303,32 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + static_for<0, KRepeat, 1>{}([&](auto) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + if constexpr(m0 == I0) + { + static_for<0, NRepeat, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + } + static_for<0, NRepeat, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + }); + }); + static_for<0, num_ds_write_inst, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + }); + i += 1; } while(i < (num_loop - 1)); } @@ -309,10 +342,38 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, I1, I1, I1, I1, Number{})); + + // B[NRepeat, N1, N2, KPack] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, I1, Number{})); + + using AThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + using BThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; using Base::c_thread_desc_; }; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp new file mode 100644 index 0000000000..a74358d4dc --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp @@ -0,0 +1,405 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/functional2.hpp" +#include "ck/utility/dtype_vector.hpp" +#include "ck/utility/type_convert.hpp" +#include "ck/utility/amd_address_space.hpp" +#include "ck/utility/dynamic_buffer.hpp" +#include "ck/tensor/static_tensor.hpp" + +namespace ck { + +template +struct ThreadGroupTransferGlobal +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + using Index = MultiIndex; + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + __device__ ThreadGroupTransferGlobal(const SrcDesc& src_desc, + const DstDesc& dst_desc, + const Index& src_block_slice_origin, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_block_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_block_slice_origin)), + element_op_(element_op) + { + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const GridBufferType& grid_buf) + { + constexpr auto src_access_lengths = NumberOfIterations{}; + constexpr auto src_dim_access_order = IterationOrder{}; + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + constexpr auto ordered_fwd_step = StepsPerIteration{}; + + // make forward steps + // forward step for each iteration just add 1 + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? ordered_fwd_step[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + // backward step at the end of the dimension iteration subtract IterationLength - 1 + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) + ? (-src_access_lengths[i] + 1) * ordered_fwd_step[i] + : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + // Take condition for bwd and negate + // condition for bwd: dimension index is the last of iteration and + // all dimension indices of higher dimensions (inner loops) + // are the last of their iteration + static_for<0, nDim, 1>{}([&](auto i) { + bool tmp = ordered_src_access_idx[i] == ordered_src_access_lengths[i] - 1; + static_for{}([&](auto j) { + tmp &= ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + forward_sweep_(i) = !tmp; + }); + return forward_sweep_; + }(); + + // check for each dimension, if it needs to be moved (either fwd or bwd) + constexpr auto move_on_dim = [&]() constexpr { + StaticallyIndexedArray move_on_dim_; + + // forward condition + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + // backward condition + static_for<0, nDim, 1>{}([&](auto i) { + bool tmp = ordered_src_access_idx[i] == ordered_src_access_lengths[i] - 1 && + ordered_src_access_idx[i] > 0; + static_for{}([&](auto j) { + tmp &= ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + move_on_dim_(i) |= tmp; + }); + + return move_on_dim_; + }(); + + // calculate src data index and make sequence + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}( + [&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order); + }(); + + // make sequence to access vgpr data. Add zero as last element of src_data_idx_seq + constexpr auto vgpr_data_idx_seq = generate_sequence_v2( + [&](auto i) { + if constexpr(i.value < src_data_idx.Size()) + { + return Number{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + // check if src element is valid + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // Vector length of elementwise operation + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack8_invocable) + return math::min(8, VectorSize); + } + else if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack4_invocable) + return math::min(4, VectorSize); + } + else if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack2_invocable) + return math::min(2, VectorSize); + } + else + { + return 1; + } + }; + + // This is 1 for pass through because internally it's doing type conversion + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + using src_vector_container = vector_type_maker_t; + using src_vector_container_t = typename src_vector_container::type; + + using elem_op_vec_t = typename vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + using vector_t = typename vector_type_maker::type::type; + + dst_vector_type op_r_v; + + // Load data from memory in src_vector first + src_vector_container src_vector = + src_vector_container{grid_buf.template Get( + src_coord_.GetOffset(), true)}; + + // apply the src elementwise op and convert to DstData under the hood if needed + static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) { + element_op_(op_r_v.template AsType()(idx), + src_vector.template AsType()[idx]); + }); + + // store result in dvgpr_ (static array holding loaded data). + // At this point data is already converted to DstData type and + // the elementwise operation has been applied + dvgpr_.template SetAsType( + vgpr_data_idx_seq, + is_src_valid ? op_r_v.template AsType()[I0] : vector_t(0)); + + // For each dimension move fwd, bwd or don't move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, BlockBufferType& dst_buf) + { + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + constexpr auto src_access_lengths = NumberOfIterations{}; + constexpr auto src_dim_access_order = IterationOrder{}; + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + constexpr auto ordered_fwd_step = StepsPerIteration{}; + + // make forward steps + // forward step for each iteration just add 1 + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? ordered_fwd_step[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + // backward step at the end of the dimension iteration subtract IterationLength - 1 + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) + ? (-src_access_lengths[i] + 1) * ordered_fwd_step[i] + : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + // Take condition for bwd and negate + // condition for bwd: dimension index is the last of iteration and + // all dimension indices of higher dimensions (inner loops) + // are the last of their iteration + static_for<0, nDim, 1>{}([&](auto i) { + bool tmp = ordered_src_access_idx[i] == ordered_src_access_lengths[i] - 1; + static_for{}([&](auto j) { + tmp &= ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + forward_sweep_(i) = !tmp; + }); + return forward_sweep_; + }(); + + // check for each dimension, if it needs to be moved (either fwd or bwd) + constexpr auto move_on_dim = [&]() constexpr { + StaticallyIndexedArray move_on_dim_; + + // forward condition + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + // backward condition + static_for<0, nDim, 1>{}([&](auto i) { + bool tmp = ordered_src_access_idx[i] == ordered_src_access_lengths[i] - 1 && + ordered_src_access_idx[i] > 0; + static_for{}([&](auto j) { + tmp &= ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + move_on_dim_(i) |= tmp; + }); + + return move_on_dim_; + }(); + + // calculate src data index and make sequence + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}( + [&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order); + }(); + + // make sequence to access vgpr data. Add zero as last element of src_data_idx_seq + constexpr auto vgpr_data_idx_seq = generate_sequence_v2( + [&](auto i) { + if constexpr(i.value < src_data_idx.Size()) + { + return Number{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + // store element from vgpr to dst buffer + dst_buf.template Set( + dst_coord_.GetOffset(), + true, + dvgpr_.template GetAsType(vgpr_data_idx_seq)); + + // For each dimension move fwd, bwd or don't move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + const auto adjusted_step = make_tensor_coordinate_step(src_desc, step); + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + private: + // descriptor of vgpr data + __device__ static constexpr auto GetThreadScratchDataDescriptor() + { + constexpr auto access_lengths_as_tuple = container_push_back( + sequence_to_tuple_of_number(NumberOfIterations{}), Number{}); + + return make_naive_tensor_descriptor_packed(access_lengths_as_tuple); + } + + static constexpr auto thread_data_scratch_desc_ = decltype(GetThreadScratchDataDescriptor()){}; + using ThreadScratchData = StaticTensorTupleOfVectorBuffer; + + ThreadScratchData dvgpr_; + SrcCoord src_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp new file mode 100644 index 0000000000..465952e285 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -0,0 +1,402 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_address_space.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" + +namespace ck { + +template +struct ABTransferThreadTiles +{ + static constexpr auto ABK0Number = Number{}; + static constexpr auto ABK1Number = Number{}; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t ABPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + using ThisThreadBlock = ThisThreadBlock; + + template + __host__ __device__ static auto MakeGridDescriptor(const GridDescriptorBase& ab_grid_desc, + index_t MN, + index_t MNPad, + index_t K, + index_t KPad, + index_t StrideAB, + index_t ABK0) + { + + if constexpr(PadMN && PadK) + { + // pad both MN and K + const auto ab_grid_desc_n_k = + transform_tensor_descriptor(ab_grid_desc, + make_tuple(make_right_pad_transform(MN, MNPad - MN), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + ab_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(ABK0, ABK1Value)), + make_pass_through_transform(MNPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return ab_grid_desc_bk0_n_bk1; + } + else if constexpr(PadMN && !PadK) + { + // pad MN, but not K + const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + ab_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(ABK0, ABK1Value)), + make_right_pad_transform(MN, MNPad - MN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return ab_grid_desc_bk0_n_bk1; + } + else if constexpr(!PadMN && PadK) + { + // pad K, but not MN + const auto ab_grid_desc_n_k = transform_tensor_descriptor( + ab_grid_desc, + make_tuple(make_pass_through_transform(MN), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + ab_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(ABK0, ABK1Value)), + make_pass_through_transform(MN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return ab_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteAB) + { + // not pad MN or K + const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + ab_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(ABK0, ABK1Value)), + make_pass_through_transform(MN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return ab_grid_desc_bk0_n_bk1; + } + else + { + // Pre-shuffled Weight + // BGlobal[K / KPerBlock, MN, KPerBlock / K1, K1] -> BTile[K / K1, MN, K1] + constexpr index_t ABK01 = KPerBlock / ABK1Value; + const index_t ABK0_ = StrideAB / ABK1Value; + const index_t ABK00 = ABK0_ / ABK01; + + const auto ab_grid_desc_abk00_mn_abk01_abk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(ABK00, MN, ABK01, ABK1Value)); + + const auto ab_grid_desc_abk0_mn_abk1_permute = transform_tensor_descriptor( + ab_grid_desc_abk00_mn_abk01_abk1_permute, + make_tuple(make_merge_transform(make_tuple(ABK00, ABK01)), + make_pass_through_transform(make_tuple(MN)), + make_pass_through_transform(ABK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return ab_grid_desc_abk0_mn_abk1_permute; + } + } + } + + __device__ static constexpr auto GetBlockDescriptor() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(UseBlockPaddingAB) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(ABK0Number, Number{}, ABK1Number), + make_tuple(Number{} * ABK1Number, ABK1Number, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeAB) / ABPackedSize; + constexpr auto MNLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto ab_lds_block_desc = make_naive_tensor_descriptor( + make_tuple(ABK0Number * Number{}, + Number{}, + ABK1Number), + make_tuple(ABK1Number, Number{}, I1)); + + constexpr auto ab_lds_block_desc_permuted = transform_tensor_descriptor( + ab_lds_block_desc, + make_tuple( + make_xor_with_modulo_transform(make_tuple(Number{}, + Number{})), + make_pass_through_transform(ABK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto ab_lds_block_desc_abk0_mnldslayer_mn_abk1 = transform_tensor_descriptor( + ab_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(ABK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(ABK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto ab_lds_block_desc_abk0_mn_abk1 = transform_tensor_descriptor( + ab_lds_block_desc_abk0_mnldslayer_mn_abk1, + make_tuple(make_pass_through_transform(ABK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(ABK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return ab_lds_block_desc_abk0_mn_abk1; + } + else + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto MN0 = ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1{}.At(I1); + constexpr auto MN1 = MNPerBlock / MN0; + + constexpr auto KThreadWrite = ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1{}.At(I0); + constexpr auto K0PerThreadWrite = ABK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MNPerWmma; + constexpr auto K0PerThreadRead = ABK0Number / KThreadRead; + + constexpr auto kfold = (ABK1Number * MN0 * sizeof(LDSTypeAB) > 128) + ? 1 + : 128 / (ABK1Number * MN0 * sizeof(LDSTypeAB)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (ABK1Number * MNPerWmma * sizeof(LDSTypeAB) > 128) + ? 1 + : ((128 / (ABK1Number * MNPerWmma * sizeof(LDSTypeAB))) > MN0 + ? MN0 + : 128 / (ABK1Number * MNPerWmma * sizeof(LDSTypeAB))); + + constexpr auto ab_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + ABK1Number)); + + constexpr auto ab_lds_block_desc_permuted = transform_tensor_descriptor( + ab_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(ABK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto ab_lds_block_desc_unmerged = transform_tensor_descriptor( + ab_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(ABK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto ab_lds_block_desc_abk0_mn_abk1 = transform_tensor_descriptor( + ab_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(ABK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return ab_lds_block_desc_abk0_mn_abk1; + } + } + + template + __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor, + BlockDescriptor& block_descriptor, + ABElementwiseOperation& ab_element_op, + const index_t block_mn_id) + { + constexpr index_t NumABTensor = ABsDataType::Size(); + const index_t mn_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_mn_id * MNPerBlock); + // workaround because v7r2 is not as general as v4r1 + if constexpr(NumABTensor > 1) + { + const auto idx_as_block_begin = generate_tuple( + [&](auto) { return make_multi_index(0, mn_block_data_idx_on_grid, 0); }, + Number{}); + + return ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + ABsDataType, + Tuple, + GridDescriptor, + decltype(tie(block_descriptor)), + ABElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1, + ABBlockTransferThreadClusterArrangeOrder, + ABBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABBlockTransferSrcVectorDim, + 2, + ABBlockTransferSrcScalarPerVector, + ABBlockTransferDstScalarPerVector_ABK1, + uniform_sequence_gen_t, + Sequence, + GlobalBufferNum>{grid_descriptor, + idx_as_block_begin, + tie(block_descriptor), + make_tuple(make_multi_index(0, 0, 0)), + ab_element_op}; + } + else + { + return ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + ABElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1, + ABBlockTransferThreadClusterArrangeOrder, + remove_cvref_t>, + remove_cvref_t>, + decltype(grid_descriptor[I0]), + decltype(block_descriptor), + ABBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABBlockTransferSrcVectorDim, + 2, + ABBlockTransferSrcScalarPerVector, + ABBlockTransferDstScalarPerVector_ABK1, + 1, + 1, + ABThreadTransferSrcResetCoordinateAfterRun, + true, + GlobalBufferNum>(grid_descriptor[I0], + make_multi_index(0, mn_block_data_idx_on_grid, 0), + ab_element_op, + block_descriptor, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + } + } + + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor() + { + // This is a block descriptor used to read LDS memory into register + // It's defined in a way consistent with the existing implementation to + // avoid changes in the pipelines + using BlockDesc = decltype(GetBlockDescriptor()); + // ABK0_MN_ABK1 -> ABK0_MNRepeat_MNWaves_KRow_MNPerWmma_ABK1 + constexpr auto ABK0 = BlockDesc{}.GetLength(I0); + constexpr auto ABK1 = BlockDesc{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto KRow = I2; +#else + constexpr auto KRow = I1; +#endif + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, KRow)), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + + __device__ static constexpr auto GetBlockStep() + { + // Grid descriptor step (MoveSrcSliceWindow) + return make_multi_index(KPerBlock / ABK1Number, 0, 0); + } + + template + __device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc) + { + // K dimension size. This should always be called with the A matrix grid descriptor + // because it doesn't work for B matrix when packed int4 is used + return grid_desc.GetLength(I0) * grid_desc.GetLength(I2); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp new file mode 100644 index 0000000000..68476ef3bf --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_address_space.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +template +struct ABTransferWaveTiles +{ + static_assert(!(is_same_v, pk_i4_t>), + "wave tile transfer method does not support pk_i4_t"); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t MNKRow = 2; + + using ThisThreadBlock = ThisThreadBlock; + + // Tiles distribution for global memory loading + // Notes: support for not power of 2 needs to be reviewed later on + // The tiles are distributed along the non-contiguous matrix dimension + // Example 4 waves A row-major MPerBlock = 64, KPerBlock = 64 + // MRepeat = 1, KRepeat = 4 + // ------------- + // |W0| | | | + // ------------- + // |W1| | | | + // ------------- + // |W2| | | | + // ------------- + // |W3| | | | + // ------------- + // Example 4 waves A column-major MPerBlock = 64, KPerBlock = 64 + // MRepeat = 4, KRepeat = 1 + // ------------- + // |W0|W1|W2|W3| + // ------------- + // | | | | | + // ------------- + // | | | | | + // ------------- + // | | | | | + // ------------- + static constexpr index_t NumberOfWaves = BlockSize / WaveSize; + static constexpr index_t MNMajorWaves_ = + MNPerBlock / MNPerWmma % std::min(MNPerBlock / MNPerWmma, NumberOfWaves) == 0 + ? std::min(MNPerBlock / MNPerWmma, NumberOfWaves) + : (MNPerBlock / MNPerWmma % 2 == 0 ? 2 : 1); + static constexpr index_t KMajorWaves_ = + KPerBlock / KPack % std::min(KPerBlock / KPack, NumberOfWaves) == 0 + ? std::min(KPerBlock / KPack, NumberOfWaves) + : (KPerBlock / KPack % 2 == 0 ? 2 : 1); + + static constexpr bool ABDoTranspose = !is_same_v; + + static constexpr index_t MNWaves_ = + ABDoTranspose ? NumberOfWaves / KMajorWaves_ : MNMajorWaves_; + static constexpr index_t KWaves_ = ABDoTranspose ? KMajorWaves_ : NumberOfWaves / MNMajorWaves_; + static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack); + static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma); + + template + __host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc, + index_t sizeMN, + index_t, + index_t sizeK, + index_t, + index_t, + index_t) + { + // Notes: padding is currently not supported + static_assert(!PadMN && !PadK, "padding is currently not supported"); + + // Divide the base descriptor MN_K into tiles + const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor( + base_desc, + make_tuple( + make_unmerge_transform(make_tuple( + math::integer_divide_ceil(sizeMN, Number{}), Number{})), + make_unmerge_transform(make_tuple(math::integer_divide_ceil(sizeK, Number{}), + Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + // The distinction is needed to get the same global indices for both layouts + // Divide each tile in 2 16x8 subtile + // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize + // MNKRow = 0-1 + // LaneLocal = 0-15 + // VectorSize must be 8 + if constexpr(!ABDoTranspose) + { + const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 = + transform_tensor_descriptor( + ab_grid_desc_mntiles_ktiles, + make_tuple(make_pass_through_transform( + math::integer_divide_ceil(sizeMN, Number{})), + make_pass_through_transform( + math::integer_divide_ceil(sizeK, Number{})), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{})); + + // Freeze VectorSize to first element of the loading chunk (for convenience) + // Swap MNPerWmma and MNKRow for consistency with transpose descriptor + return transform_tensor_descriptor( + ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, + make_tuple( + make_pass_through_transform( + math::integer_divide_ceil(sizeMN, Number{})), + make_pass_through_transform(math::integer_divide_ceil(sizeK, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_freeze_transform(I0)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<2>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<>{})); + } + else + { + const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 = + transform_tensor_descriptor( + ab_grid_desc_mntiles_ktiles, + make_tuple(make_pass_through_transform( + math::integer_divide_ceil(sizeMN, Number{})), + make_pass_through_transform( + math::integer_divide_ceil(sizeK, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + // Freeze VectorSize to first element of the loading chunk (for convenience) + return transform_tensor_descriptor( + ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, + make_tuple( + make_pass_through_transform( + math::integer_divide_ceil(sizeMN, Number{})), + make_pass_through_transform(math::integer_divide_ceil(sizeK, Number{})), + make_pass_through_transform(Number{}), + make_freeze_transform(I0), + make_pass_through_transform(Number{})), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{})); + } + } + + __device__ static constexpr auto GetBlockDescriptor() + { + // LDS memory layouts: + // lanes within tiles stored contiguously in chunks of 8 elements + // tiles are then stored first in K dimension + // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize + const auto a_grid_desc_mraw_kraw = [&]() { + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + I1)); + }(); + + // Freeze VectorSize to first element of the chunk (for convenience) + return transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_freeze_transform(I0)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<>{})); + } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MNWaves_, KWaves_, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto GetBlockLaneIdx() + { + const index_t lane_id = __lane_id(); + + constexpr index_t LanesPerSubTile = ABDoTranspose ? KPack : MNPerWmma; + + constexpr auto laneid_to_block_lane_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MNKRow, LanesPerSubTile))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return laneid_to_block_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id)); + } + + template + __device__ static auto GetGridLaneIdx() + { + const index_t lane_id = __lane_id(); + + constexpr index_t SubTilesRow = MNKRow; + constexpr index_t SubTilesCol = 4 / sizeof(ABDataType); + constexpr index_t LanesPerSubTile = + ABDoTranspose ? KPack / SubTilesCol : MNPerWmma / SubTilesCol; + constexpr auto dims_tuple = ABDoTranspose + ? make_tuple(SubTilesCol, SubTilesRow, LanesPerSubTile) + : make_tuple(SubTilesRow, SubTilesCol, LanesPerSubTile); + + constexpr auto laneid_to_grid_lane_idx_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(dims_tuple)), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto indices = + laneid_to_grid_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id)); + + if constexpr(!ABDoTranspose) + { + return make_multi_index(indices[I0], indices[I1] * LanesPerSubTile + indices[I2]); + } + else + { + return make_multi_index(indices[I1], indices[I0] * LanesPerSubTile + indices[I2]); + } + } + + template + __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor, + BlockDescriptor& block_descriptor, + ABElementwiseOperation& ab_element_op, + const index_t block_mn_id) + { + // Note: GlobalBufferNum is currently not used but it will be needed + // once we add other pipelines. It is currently needed only for + // consistency with the thread tiles approach + static_assert(GlobalBufferNum == 1, "single global buffer is only supported"); + constexpr index_t NumABTensor = ABsDataType::Size(); + static_assert(NumABTensor == 1, "multiAB currently not supported"); + + using ABDataType = remove_cvref_t>; + + const auto wave_idx = GetWaveIdx(); + index_t wave_idK = wave_idx[I1]; + index_t wave_idMN = wave_idx[I0]; + + const auto grid_lane_id = GetGridLaneIdx(); + index_t lane_group_grid = grid_lane_id[I0]; + index_t lane_local_id_grid = grid_lane_id[I1]; + + const auto block_lane_id = GetBlockLaneIdx(); + index_t lane_group_block = block_lane_id[I0]; + index_t lane_local_id_block = block_lane_id[I1]; + + return ThreadGroupTransferGlobal, + Sequence, + Sequence, + ABK1Value, + ABDoTranspose>( + grid_descriptor[I0], + block_descriptor, + make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN, + wave_idK, + lane_group_grid, + lane_local_id_grid), + make_multi_index(wave_idMN, wave_idK, lane_group_block, lane_local_id_block), + ab_element_op); + } + + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor() + { + // This is a block descriptor used to read LDS memory into register + // It's defined in a way consistent with the existing implementation to + // avoid changes in the pipelines + return make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + I1)); + } + + __device__ static constexpr auto GetBlockStep() + { + // Grid descriptor step (MoveSrcSliceWindow) + return make_multi_index(I0, KWaves_ * KRepeat_, I0, I0); + } + + template + __device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc) + { + return grid_desc.GetLength(I1) * KPack; + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index d226510cf0..25653dd859 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -175,7 +175,8 @@ template + bool PermuteB, + bool ForceThreadTileTransfer = false> struct GridwiseGemm_wmma_cshuffle_v3 : GridwiseGemm_wmma_cshuffle_v3_base< ALayout, @@ -227,7 +228,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 ComputeTypeA, ComputeTypeB, PermuteA, - PermuteB> + PermuteB, + ForceThreadTileTransfer> { using Base = GridwiseGemm_wmma_cshuffle_v3_base< ALayout, @@ -279,7 +281,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 ComputeTypeA, ComputeTypeB, PermuteA, - PermuteB>; + PermuteB, + ForceThreadTileTransfer>; using Base::I0; using Base::I1; @@ -318,9 +321,6 @@ struct GridwiseGemm_wmma_cshuffle_v3 using ThisThreadBlock = ThisThreadBlock; - using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; - using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; - using Base::NumATensor; using Base::NumBTensor; using Base::NumDTensor; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp index 36724d5745..1b8a8ef09e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -122,7 +122,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale ComputeTypeA, ComputeTypeB, PermuteA, - PermuteB> + PermuteB, + true> { using Base = GridwiseGemm_wmma_cshuffle_v3_base< ALayout, @@ -174,7 +175,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale ComputeTypeA, ComputeTypeB, PermuteA, - PermuteB>; + PermuteB, + true>; using Base::I0; using Base::I1; @@ -213,9 +215,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale using ThisThreadBlock = ThisThreadBlock; - using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; - using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; - using Base::NumATensor; using Base::NumBTensor; using Base::NumDTensor; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index dac0c9b3b0..523cb8efd1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -14,10 +14,13 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -107,7 +110,8 @@ template + bool PermuteB, + bool ForceThreadTileTransfer = false> // only needed for convolution (limitation) struct GridwiseGemm_wmma_cshuffle_v3_base { @@ -162,6 +166,101 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return 1; }(); + // Limitations of the current implementation: + // - no multiAB + // - GemmSpecialization Default + // - pipeline v1 because v3 is buggy (fixed in batched gemm gemm implementation) + // AK1Value == 8 is not really a limitation but a requirement for the method so + // it will stay +#ifdef __gfx12__ + static constexpr bool IsAWaveTransferApplicable = + !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && + GemmSpec == tensor_operation::device::GemmSpecialization::Default && + BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8; + + static constexpr bool IsBWaveTransferApplicable = + !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && + GemmSpec == tensor_operation::device::GemmSpecialization::Default && + BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; +#else + static constexpr bool IsAWaveTransferApplicable = false; + static constexpr bool IsBWaveTransferApplicable = false; +#endif + + static constexpr index_t WaveSize = + WmmaSelector::selected_wmma + .wave_size; + static constexpr bool UseBlockPaddingA = + ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4; + using ATransfer = typename std::conditional< + IsAWaveTransferApplicable, + ABTransferWaveTiles, + ABTransferThreadTiles>::type; + + static constexpr bool UseBlockPaddingB = + BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4; + + using BTransfer = typename std::conditional< + IsBWaveTransferApplicable, + ABTransferWaveTiles, + ABTransferThreadTiles>::type; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != tensor_operation::device::GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + static_assert(!PermuteA, "PermuteA is not supported"); + // return block_id to C matrix tile idx (m0, n0) mapping // if arch = gfx942 using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; @@ -222,27 +321,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return math::integer_divide_ceil(N, NPerBlock); } - template - __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) - { - // K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1 - constexpr auto K0 = BlockDesc{}.GetLength(I0); - constexpr auto K1 = BlockDesc{}.GetLength(I2); -#ifdef __gfx12__ - constexpr auto KRow = I2; -#else - constexpr auto KRow = I1; -#endif - return transform_tensor_descriptor( - BlockDesc{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); - } - static constexpr auto MakeAsGridPointer() { return generate_tuple( @@ -268,87 +346,27 @@ struct GridwiseGemm_wmma_cshuffle_v3_base using AsGridPointer = decltype(MakeAsGridPointer()); using BsGridPointer = decltype(MakeBsGridPointer()); - __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( - index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + __host__ __device__ static auto MakeAGridDescriptor_M_K(index_t M, index_t K, index_t StrideA) { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - if constexpr(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::MNKPadding) + if constexpr(is_same_v) { - // pad both M and K - const auto a_grid_desc_m_k = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(MPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MNPadding) + else if constexpr(is_same_v) { - // pad M, but not K - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_right_pad_transform(M, MPad - M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::NKPadding) + } + + __host__ __device__ static auto MakeBGridDescriptor_N_K(index_t N, index_t K, index_t StrideB) + { + if constexpr(is_same::value) { - // pad K, but not M - const auto a_grid_desc_m_k = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); } - else + else if constexpr(is_same::value) { - static_assert(!PermuteA, "PermuteA is not supported"); - - // not pad M or K - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); } } @@ -360,123 +378,25 @@ struct GridwiseGemm_wmma_cshuffle_v3_base const std::array& StrideAs, const index_t AK0) { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding; + constexpr bool padK = GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding; return generate_tuple( [&](auto i) { - return MakeAGridDescriptor_AK0_M_AK1(M, MPad, K, KPad, StrideAs[i], AK0); + const auto base_desc = MakeAGridDescriptor_M_K(M, K, StrideAs[i]); + + return ATransfer::template MakeGridDescriptor( + base_desc, M, MPad, K, KPad, StrideAs[i], AK0); }, Number{}); } - __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( - index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) - { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); - } - }(); - - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - static_assert(!(is_same_v, pk_i4_t> && - GemmSpec != GemmSpecialization::Default), - "pk_i4_t does not support padding"); - - if constexpr(GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both N and K - const auto b_grid_desc_n_k = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_right_pad_transform(N, NPad - N), - make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(NPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad N, but not K - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad K, but not N - const auto b_grid_desc_n_k = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - if constexpr(!PermuteB) - { - // not pad N or K - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - // Pre-shuffled Weight - // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] - constexpr index_t BK01 = KPerBlock / BK1Value; - const index_t BK0_ = StrideB / BK1Value; - const index_t BK00 = BK0_ / BK01; - - const auto b_grid_desc_bk00_n_bk01_bk1_permute = - make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); - - const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( - b_grid_desc_bk00_n_bk01_bk1_permute, - make_tuple(make_merge_transform(make_tuple(BK00, BK01)), - make_pass_through_transform(make_tuple(N)), - make_pass_through_transform(BK1Value)), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_grid_desc_bk0_n_bk1_permute; - } - } - } - __host__ __device__ static auto MakeBsGridDescriptor_BK0_N_BK1(const index_t K, const index_t KPad, @@ -485,27 +405,36 @@ struct GridwiseGemm_wmma_cshuffle_v3_base const std::array& StrideBs, const index_t BK0) { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding; + constexpr bool padK = GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding; return generate_tuple( [&](auto i) { - return MakeBGridDescriptor_BK0_N_BK1(K, KPad, N, NPad, StrideBs[i], BK0); + const auto base_desc = MakeBGridDescriptor_N_K(N, K, StrideBs[i]); + return BTransfer::template MakeGridDescriptor( + base_desc, N, NPad, K, KPad, StrideBs[i], BK0); }, Number{}); } - template - __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&) + __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor() { constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); - return MakeWmmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + return ATransfer::template MakeWmmaTileDescriptor(); } - template - __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1&) + __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor() { constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); - return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + return BTransfer::template MakeWmmaTileDescriptor(); } template @@ -610,278 +539,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base Number{}); } - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // bank conflict when writting the data into LDS, but don't worry, we have whole entire - // loop to hide it in v4. it may give you some benefit from less valu in compute address - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(Number{} * AK1Number, AK1Number, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeA) / APackedSize; - constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerWmma; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerWmma * sizeof(LDSTypeA) > 128) - ? 1 - : ((128 / (AK1Number * MPerWmma * sizeof(LDSTypeA))) > M0 - ? M0 - : 128 / (AK1Number * MPerWmma * sizeof(LDSTypeA))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // bank conflict when writting the data into LDS, but don't worry, we have whole entire - // loop to hide it in v4. it may give you some benefit from less valu in compute address - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(Number{} * BK1Number, BK1Number, I1)); - } - else if constexpr(is_same::value) - { - // NLdsLayer * K0 as logical Bank - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeB) / BPackedSize; - constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - BK0Number * Number{}, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_pass_through_transform(BK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - else // RowMajor B - { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerWmma; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerWmma * sizeof(LDSTypeB) > 128) - ? 1 - : ((128 / (BK1Number * NPerWmma * sizeof(LDSTypeB))) > N0 - ? N0 - : 128 / (BK1Number * NPerWmma * sizeof(LDSTypeB))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - } - __host__ __device__ static constexpr auto // *Caution Here repeat is shuffle repeat GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() @@ -899,28 +556,27 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; } - using BlockwiseGemmPipe = remove_cvref_t< - decltype(BlockGemmPipeline_Selector< - BlkGemmPipelineVer, - BlkGemmPipeSched, - BlockSize, - LDSTypeA, - LDSTypeB, - ComputeTypeA, - ComputeTypeB, - AccDataType, - decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), - decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MPerWmma, - NPerWmma, - MRepeat, - NRepeat, - KPack>())>; + using BlockwiseGemmPipe = + remove_cvref_t())>; template __device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -1168,8 +824,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto a_block_desc_ak0_m_ak1 = ATransfer::GetBlockDescriptor(); + constexpr auto b_block_desc_bk0_n_bk1 = BTransfer::GetBlockDescriptor(); // lds max alignment constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); @@ -1257,161 +913,32 @@ struct GridwiseGemm_wmma_cshuffle_v3_base auto e_grid_buf = make_dynamic_buffer( p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); - // lds max alignment constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = ATransfer::GetBlockDescriptor(); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = BTransfer::GetBlockDescriptor(); // A matrix blockwise copy - // workaround because v7r2 is not as general as v4r1 - auto get_a_blockwise_transfer = [&]() { - if constexpr(NumATensor > 1) - { - const auto idx_as_block_begin = generate_tuple( - [&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, - Number{}); - - return ThreadGroupTensorSliceTransfer_v7r2< - ThisThreadBlock, - AsDataType, - Tuple, - AGridDesc_AK0_M_K1, - decltype(tie(a_block_desc_ak0_m_ak1)), - AElementwiseOperation, - Sequence(InMemoryDataOperationEnum::Set)>, - Sequence, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - uniform_sequence_gen_t, - Sequence, - BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1, - idx_as_block_begin, - tie(a_block_desc_ak0_m_ak1), - make_tuple(make_multi_index(0, 0, 0)), - a_element_op}; - } - else - { - return ThreadGroupTensorSliceTransfer_v4r1< - ThisThreadBlock, - AElementwiseOperation, - ck::tensor_operation::element_wise::PassThrough, - InMemoryDataOperationEnum::Set, - Sequence, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - remove_cvref_t>, - remove_cvref_t>, - decltype(as_grid_desc_ak0_m_ak1[I0]), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - as_grid_desc_ak0_m_ak1[I0], - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - } - }; - - auto a_blockwise_copy = get_a_blockwise_transfer(); + auto a_blockwise_copy = + ATransfer::template GetBlockTransfer( + as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id); // B matrix blockwise copy - // workaround because v7r2 is not as general as v4r1 - auto get_b_blockwise_transfer = [&]() { - if constexpr(NumBTensor > 1) - { - const auto idx_bs_block_begin = generate_tuple( - [&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, - Number{}); - - return ThreadGroupTensorSliceTransfer_v7r2< - ThisThreadBlock, - BsDataType, - Tuple, - BGridDesc_BK0_N_K1, - decltype(tie(b_block_desc_bk0_n_bk1)), - BElementwiseOperation, - Sequence(InMemoryDataOperationEnum::Set)>, - Sequence, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - uniform_sequence_gen_t, - Sequence, - BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1, - idx_bs_block_begin, - tie(b_block_desc_bk0_n_bk1), - make_tuple(make_multi_index(0, 0, 0)), - b_element_op}; - } - else - { - return ThreadGroupTensorSliceTransfer_v4r1< - ThisThreadBlock, - BElementwiseOperation, - ck::tensor_operation::element_wise::PassThrough, - InMemoryDataOperationEnum::Set, - Sequence, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - remove_cvref_t>, - remove_cvref_t>, - decltype(bs_grid_desc_bk0_n_bk1[I0]), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - bs_grid_desc_bk0_n_bk1[I0], - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - } - }; - - auto b_blockwise_copy = get_b_blockwise_transfer(); + auto b_blockwise_copy = + BTransfer::template GetBlockTransfer( + bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -1427,8 +954,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base APackedSize), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + constexpr auto a_block_slice_copy_step = ATransfer::GetBlockStep(); + constexpr auto b_block_slice_copy_step = BTransfer::GetBlockStep(); // Blockwise GEMM pipeline static_assert(std::is_default_constructible_v); @@ -1436,8 +963,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) / - KPerBlock); + ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / KPerBlock); blockwise_gemm_pipeline.template Run( get_first_element_workaround(as_grid_desc_ak0_m_ak1), diff --git a/include/ck/utility/amd_transpose_load.hpp b/include/ck/utility/amd_transpose_load.hpp new file mode 100644 index 0000000000..6ef17b18da --- /dev/null +++ b/include/ck/utility/amd_transpose_load.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "data_type.hpp" + +namespace ck { + +#if defined(__gfx12__) +template +__device__ auto amd_global_load_transpose_to_vgpr(const T* in_ptr) +{ + using vector_t = typename vector_type::type; + if constexpr(sizeof(T) == 2) + { + typedef __attribute__((__vector_size__(8 * sizeof(__fp16)))) __fp16 llvm_fp16x8_t; + __attribute__((address_space(1))) llvm_fp16x8_t* glb_ptr = + reinterpret_cast<__attribute__((address_space(1))) llvm_fp16x8_t*>( + reinterpret_cast(in_ptr)); + return bit_cast(__builtin_amdgcn_global_load_tr_b128_v8f16(glb_ptr)); + } + else if constexpr(sizeof(T) == 1) + { + typedef __attribute__((__vector_size__(2 * sizeof(int)))) int llvm_intx2_t; + __attribute__((address_space(1))) llvm_intx2_t* glb_ptr = + reinterpret_cast<__attribute__((address_space(1))) llvm_intx2_t*>( + reinterpret_cast(in_ptr)); + return bit_cast(__builtin_amdgcn_global_load_tr_b64_v2i32(glb_ptr)); + } + else + { + static_assert(false, "not implemented"); + } +} +#endif + +} // namespace ck diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index a1f3ee2d78..66166e11e3 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -12,6 +12,7 @@ #else #include "amd_buffer_addressing.hpp" #endif +#include "amd_transpose_load.hpp" #include "generic_memory_space_atomic.hpp" namespace ck { @@ -69,6 +70,7 @@ struct DynamicBuffer __host__ __device__ constexpr T& operator()(IndexType i) { return p_data_[i]; } template >::type, typename scalar_type>::type>::value || !is_native_type(), @@ -89,7 +91,8 @@ struct DynamicBuffer bool constexpr use_amd_buffer_addressing = false; #endif - if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) + if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing && + !DoTranspose) { constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; @@ -112,6 +115,14 @@ struct DynamicBuffer invalid_element_value_); } } + else if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && DoTranspose) + { +#ifdef __gfx12__ + return amd_global_load_transpose_to_vgpr(p_data_ + i); +#else + static_assert(!DoTranspose, "load-with-transpose only supported on gfx12+"); +#endif + } else { if(is_valid_element) diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp index 7652e73809..672fc8c31b 100644 --- a/include/ck/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,15 +7,19 @@ namespace ck { +#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM +#ifdef __gfx12__ +__device__ void llvm_amdgcn_s_wait_dscnt(short cnt) __asm("llvm.amdgcn.s.wait.dscnt"); +#endif +#endif + __device__ void block_sync_lds() { #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #ifdef __gfx12__ - asm volatile("\ - s_wait_dscnt 0x0 \n \ - s_barrier_signal -1 \n \ - s_barrier_wait -1 \ - " ::); + llvm_amdgcn_s_wait_dscnt(0); + asm volatile("s_barrier_signal -1\n\t" + "s_barrier_wait -1"); #else // asm volatile("\ // s_waitcnt lgkmcnt(0) \n \ diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp index a439cf27f5..71b5c5e7cf 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -44,6 +44,7 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp index 55e0362018..f4489dc45f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -42,6 +42,7 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp index e51de0556c..423f86365c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -49,6 +49,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp index 722a0bae55..2eb28958e6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp @@ -51,6 +51,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,