mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
Merge commit 'b09121f86066381f3662fdbdee6a810849a8a1a7' into develop
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "profiler/common.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
@@ -33,6 +34,8 @@ using ReduceOutElementOps = ck::Tuple<Identity, Identity>;
|
||||
using DeviceGemmReduceNoOpPtr =
|
||||
ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>;
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
#ifdef CK_USE_XDL
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
|
||||
@@ -44,6 +47,22 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn
|
||||
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
|
||||
void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
|
||||
void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
|
||||
void add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmReduceNoOpPtr>&);
|
||||
#endif // CK_USE_WMMA
|
||||
#endif // CK_ENABLE_FP16
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
@@ -210,6 +229,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
{
|
||||
@@ -217,35 +237,64 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
ck::tensor_operation::device::instance::
|
||||
add_device_batched_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
|
||||
gemm_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
|
||||
if(gemm_ptrs.size() <= 0)
|
||||
{
|
||||
@@ -318,9 +367,21 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
|
||||
reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
|
||||
reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
|
||||
|
||||
bool c_error = ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
|
||||
bool d0_error = ck::utils::check_err(d0_g_m_device_result, d0_g_m_host_result);
|
||||
bool d1_error = ck::utils::check_err(d1_g_m_device_result, d1_g_m_host_result);
|
||||
bool c_error = ck::utils::check_err(c_g_m_n_device_result,
|
||||
c_g_m_n_host_result,
|
||||
"Error: Device and Host results do not match!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
bool d0_error = ck::utils::check_err(d0_g_m_device_result,
|
||||
d0_g_m_host_result,
|
||||
"Error: Device and Host results do not match!",
|
||||
get_rtol<ReduceDataType>(),
|
||||
get_atol<ReduceDataType>());
|
||||
bool d1_error = ck::utils::check_err(d1_g_m_device_result,
|
||||
d1_g_m_host_result,
|
||||
"Error: Device and Host results do not match!",
|
||||
get_rtol<ReduceDataType>(),
|
||||
get_atol<ReduceDataType>());
|
||||
|
||||
pass = pass && (c_error == true);
|
||||
pass = pass && (d0_error == true);
|
||||
|
||||
Reference in New Issue
Block a user