mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Improve 4k gemm perf (#1047)
* improve 4k gemm perf * add f8 instances * format --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <typeinfo>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
@@ -20,6 +21,7 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -69,14 +71,17 @@ int profile_gemm_impl(int do_verification,
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 0:
|
||||
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
|
||||
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
|
||||
break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 0.1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.01, 0.01});
|
||||
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
@@ -130,11 +135,10 @@ int profile_gemm_impl(int do_verification,
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
float best_tflops = 0;
|
||||
int best_instance_id = 0;
|
||||
|
||||
int instance_id = 0;
|
||||
// profile device op instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
@@ -178,10 +182,8 @@ int profile_gemm_impl(int do_verification,
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_instance_id = instance_id;
|
||||
best_tflops = tflops;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
@@ -205,53 +207,94 @@ int profile_gemm_impl(int do_verification,
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
|
||||
}
|
||||
|
||||
instance_id++;
|
||||
}
|
||||
|
||||
if constexpr(is_same<CDataType, float>::value)
|
||||
sleep(2);
|
||||
|
||||
// Run the best instance again
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f32";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, half_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f16";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, bhalf_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = bf16";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = int8";
|
||||
}
|
||||
auto& op_ptr = op_ptrs[best_instance_id];
|
||||
auto argument_ptr =
|
||||
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(),
|
||||
StreamConfig{nullptr, time_kernel, 0, 50, 200});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
if constexpr(is_same<CDataType, float>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f32";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, half_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f16";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, bhalf_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = bf16";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = int8";
|
||||
}
|
||||
#if defined CK_ENABLE_FP8
|
||||
else if constexpr(is_same<CDataType, f8_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = fp8";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, f8_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = fp8";
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = ColumnMajor";
|
||||
}
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = ColumnMajor";
|
||||
}
|
||||
|
||||
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = ColumnMajor";
|
||||
}
|
||||
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = ColumnMajor";
|
||||
}
|
||||
|
||||
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
|
||||
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_avg_time
|
||||
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
|
||||
<< best_op_name << std::endl;
|
||||
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
|
||||
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << avg_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << op_name
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user