mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
Implement device_grouped_gemm_fixed_nk_bias for RDNA4 (#4340)
## Proposed changes Summary: - Modified implementation for grouped_gemm_fixed_nk_bias - FP16 WMMA examples - WMMA instances - Profiler for grouped_gemm_fixed_nk_bias - Add WMMA instances to existing tests **This PR depends on PR https://github.com/ROCm/rocm-libraries/pull/4299 and should be merged after it. Only the last 6 commits are in the scope of this PR.** ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [x] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [x] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [x] I have added inline documentation which enables the maintainers with understanding the motivation - [x] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -3,13 +3,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -70,20 +71,52 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
|
||||
PassThrough,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ck::tensor_operation::element_wise::SplitKAdd>>>&
|
||||
instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ck::tensor_operation::element_wise::SplitKAdd>>>&
|
||||
instances);
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
|
||||
BLayout,
|
||||
Row_Tuple,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
F32_Tuple,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -91,11 +124,11 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemmFixedNK<ALayout,
|
||||
BLayout,
|
||||
Row_Tuple,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
F32_Tuple,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -112,28 +145,96 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// fp32_output
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, float>)
|
||||
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, float>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemmFixedNK<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ck::tensor_operation::element_wise::SplitKAdd>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemmFixedNK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ck::tensor_operation::element_wise::SplitKAdd>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
// fp16_output
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(op_ptrs);
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
|
||||
Reference in New Issue
Block a user