Merge commit 'b09121f86066381f3662fdbdee6a810849a8a1a7' into develop

This commit is contained in:
assistant-librarian[bot]
2026-01-20 10:16:09 +00:00
parent 38c7251ed1
commit 43058803dc
13 changed files with 1345 additions and 78 deletions

View File

@@ -8,6 +8,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "profiler/common.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
@@ -33,6 +34,8 @@ using ReduceOutElementOps = ck::Tuple<Identity, Identity>;
using DeviceGemmReduceNoOpPtr =
ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>;
#ifdef CK_ENABLE_FP16
#ifdef CK_USE_XDL
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
@@ -44,6 +47,22 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
#endif // CK_USE_WMMA
#endif // CK_ENABLE_FP16
} // namespace instance
} // namespace device
@@ -210,6 +229,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
// add device GEMM instances
std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
@@ -217,35 +237,64 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
gemm_ptrs);
#endif
}
}
#endif // CK_ENABLE_FP16
if(gemm_ptrs.size() <= 0)
{
@@ -318,9 +367,21 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
bool c_error = ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
bool d0_error = ck::utils::check_err(d0_g_m_device_result, d0_g_m_host_result);
bool d1_error = ck::utils::check_err(d1_g_m_device_result, d1_g_m_host_result);
bool c_error = ck::utils::check_err(c_g_m_n_device_result,
c_g_m_n_host_result,
"Error: Device and Host results do not match!",
get_rtol<CDataType>(),
get_atol<CDataType>());
bool d0_error = ck::utils::check_err(d0_g_m_device_result,
d0_g_m_host_result,
"Error: Device and Host results do not match!",
get_rtol<ReduceDataType>(),
get_atol<ReduceDataType>());
bool d1_error = ck::utils::check_err(d1_g_m_device_result,
d1_g_m_host_result,
"Error: Device and Host results do not match!",
get_rtol<ReduceDataType>(),
get_atol<ReduceDataType>());
pass = pass && (c_error == true);
pass = pass && (d0_error == true);