mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
Implement grouped gemm fastgelu for RDNA4 (#3303)
* Implement grouped gemm fastgelu for RDNA4
* chore: some cleanup and minor inconsistencies in grouped gemm profiler
* chore: clarified logic and reporting of supported instance warnings
[ROCm/composable_kernel commit: f9c6ba0403]
This commit is contained in:
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
BF16,
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances(
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
BF16,
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
BF16,
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
BF16,
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances(
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances(
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
|
||||
@@ -20,9 +20,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances(
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
|
||||
@@ -17,7 +17,10 @@ using EDataType = F16;
|
||||
|
||||
template <device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
using device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
@@ -40,9 +43,9 @@ void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
|
||||
@@ -17,7 +17,10 @@ using EDataType = F16;
|
||||
|
||||
template <device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
using device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
@@ -40,9 +43,9 @@ void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_grouped_gemm_fastgelu_instance
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
|
||||
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
Col,
|
||||
Row,
|
||||
device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Col,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
Col,
|
||||
Col,
|
||||
device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
Row,
|
||||
Row,
|
||||
device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_fastgelu_wmma_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Col,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
Row,
|
||||
Col,
|
||||
device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user