mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
Test fixes.
This commit is contained in:
@@ -60,7 +60,7 @@ struct ReferenceCGemm : public device::BaseOperator
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_mk_kn_mn_real = [&](auto m, auto n) {
|
||||
const int K = arg.a_m_k_real_.mDesc.GetLengths()[1];
|
||||
const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1];
|
||||
|
||||
if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1])
|
||||
{
|
||||
@@ -69,7 +69,7 @@ struct ReferenceCGemm : public device::BaseOperator
|
||||
|
||||
float v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
float v_a_real;
|
||||
float v_b_real;
|
||||
@@ -92,7 +92,7 @@ struct ReferenceCGemm : public device::BaseOperator
|
||||
};
|
||||
|
||||
auto f_mk_kn_mn_imag = [&](auto m, auto n) {
|
||||
const int K = arg.a_m_k_real_.mDesc.GetLengths()[1];
|
||||
const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1];
|
||||
|
||||
if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1])
|
||||
{
|
||||
@@ -101,7 +101,7 @@ struct ReferenceCGemm : public device::BaseOperator
|
||||
|
||||
float v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
float v_a_real;
|
||||
float v_b_real;
|
||||
|
||||
@@ -21,9 +21,9 @@
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using DeviceCGemmNoOpPtr =
|
||||
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
ck::tensor_operation::device::DeviceCGemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -48,9 +48,9 @@ int main()
|
||||
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
bool res = true;
|
||||
std::vector<DeviceCGemmNoOpPtr> gemmPtrs;
|
||||
std::vector<DeviceCGemmNoOpPtr> cgemmPtrs;
|
||||
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(cgemmPtrs);
|
||||
|
||||
for(auto& cgemmPtr : cgemmPtrs)
|
||||
@@ -76,7 +76,7 @@ int main()
|
||||
RowMajor,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>{}(gemmPtr);
|
||||
PassThrough>{}(cgemmPtr);
|
||||
}
|
||||
|
||||
cgemmPtrs.clear();
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using DeviceCGemmNoOpPtr =
|
||||
ck::tensor_operation::device::DevicecgemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::device::DeviceCGemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
@@ -50,10 +50,7 @@ int main()
|
||||
|
||||
bool res = true;
|
||||
std::vector<DeviceCGemmNoOpPtr> cgemmPtrs;
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_f16_f16_f16_km_kn_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(cgemmPtrs);
|
||||
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(cgemmPtrs);
|
||||
|
||||
@@ -72,10 +69,6 @@ int main()
|
||||
}
|
||||
|
||||
cgemmPtrs.clear();
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_f16_f16_f16_km_nk_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(cgemmPtrs);
|
||||
|
||||
@@ -94,10 +87,6 @@ int main()
|
||||
}
|
||||
|
||||
cgemmPtrs.clear();
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs);
|
||||
|
||||
@@ -116,14 +105,8 @@ int main()
|
||||
}
|
||||
|
||||
cgemmPtrs.clear();
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs);
|
||||
|
||||
for(auto& cgemmPtr : cgemmPtrs)
|
||||
{
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using DeviceCGemmNoOpPtr =
|
||||
ck::tensor_operation::device::DevicecgemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::device::DeviceCGemmPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
@@ -54,10 +54,7 @@ int main()
|
||||
|
||||
bool res = true;
|
||||
std::vector<DeviceCGemmNoOpPtr> cgemmPtrs;
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_f32_f32_f32_km_kn_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(cgemmPtrs);
|
||||
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(cgemmPtrs);
|
||||
|
||||
@@ -76,10 +73,6 @@ int main()
|
||||
}
|
||||
|
||||
cgemmPtrs.clear();
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_f32_f32_f32_km_nk_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(cgemmPtrs);
|
||||
|
||||
@@ -98,10 +91,6 @@ int main()
|
||||
}
|
||||
|
||||
cgemmPtrs.clear();
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs);
|
||||
|
||||
@@ -120,10 +109,6 @@ int main()
|
||||
}
|
||||
|
||||
cgemmPtrs.clear();
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs);
|
||||
ck::tensor_operation::device::device_cgemm_instance::
|
||||
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs);
|
||||
|
||||
|
||||
@@ -77,21 +77,23 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace());
|
||||
DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace());
|
||||
DeviceMem aux_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace());
|
||||
DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * A_real.mDesc.GetElementSpace());
|
||||
DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * A_imag.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * B_real.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * B_imag.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C_real.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C_imag.mDesc.GetElementSpace());
|
||||
DeviceMem aux_device_buf(sizeof(CDataType) * Aux.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(A.mData.data());
|
||||
b_k_n_device_buf.ToDevice(B.mData.data());
|
||||
a_m_k_real_device_buf.ToDevice(A_real.mData.data());
|
||||
a_m_k_imag_device_buf.ToDevice(A_imag.mData.data());
|
||||
b_k_n_real_device_buf.ToDevice(B_real.mData.data());
|
||||
b_k_n_imag_device_buf.ToDevice(B_imag.mData.data());
|
||||
|
||||
auto invoker_ptr = cgemmPtr->MakeInvokerPointer();
|
||||
auto argument_ptr = cgemmPtr->MakeArgumentPointer(
|
||||
static_cast<ADataType*>(a_m_k_real_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ADataType*>(a_m_k_real_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ADataType*>(a_m_k_imag_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_real_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_imag_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()),
|
||||
@@ -255,7 +257,7 @@ struct TestCGemm
|
||||
if(std::is_same<CDataType, float>::value)
|
||||
{
|
||||
res = ck::utils::check_err(c_device_real.mData, c_host_real.mData) &&
|
||||
ck::utils::check_err(c_device_real.mData, c_host.mData);
|
||||
ck::utils::check_err(c_device_imag.mData, c_host_imag.mData);
|
||||
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
else if(std::is_same<CDataType, ck::half_t>::value)
|
||||
@@ -326,15 +328,13 @@ struct TestCGemmBF16
|
||||
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
|
||||
Tensor<float> b_k_n_imag_fp32(
|
||||
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
|
||||
Tensor<float> c_m_n_host_real_fp32(
|
||||
Tensor<float> c_m_n_real_host_fp32(
|
||||
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
||||
Tensor<float> c_m_n_host_imag_fp32(
|
||||
Tensor<float> c_m_n_imag_host_fp32(
|
||||
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
||||
Tensor<float> c_m_n_device_real_fp32(
|
||||
Tensor<float> c_m_n_real_device_fp32(
|
||||
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
||||
Tensor<float> c_m_n_device_imag_fp32(
|
||||
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
||||
Tensor<float> aux_fp32(
|
||||
Tensor<float> c_m_n_imag_device_fp32(
|
||||
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
|
||||
|
||||
a_m_k_real_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5});
|
||||
@@ -361,8 +361,7 @@ struct TestCGemmBF16
|
||||
c_m_n_real_host_fp32,
|
||||
c_m_n_imag_host_fp32,
|
||||
c_m_n_real_device_fp32,
|
||||
c_m_n_imag_device_fp32,
|
||||
aux_fp32);
|
||||
c_m_n_imag_device_fp32);
|
||||
}
|
||||
|
||||
auto operator()(DeviceCGemmPtr_& cgemmPtr)
|
||||
@@ -392,43 +391,42 @@ struct TestCGemmBF16
|
||||
Tensor<float>& c_imag_host_fp32 = std::get<12>(host_tensors);
|
||||
Tensor<float>& c_real_device_fp32 = std::get<13>(host_tensors);
|
||||
Tensor<float>& c_imag_device_fp32 = std::get<14>(host_tensors);
|
||||
Tensor<float>& aux_fp32 = std::get<15>(host_tensors);
|
||||
|
||||
auto a_element_op = AElementwiseOperation{};
|
||||
auto b_element_op = BElementwiseOperation{};
|
||||
auto c_element_op = CElementwiseOperation{};
|
||||
|
||||
// use fp32 host kernel to verify bf16 device kernel
|
||||
using ReferenceGemmInstance =
|
||||
using ReferenceCGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceCGemm<float,
|
||||
float,
|
||||
float,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>;
|
||||
ck::gemm_util::RunHostCGEMM<ReferenceCGemmInstance>(a_real_fp32,
|
||||
a_imag_fp32,
|
||||
b_real_fp32,
|
||||
b_imag_fp32,
|
||||
c_real_host_fp32,
|
||||
c_imag_fp32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
ck::cgemm_util::RunHostCGEMM<ReferenceCGemmInstance>(a_real_fp32,
|
||||
a_imag_fp32,
|
||||
b_real_fp32,
|
||||
b_imag_fp32,
|
||||
c_real_host_fp32,
|
||||
c_imag_host_fp32,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
// Act
|
||||
ck::gemm_util::RunDeviceCGEMM(cgemmPtr,
|
||||
params,
|
||||
a_real_bf16,
|
||||
a_imag_bf16,
|
||||
b_real_bf16,
|
||||
b_imag_bf16,
|
||||
c_real_device_bf16,
|
||||
c_imag_device_bf16,
|
||||
aux_bf16,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
ck::cgemm_util::RunDeviceCGEMM(cgemmPtr,
|
||||
params,
|
||||
a_real_bf16,
|
||||
a_imag_bf16,
|
||||
b_real_bf16,
|
||||
b_imag_bf16,
|
||||
c_real_device_bf16,
|
||||
c_imag_device_bf16,
|
||||
aux_bf16,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
bf16_to_f32_(c_real_device_bf16, c_real_device_fp32);
|
||||
bf16_to_f32_(c_imag_device_bf16, c_imag_device_fp32);
|
||||
|
||||
Reference in New Issue
Block a user