CK pk_i4_t test failures fix (SWDEV-518629) (#2075)

* fix pk_i4_v3 tests failures in Unbuntu env.

* fix pk_i4_t tests failure on Unbuntu issues.

* some fixed.

---------

Co-authored-by: mtgu0705 <mtgu@amd.com>
This commit is contained in:
Mingtao Gu
2025-04-14 16:58:57 +08:00
committed by GitHub
parent 269f4f6af5
commit 56378f810f
13 changed files with 148 additions and 42 deletions

View File

@@ -133,7 +133,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2);
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute
@@ -192,14 +192,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" ||
ck::get_device_name() != "gfx950")
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
return true;
}
bool pass = true;
if(config.do_verification)
{

View File

@@ -134,7 +134,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2);
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute
@@ -242,14 +242,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" ||
ck::get_device_name() != "gfx950")
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
return true;
}
bool pass = true;
if(config.do_verification)
{

View File

@@ -161,7 +161,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2);
DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
@@ -274,14 +274,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" ||
ck::get_device_name() != "gfx950")
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
return true;
}
bool pass = true;
if(config.do_verification)
{

View File

@@ -152,7 +152,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_preshuffled.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_preshuffled.mDesc.GetElementSpaceSize() /
2);
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// do GEMM
@@ -261,14 +262,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" ||
ck::get_device_name() != "gfx950")
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
return true;
}
bool pass = true;
if(config.do_verification)
{

View File

@@ -132,7 +132,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2);
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute
@@ -240,14 +240,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" ||
ck::get_device_name() != "gfx950")
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
return true;
}
bool pass = true;
if(config.do_verification)
{