Merge commit '3784c0e7c395af214fdddd5f702691b354bfe8d4' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-12 20:14:45 +00:00
parent 90e4b6bfe9
commit 77527d2fa6
13 changed files with 1570 additions and 12 deletions

View File

@@ -34,6 +34,7 @@ using ReduceOutElementOps = ck::Tuple<Div, Div>;
using DeviceGemmReduceNoOpPtr =
ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>;
#ifdef CK_USE_XDL
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
@@ -45,6 +46,20 @@ void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
#endif
#ifdef CK_USE_WMMA
void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
#endif
} // namespace instance
} // namespace device
@@ -211,33 +226,61 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs);
#endif
}
}
@@ -274,6 +317,8 @@ bool profile_gemm_reduce_impl(int do_verification,
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
std::string gemm_name = gemm_ptr->GetTypeString();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
@@ -289,8 +334,6 @@ bool profile_gemm_reduce_impl(int do_verification,
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::string gemm_name = gemm_ptr->GetTypeString();
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
@@ -317,9 +360,9 @@ bool profile_gemm_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result);
ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result);
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
pass = pass & ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result);
pass = pass & ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result);
if(do_log)
{
@@ -346,7 +389,7 @@ bool profile_gemm_reduce_impl(int do_verification,
}
else
{
std::cout << "does not support this GEMM problem" << std::endl;
std::cout << gemm_name << ": does not support this GEMM problem" << std::endl;
}
}