mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
* Partial Progress : Preshuffle working code for datatype * Partial Progress : Preshuffle Cleanup * Working code for default config with min max step * Partial Progress : PermuteN implemented in validation * Partial Progress : PermuteN changes in Preshuffle * CK Tile Engine Preshuffle Complete * CK TILE ENGINE : Preshuffle Layout validation * CK Tile Engine Preshuffle Validation * Preshuffle Validation check * CK Tile Engine Preshuffle : Fixing Validation Cases * Addressing PR review Comments * Changes in config * Addressing Review Comments * Adding additional architecture in Jenkins * Partial Progress : Selective Datatype and layouts * Limited datatypes and layouts * Addressing CI errors * Datatype updates * Datatype updates * Datatype changes to Preshuffle * Addressing Review Comments * Addressing Review Comments * Datatype changes * Changes to Cmake * Update on Jenkins * Formatting with precommit * Ruff Formatting
234 lines
8.1 KiB
C++
234 lines
8.1 KiB
C++
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host.hpp"
|
|
#include "gemm_preshuffle_common.hpp"
|
|
|
|
//[TODO] Move parts of this File to commons
|
|
enum class Metric
|
|
{
|
|
LATENCY = 0,
|
|
TFLOPS = 1,
|
|
BANDWIDTH = 2
|
|
};
|
|
|
|
inline constexpr auto get_metric_name(Metric m)
|
|
{
|
|
switch(m)
|
|
{
|
|
case Metric::LATENCY: return "latency";
|
|
case Metric::TFLOPS: return "tflops";
|
|
case Metric::BANDWIDTH: return "bandwidth";
|
|
default: throw std::invalid_argument("Unsupported metric type");
|
|
}
|
|
}
|
|
|
|
struct KernelConfig
|
|
{
|
|
std::tuple<int, int, int> tile_dims;
|
|
std::tuple<int, int, int> warp_dims;
|
|
std::tuple<int, int, int> warp_tile_dims;
|
|
bool permuteN;
|
|
};
|
|
|
|
struct GemmProblem
|
|
{
|
|
int split_k_;
|
|
int m_, n_, k_;
|
|
int stride_a_, stride_b_, stride_c_;
|
|
|
|
std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_;
|
|
std::string layout_a_, layout_b_, layout_c_;
|
|
|
|
bool structured_sparsity_;
|
|
|
|
friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem)
|
|
{
|
|
os << "{\n"
|
|
<< " \"split_k\":" << problem.split_k_ << ",\n"
|
|
<< " \"m\":" << problem.m_ << ",\n"
|
|
<< " \"n\":" << problem.n_ << ",\n"
|
|
<< " \"k\":" << problem.k_ << ",\n"
|
|
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
|
|
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
|
|
<< " \"stride_c\":" << problem.stride_c_ << ",\n"
|
|
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
|
|
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
|
|
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
|
|
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
|
|
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
|
|
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
|
|
<< " \"layout_c\":\"" << problem.layout_c_ << "\",\n"
|
|
<< " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false")
|
|
<< "\n"
|
|
<< "}";
|
|
return os;
|
|
}
|
|
};
|
|
|
|
struct PerformanceResult
|
|
{
|
|
double latency_;
|
|
double tflops_;
|
|
double bandwidth_;
|
|
|
|
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
|
|
{
|
|
switch(m)
|
|
{
|
|
case Metric::LATENCY: return a.latency_ < b.latency_;
|
|
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
|
|
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
|
|
default: throw std::invalid_argument("Unsupported metric type");
|
|
}
|
|
}
|
|
|
|
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
|
|
{
|
|
os << "{\n"
|
|
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
|
|
<< ",\n"
|
|
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
|
|
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
|
|
<< "}";
|
|
return os;
|
|
}
|
|
};
|
|
|
|
struct KernelInstance
|
|
{
|
|
std::string name_;
|
|
GemmProblem problem_;
|
|
PerformanceResult perf_result_;
|
|
|
|
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
|
|
{
|
|
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
|
|
}
|
|
|
|
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
|
|
{
|
|
os << "{\n"
|
|
<< " \"name\": \"" << obj.name_ << "\",\n"
|
|
<< " \"problem\": " << obj.problem_ << ",\n"
|
|
<< " \"perf_result\": " << obj.perf_result_ << "\n"
|
|
<< "}";
|
|
return os;
|
|
}
|
|
};
|
|
|
|
struct Setting
|
|
{
|
|
int n_warmup_;
|
|
int n_repeat_;
|
|
bool is_gpu_timer_;
|
|
int verify_;
|
|
int init_method_;
|
|
bool log_;
|
|
std::string csv_filename_;
|
|
bool flush_cache_;
|
|
int rotating_count_;
|
|
bool json_output_;
|
|
};
|
|
|
|
inline std::string get_rocm_version()
|
|
{
|
|
std::ifstream version_file("/opt/rocm/.info/version");
|
|
if(version_file.is_open())
|
|
{
|
|
std::string version;
|
|
std::getline(version_file, version);
|
|
return version;
|
|
}
|
|
return "Unknown";
|
|
}
|
|
|
|
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
|
auto calculate_rtol_atol(const ck_tile::index_t K,
|
|
const ck_tile::index_t kbatch,
|
|
const float max_accumulated_value)
|
|
{
|
|
using ComputeType =
|
|
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
|
// Calculate thresholds
|
|
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
|
ck_tile::integer_divide_ceil(K, kbatch));
|
|
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
|
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
|
// Calculate error due to split_k accumulation
|
|
const auto rtol_split_k =
|
|
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
|
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
|
max_accumulated_value, kbatch);
|
|
// Use higher threshold
|
|
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
|
}
|
|
|
|
/// @brief Function to compare the results of the device and host computations
|
|
bool compare(std::string instanceName,
|
|
ck_tile::index_t K,
|
|
ck_tile::index_t kbatch,
|
|
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
|
ck_tile::HostTensor<CDataType>& c_m_n_ref)
|
|
{
|
|
const float max_accumulated_value =
|
|
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
|
|
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
|
K, kbatch, max_accumulated_value);
|
|
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
|
c_m_n_ref,
|
|
"Error: Incorrect results!",
|
|
rtol_atol.at(ck_tile::number<0>{}),
|
|
rtol_atol.at(ck_tile::number<1>{}));
|
|
|
|
std::cout << "For " << instanceName << " Relative error threshold is "
|
|
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
|
|
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
|
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
|
|
|
return pass;
|
|
}
|
|
|
|
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
|
|
void gemm_host_reference(int verify,
|
|
ck_tile::HostTensor<ADataType>& a_m_k,
|
|
ck_tile::HostTensor<BDataType>& b_k_n,
|
|
ck_tile::HostTensor<CDataType>& c_m_n_ref,
|
|
ck_tile::DeviceMem& a_m_k_dev_buf,
|
|
ck_tile::DeviceMem& b_k_n_dev_buf,
|
|
ck_tile::index_t M,
|
|
ck_tile::index_t N,
|
|
ck_tile::index_t K,
|
|
ck_tile::index_t stride_A,
|
|
ck_tile::index_t stride_B,
|
|
ck_tile::index_t stride_C)
|
|
{
|
|
if(verify == 1)
|
|
{
|
|
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
|
a_m_k, b_k_n, c_m_n_ref);
|
|
}
|
|
else if(verify == 2)
|
|
{
|
|
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
|
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
|
|
|
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes());
|
|
c_m_n_gpu_buf_ref.SetZero();
|
|
|
|
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
|
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
|
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
|
|
|
ck_tile::reference_gemm_gpu<ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CDataType,
|
|
ALayout,
|
|
BLayout,
|
|
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
|
|
|
c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());
|
|
}
|
|
}
|