mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Ck profiler splitk (#857)
* updated regular gemm * update ckProfiler * fixed gtests --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -94,7 +94,6 @@ bool profile_gemm_splitk_impl(int do_verification,
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_device_buf.SetZero();
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
@@ -136,60 +135,84 @@ bool profile_gemm_splitk_impl(int do_verification,
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
float best_kbatch = 0;
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
KBatch);
|
||||
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 36, 40, 60,
|
||||
64, 72, 80, 88, 96, 128, 144, 160, 176, 192, 256};
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
if(KBatch > 0)
|
||||
{
|
||||
// re-init C to zero before profiling next kernel
|
||||
c_device_buf.SetZero();
|
||||
kbatch_list = {KBatch};
|
||||
}
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
for(std::size_t i = 0; i < kbatch_list.size(); i++)
|
||||
{
|
||||
auto kbatch_curr = kbatch_list[i];
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
auto argument_ptr =
|
||||
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
kbatch_curr);
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
// re-init C to zero before profiling next kernel
|
||||
c_device_buf.SetZero();
|
||||
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_device: ", c_m_n_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
|
||||
<< kbatch_curr << std::endl;
|
||||
|
||||
// set softer tolerances for fp8
|
||||
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
|
||||
@@ -206,20 +229,20 @@ bool profile_gemm_splitk_impl(int do_verification,
|
||||
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
|
||||
}
|
||||
|
||||
if(do_log)
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_kbatch = kbatch_curr;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
|
||||
else
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,7 +282,7 @@ bool profile_gemm_splitk_impl(int do_verification,
|
||||
}
|
||||
|
||||
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
|
||||
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << KBatch
|
||||
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
|
||||
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
|
||||
<< " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
|
||||
std::vector<Tensor<ADataType>> a_m_k;
|
||||
std::vector<Tensor<BDataType>> b_k_n;
|
||||
std::vector<Tensor<CDataType>> c_m_n_host_results;
|
||||
std::vector<Tensor<CDataType>> c_m_n_device_results;
|
||||
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
@@ -81,6 +82,9 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
|
||||
c_m_n_device_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
|
||||
c_m_n_host_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
#if DEBUG_LOG
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i
|
||||
<< "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
@@ -137,7 +141,6 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
|
||||
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
|
||||
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
|
||||
c_device_buf[i]->SetZero();
|
||||
|
||||
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
|
||||
|
||||
@@ -170,9 +173,36 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
float best_kbatch = 0;
|
||||
|
||||
auto p_ds = std::vector<std::array<const void*, 0>>{};
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
|
||||
b_k_n[i],
|
||||
c_m_n_host_results[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
}
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& gemm_ptr : op_ptrs)
|
||||
{
|
||||
@@ -193,139 +223,135 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
if(kbatch > 1)
|
||||
{
|
||||
using DeviceOpSplitK =
|
||||
ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<>,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
using DeviceOpSplitK = ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<>,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) != nullptr)
|
||||
{
|
||||
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
|
||||
->SetKBatchSize(argument_ptr.get(), kbatch);
|
||||
}
|
||||
// skip non-splitk grouped_gemm
|
||||
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) == nullptr)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64};
|
||||
|
||||
if(kbatch > 0)
|
||||
{
|
||||
kbatch_list = {kbatch};
|
||||
}
|
||||
|
||||
for(std::size_t j = 0; j < kbatch_list.size(); j++)
|
||||
{
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
auto kbatch_curr = kbatch_list[j];
|
||||
|
||||
if(time_kernel)
|
||||
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
|
||||
->SetKBatchSize(argument_ptr.get(), kbatch_curr);
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
|
||||
|
||||
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
|
||||
sizeof(BDataType) * Ks[i] * Ns[i] +
|
||||
sizeof(CDataType) * Ms[i] * Ns[i];
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_gemm_name = gemm_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
bool instance_pass = true;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
|
||||
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
|
||||
c_device_buf[i]->SetZero();
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}));
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
|
||||
b_k_n[i],
|
||||
c_m_n_host_result,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
if(std::is_same_v<CDataType, ck::half_t> && kbatch > 1)
|
||||
if(do_verification)
|
||||
{
|
||||
bool instance_pass = true;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
instance_pass =
|
||||
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
0.06);
|
||||
}
|
||||
else
|
||||
{
|
||||
instance_pass =
|
||||
instance_pass &&
|
||||
ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result);
|
||||
|
||||
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
|
||||
|
||||
if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1)
|
||||
{
|
||||
instance_pass =
|
||||
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
|
||||
c_m_n_host_results[i],
|
||||
"Error: Incorrect results!",
|
||||
0.06);
|
||||
}
|
||||
else
|
||||
{
|
||||
instance_pass =
|
||||
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
|
||||
c_m_n_host_results[i]);
|
||||
}
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_host : ", c_m_n_host_results[i].mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "Instance: " << gemm_name << " verification "
|
||||
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
|
||||
|
||||
pass = pass && instance_pass;
|
||||
}
|
||||
|
||||
std::cout << "Instance: " << gemm_name << " verification "
|
||||
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
pass = pass && instance_pass;
|
||||
if(time_kernel)
|
||||
{
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
|
||||
|
||||
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
|
||||
sizeof(BDataType) * Ks[i] * Ns[i] +
|
||||
sizeof(CDataType) * Ms[i] * Ns[i];
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch "
|
||||
<< kbatch_curr << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_gemm_name = gemm_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_kbatch = kbatch_curr;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
|
||||
<< best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
|
||||
Reference in New Issue
Block a user