Implement device_grouped_gemm_fixed_nk_bias for RDNA4 (#4340)

## Proposed changes

Summary:

- Modified implementation for grouped_gemm_fixed_nk_bias
- FP16 WMMA examples
- WMMA instances
- Profiler for grouped_gemm_fixed_nk_bias
- Add WMMA instances to existing tests

**This PR depends on PR https://github.com/ROCm/rocm-libraries/pull/4299
and should be merged after it.
Only the last 6 commits are in the scope of this PR.**

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [x] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [x] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [x] I have added inline documentation which enables the maintainers
with understanding the motivation
- [x] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Yung-sheng Tu
2026-02-26 01:28:09 +01:00
committed by GitHub
parent bea67cb520
commit 00853e2bd2
11 changed files with 1514 additions and 40 deletions

View File

@@ -3,13 +3,14 @@
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp"
#include "ck/utility/type.hpp"
#include <memory>
#include <vector>
namespace ck {
namespace tensor_operation {
@@ -70,20 +71,52 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
PassThrough,
Add>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
std::vector<
std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
ck::tensor_operation::element_wise::SplitKAdd>>>&
instances);
void add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
std::vector<
std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
ck::tensor_operation::element_wise::SplitKAdd>>>&
instances);
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
BLayout,
Row_Tuple,
DsLayout,
ELayout,
ADataType,
BDataType,
F32_Tuple,
DsDataType,
EDataType,
PassThrough,
PassThrough,
@@ -91,11 +124,11 @@ struct DeviceOperationInstanceFactory<
{
using DeviceOp = DeviceGroupedGemmFixedNK<ALayout,
BLayout,
Row_Tuple,
DsLayout,
ELayout,
ADataType,
BDataType,
F32_Tuple,
DsDataType,
EDataType,
PassThrough,
PassThrough,
@@ -112,28 +145,96 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
#endif
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
#endif
}
}
// fp32_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, float>)
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, float>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
#ifdef CK_USE_XDL
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(op_ptrs);
#endif
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
#ifdef CK_USE_XDL
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(op_ptrs);
#endif
}
}
return op_ptrs;
}
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemmFixedNK<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
ck::tensor_operation::element_wise::SplitKAdd>>
{
using DeviceOp = DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
ck::tensor_operation::element_wise::SplitKAdd>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
// fp16_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(op_ptrs);
#if defined(CK_USE_WMMA)
add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
#endif
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
#if defined(CK_USE_WMMA)
add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
#endif
}
}
return op_ptrs;