Implement grouped gemm fastgelu for RDNA4 (#3303)

* Implement grouped gemm fastgelu for RDNA4

* chore: some cleanup and minor inconsistencies in grouped gemm profiler

* chore: clarified logic and reporting of supported instance warnings

[ROCm/composable_kernel commit: f9c6ba0403]
This commit is contained in:
Erwin Terpstra
2026-01-07 19:20:44 +01:00
committed by GitHub
parent 6f6256381a
commit d074af36c9
24 changed files with 665 additions and 399 deletions

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances(
BF16,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
BF16,

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances(
BF16,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
BF16,

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances(
BF16,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
BF16,

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances(
BF16,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
BF16,

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances(
F16,
DsDataType,
F16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(
F16,
DsDataType,
F16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances(
F16,
DsDataType,
F16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<

View File

@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances(
F16,
DsDataType,
F16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<

View File

@@ -17,7 +17,10 @@ using EDataType = F16;
template <device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
using device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances =
std::tuple<
// clang-format off
@@ -40,9 +43,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<

View File

@@ -17,7 +17,10 @@ using EDataType = F16;
template <device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
using device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances =
std::tuple<
// clang-format off
@@ -40,9 +43,9 @@ void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<

View File

@@ -1,10 +1,15 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_grouped_gemm_fastgelu_instance
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instance.cpp
)

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,
Col,
Row,
device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,
Col,
Col,
device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,
Row,
Row,
device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,
Row,
Col,
device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck