Merge commit '7330ec37ee3b8cf2d54630372dfe9e86a893e4f5' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-04 21:11:23 +00:00
parent 5677205f88
commit 7f65be1b3e
51 changed files with 3709 additions and 189 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -16,6 +16,70 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_USE_WMMA
#ifdef CK_ENABLE_BF16
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
Col,
Row,
Row,
BF16,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
Col,
Col,
Row,
BF16,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif // CK_ENABLE_BF16
#ifdef CK_ENABLE_FP16
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
Col,
Row,
Row,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
Col,
Col,
Row,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif // CK_ENABLE_FP16
#endif // CK_USE_WMMA
#ifdef CK_USE_XDL
#ifdef CK_ENABLE_FP16
void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
@@ -46,6 +110,8 @@ void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_i
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif // CK_ENABLE_FP16
#endif // CK_USE_XDL
template <typename ALayout,
typename B0Layout,
typename B1Layout,
@@ -86,7 +152,46 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_USE_WMMA
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<B0DataType, bhalf_t> &&
is_same_v<B1DataType, bhalf_t> && is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
{
add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Col> && is_same_v<CLayout, Row>)
{
add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance(
op_ptrs);
}
}
#endif // CK_ENABLE_BF16
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
{
add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Col> && is_same_v<CLayout, Row>)
{
add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
op_ptrs);
}
}
#endif // CK_ENABLE_FP16
#endif // CK_USE_WMMA
#ifdef CK_USE_XDL
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{
@@ -103,10 +208,11 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif // CK_ENABLE_FP16
#endif // CK_USE_XDL
return op_ptrs;
}
};
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation