mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Wmma support for gemm_reduce (#3145)
* Initial implementation GEMM+Reduce:
- device struct
- epilogue struct
* Fix tests, improve profiler and add initial instances
* Add instances
* Fix compilation error
* Address review comments
* Fix logging
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
[ROCm/composable_kernel commit: 7414a0f4d4]
This commit is contained in:
@@ -34,6 +34,7 @@ using ReduceOutElementOps = ck::Tuple<Div, Div>;
|
||||
using DeviceGemmReduceNoOpPtr =
|
||||
ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
|
||||
@@ -45,6 +46,20 @@ void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
|
||||
|
||||
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
@@ -211,33 +226,61 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -274,6 +317,8 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
++num_kernel;
|
||||
@@ -289,8 +334,6 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
@@ -317,9 +360,9 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
|
||||
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
|
||||
|
||||
ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
|
||||
ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result);
|
||||
ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result);
|
||||
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
|
||||
pass = pass & ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result);
|
||||
pass = pass & ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
@@ -346,7 +389,7 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "does not support this GEMM problem" << std::endl;
|
||||
std::cout << gemm_name << ": does not support this GEMM problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user