[rocm-libraries] ROCm/rocm-libraries#4340 (commit 70a312f)

Implement device_grouped_gemm_fixed_nk_bias for RDNA4

## Proposed changes

Summary:

- Modified implementation for grouped_gemm_fixed_nk_bias
- FP16 WMMA examples
- WMMA instances
- Profiler for grouped_gemm_fixed_nk_bias
- Add WMMA instances to existing tests

**This PR depends on PR https://github.com/ROCm/rocm-libraries/pull/4299
and should be merged after it.
Only the last 6 commits are in the scope of this PR.**

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [x] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [x] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [x] I have added inline documentation which enables the maintainers
with understanding the motivation
- [x] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Yung-sheng Tu
2026-02-26 00:28:58 +00:00
committed by assistant-librarian[bot]
parent 9a32f0ea19
commit 75aea70c2c
11 changed files with 1514 additions and 40 deletions

View File

@@ -7,22 +7,55 @@
#include <sstream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/env.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/env.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct SplitKAdd
{
static constexpr const char* name = "SplitKAdd";
__host__ __device__ void set_kbatch(const index_t& id, const index_t& total)
{
kbatch_id = id;
KBatch = total;
}
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
if(kbatch_id == KBatch - 1)
{
add_op(y, x0, x1);
}
else
{
passthrough_op(y, x0);
}
}
private:
index_t kbatch_id = 0;
index_t KBatch = 1;
static constexpr auto add_op = Add{};
static constexpr auto passthrough_op = PassThrough{};
};
} // namespace element_wise
namespace device {
template <typename GridwiseGemm,
@@ -110,6 +143,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
while(id_local < local_grid_size)
{
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_start, id_off);
const auto tile_index =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(tile_index[Number<0>{}]);
auto c_element_op_copy(c_element_op);
if constexpr(std::is_same_v<decltype(c_element_op_copy),
ck::tensor_operation::element_wise::SplitKAdd>)
{
c_element_op_copy.set_kbatch(kbatch_id, k_batch_);
}
KernelArgument kernel_arg{std::array<const void*, 1>{gemmTransKernelArg.p_a_grid},
std::array<const void*, 1>{gemmTransKernelArg.p_b_grid},
gemmTransKernelArg.p_ds_grid,
@@ -124,15 +172,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
k_batch_,
a_element_op,
b_element_op,
c_element_op,
c_element_op_copy,
false};
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_start, id_off);
const auto tile_index =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(kernel_arg, tile_index[Number<0>{}]);
@@ -813,7 +855,9 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK<ALayout
}
if constexpr(!std::is_same_v<CDEElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>)
ck::tensor_operation::element_wise::PassThrough> &&
!std::is_same_v<CDEElementwiseOperation,
ck::tensor_operation::element_wise::SplitKAdd>)
{
if(arg.k_batch_ > 1)
{

View File

@@ -8,27 +8,28 @@
#include <ostream>
#endif
#include "ck/utility/env.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/env.hpp"
namespace ck {