Wmma support for gemm_reduce (#3145)

* Initial implementation GEMM+Reduce:

 - device struct
 - epilogue struct

* Fix tests, improve profiler and add initial instances

* Add instances

* Fix compilation error

* Address review comments

* Fix logging

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>

[ROCm/composable_kernel commit: 7414a0f4d4]
This commit is contained in:
Enrico Degregori
2025-11-12 20:23:54 +01:00
committed by GitHub
parent c8c5a7e1c6
commit e00db44d0c
12 changed files with 1568 additions and 12 deletions

View File

@@ -34,6 +34,7 @@ using ReduceOutElementOps = ck::Tuple<Div, Div>;
using DeviceGemmReduceNoOpPtr =
ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>;
#ifdef CK_USE_XDL
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
@@ -45,6 +46,20 @@ void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
#endif
#ifdef CK_USE_WMMA
void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
#endif
} // namespace instance
} // namespace device
@@ -211,33 +226,61 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs);
#endif
}
}
@@ -274,6 +317,8 @@ bool profile_gemm_reduce_impl(int do_verification,
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
std::string gemm_name = gemm_ptr->GetTypeString();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
@@ -289,8 +334,6 @@ bool profile_gemm_reduce_impl(int do_verification,
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::string gemm_name = gemm_ptr->GetTypeString();
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
@@ -317,9 +360,9 @@ bool profile_gemm_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result);
ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result);
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
pass = pass & ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result);
pass = pass & ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result);
if(do_log)
{
@@ -346,7 +389,7 @@ bool profile_gemm_reduce_impl(int do_verification,
}
else
{
std::cout << "does not support this GEMM problem" << std::endl;
std::cout << gemm_name << ": does not support this GEMM problem" << std::endl;
}
}