created fastegelu instances

This commit is contained in:
Zoltan Lakatos
2025-06-11 14:14:43 +00:00
parent 4c8ea9517d
commit 264e1b238a
9 changed files with 541 additions and 72 deletions

View File

@@ -11,6 +11,7 @@
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#if defined(CK_ENABLE_FP16)
namespace ck {
namespace tensor_operation {
namespace device {
@@ -136,18 +137,17 @@ template <typename ALayout,
typename D0DataType,
typename D1DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
PassThrough,
PassThrough,
AddAddFastGelu>>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
PassThrough,
PassThrough,
AddAddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
@@ -170,7 +170,6 @@ struct DeviceOperationInstanceFactory<
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
#if defined(CK_ENABLE_FP16)
constexpr bool IsAllDRowLayout = is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row>;
constexpr bool IsAllDFloat16 =
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t>;
@@ -203,8 +202,6 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif // CK_ENABLE_FP16
#endif // CK_USE_WMMA
return op_ptrs;
@@ -223,18 +220,18 @@ template <typename ALayout,
typename D0DataType,
typename D1DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddAddFastGelu>>
struct DeviceOperationInstanceFactory<
DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
PassThrough,
PassThrough,
AddAddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
@@ -244,9 +241,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddAddFastGelu>;
PassThrough,
PassThrough,
AddAddFastGelu>;
static auto GetInstances()
{
@@ -317,3 +314,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif // CK_ENABLE_FP16

View File

@@ -160,18 +160,17 @@ template <typename ALayout,
typename BDataType,
typename D0DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
@@ -241,18 +240,17 @@ template <typename ALayout,
typename BDataType,
typename D0DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_XDL)
void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
@@ -67,6 +68,132 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
PassThrough,
PassThrough,
FastGelu>>>&);
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>&);
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>&);
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>&);
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>&);
#endif // CK_USE_WMMA
// GEMM + Add + FastGelu
// DeviceGemmMultipleDSplitK specialization
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
// No XDL instances for DeviceGemmMultipleDSplitK with AddFastGelu at the moment
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};
// GEMM + FastGelu
template <typename ALayout,
@@ -75,17 +202,17 @@ template <typename ALayout,
typename ADataType,
typename BDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
@@ -103,6 +230,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
@@ -127,6 +255,28 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
// Reuse DeviceGemmMultipleDSplitK instances
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
#endif // CK_USE_WMMA
return op_ptrs;
}
@@ -136,4 +286,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#endif // CK_ENABLE_FP16

View File

@@ -1,5 +1,10 @@
# ONLY XDL_KERNELS
# XDL AND WMMA KERNELS
add_instance_library(device_gemm_fastgelu_instance
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp

View File

@@ -0,0 +1,76 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise(a * b)
// elementwise(c) = fastgelu(c)
// output: e[m, n]
// input: a[m, k], b[n, k]
template <GemmSpecialization GemmSpec>
using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise(a * b)
// elementwise(c) = fastgelu(c)
// output: e[m, n]
// input: a[m, k], b[n, k]
template <GemmSpecialization GemmSpec>
using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,81 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise(a * b)
// elementwise(c) = fastgelu(c)
// output: e[m, n]
// input: a[m, k], b[n, k]
template <GemmSpecialization GemmSpec>
using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,83 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise(a * b)
// elementwise(c) = fastgelu(c)
// output: e[m, n]
// input: a[m, k], b[n, k]
template <GemmSpecialization GemmSpec>
using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -46,7 +46,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp)
@@ -89,8 +88,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
endif()
endif()
@@ -148,7 +148,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
@@ -202,8 +201,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
endif()
endif()