mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
@@ -32,10 +32,14 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -32,11 +32,15 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
|
||||
}
|
||||
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmNoOpPtr>&);
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -55,10 +55,10 @@ check_err(const std::vector<T>& out,
|
||||
}
|
||||
|
||||
bool check_err(const std::vector<_Float16>& out,
|
||||
const std::vector<_Float16>& ref,
|
||||
const std::string& msg,
|
||||
_Float16 rtol = static_cast<_Float16>(1e-3f),
|
||||
_Float16 atol = static_cast<_Float16>(1e-3f))
|
||||
const std::vector<_Float16>& ref,
|
||||
const std::string& msg,
|
||||
_Float16 rtol = static_cast<_Float16>(1e-3f),
|
||||
_Float16 atol = static_cast<_Float16>(1e-3f))
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
@@ -69,14 +69,14 @@ bool check_err(const std::vector<_Float16>& out,
|
||||
}
|
||||
|
||||
bool res{true};
|
||||
int err_count = 0;
|
||||
double err = 0;
|
||||
double max_err = std::numeric_limits<_Float16>::min();
|
||||
int err_count = 0;
|
||||
double err = 0;
|
||||
double max_err = std::numeric_limits<_Float16>::min();
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
double out_ = double(out[i]);
|
||||
double ref_ = double(ref[i]);
|
||||
err = std::abs(out_ - ref_);
|
||||
err = std::abs(out_ - ref_);
|
||||
if(err > atol + rtol * std::abs(ref_) || !std::isfinite(out_) || !std::isfinite(ref_))
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
|
||||
Reference in New Issue
Block a user