[CK TILE GEMM STREAMK] update identifier names according to the new code style (#3348)

* [CK TILE GEMM STREAMK] update identifier names according to the new code style
This commit is contained in:
Cong Ma
2025-12-12 17:08:26 -07:00
committed by GitHub
parent b4a34371a6
commit 9707ddb444
3 changed files with 328 additions and 312 deletions

View File

@@ -7,46 +7,46 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
struct GemmConfigBase
struct GemmConfigurationBase
{
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool PAD_M = true;
static constexpr bool PAD_N = true;
static constexpr bool PAD_K = true;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool PERMUTE_A = false;
static constexpr bool PERMUTE_B = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool TRANSPOSE_C = false;
static constexpr bool USE_STRUCTURED_SPARSITY = false;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int BLOCK_PER_CU = 1;
static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NUM_WAVE_GROUPS = 1;
static constexpr bool PRESHUFFLE = false;
static constexpr bool DOUBLE_SMEM_BUFFER = false;
};
template <typename PrecType, bool Persistent_>
struct GemmConfigMemoryInterwave : public GemmConfigBase
template <typename PrecisionType, bool IsPersistent>
struct GemmConfigurationMemoryInterwave : public GemmConfigurationBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 16;
static constexpr ck_tile::index_t M_TILE = 256;
static constexpr ck_tile::index_t N_TILE = 256;
static constexpr ck_tile::index_t K_TILE = 16;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_WARP = 2;
static constexpr ck_tile::index_t N_WARP = 2;
static constexpr ck_tile::index_t K_WARP = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr ck_tile::index_t M_WARP_TILE = 32;
static constexpr ck_tile::index_t N_WARP_TILE = 32;
static constexpr ck_tile::index_t K_WARP_TILE = sizeof(PrecisionType) == 2 ? 8 : 16;
static constexpr bool Persistent = Persistent_;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool PERSISTENT = IsPersistent;
static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
struct StreamKGemmTypeConfig
struct StreamKGemmTypeConfiguration
{
using ADataType = ADataType_;
using BDataType = BDataType_;
@@ -54,7 +54,7 @@ struct StreamKGemmTypeConfig
using CDataType = CDataType_;
};
auto create_args(int argc, char* argv[])
auto createArgs(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "512", "m dimension")

View File

@@ -12,31 +12,35 @@ static constexpr inline auto is_row_major(Layout)
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
auto calculateRtolAtol(const ck_tile::index_t k_dim,
const ck_tile::index_t k_batch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
const auto relative_tolerance =
ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(k_dim, k_batch));
const auto absolute_tolerance =
ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / k_batch, ck_tile::integer_divide_ceil(k_dim, k_batch));
// Calculate error due to multiple WGs working in the same C macro tile
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
const auto relative_tolerance_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(k_batch);
const auto absolute_tolerance_split_k =
ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(max_accumulated_value,
k_batch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
return ck_tile::make_tuple(std::max(relative_tolerance, relative_tolerance_split_k),
std::max(absolute_tolerance, absolute_tolerance_split_k));
}
template <typename GemmConfig,
template <typename GemmConfiguration,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename AccumulatorDataType,
typename CDataType,
typename ALayout,
typename BLayout,
@@ -45,102 +49,107 @@ template <typename GemmConfig,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& s);
const ck_tile::stream_config& stream_config);
template <typename GemmConfig,
template <typename GemmConfiguration,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename AccumulatorDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
std::tuple<float, ck_tile::index_t>
invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
int n_warmup,
int n_repeat,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy)
std::tuple<float, ck_tile::index_t> invokeGemm(ck_tile::DeviceMem& a_m_k_device_memory,
ck_tile::DeviceMem& b_k_n_device_memory,
ck_tile::DeviceMem& c_m_n_device_memory,
ck_tile::index_t m_dim,
ck_tile::index_t n_dim,
ck_tile::index_t k_dim,
ck_tile::index_t stride_a,
ck_tile::index_t stride_b,
ck_tile::index_t stride_c,
int warmup_iterations,
int repeat_iterations,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy)
{
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C};
ck_tile::StreamKHostArgs args{a_m_k_device_memory.GetDeviceBuffer(),
b_k_n_device_memory.GetDeviceBuffer(),
c_m_n_device_memory.GetDeviceBuffer(),
m_dim,
n_dim,
k_dim,
stride_a,
stride_b,
stride_c};
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
std::tuple<float, ck_tile::index_t> average_time_and_batch;
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
{
ave_time_and_batch = gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Atomic>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache});
average_time_and_batch = gemm<GemmConfiguration,
ADataType,
BDataType,
DsDataType,
AccumulatorDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Atomic>(
args,
ck_tile::stream_config{
nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache});
}
else /*Reduction*/
{
ave_time_and_batch = gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Reduction>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache});
average_time_and_batch = gemm<GemmConfiguration,
ADataType,
BDataType,
DsDataType,
AccumulatorDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Reduction>(
args,
ck_tile::stream_config{
nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache});
}
return ave_time_and_batch;
return average_time_and_batch;
}
template <typename CDataType>
bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const ck_tile::HostTensor<CDataType>& c_m_n_ref,
const ck_tile::tuple<double, double>& rtol_atol,
const char* variant)
bool doVerify(const ck_tile::HostTensor<CDataType>& c_m_n_device_result,
const ck_tile::HostTensor<CDataType>& c_m_n_reference,
const ck_tile::tuple<double, double>& relative_absolute_tolerances,
const char* variant)
{
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_ref,
bool pass = ck_tile::check_err(c_m_n_device_result,
c_m_n_reference,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
relative_absolute_tolerances.at(ck_tile::number<0>{}),
relative_absolute_tolerances.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "Relative error threshold: "
<< relative_absolute_tolerances.at(ck_tile::number<0>{})
<< " Absolute error threshold: "
<< relative_absolute_tolerances.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The " << variant << " verification result is:" << (pass ? "correct" : "fail")
<< std::endl;
return pass;
}
ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string& strategy)
ck_tile::StreamKReductionStrategy getReductionStrategyValue(const std::string& strategy)
{
if(strategy == "atomic")
{
@@ -156,172 +165,169 @@ ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string
}
}
template <typename GemmConfig,
typename TypeConfig,
template <typename GemmConfiguration,
typename TypeConfiguration,
typename ALayout,
typename BLayout,
typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
int runGemmExampleWithLayouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
auto [result, arg_parser] = createArgs(argc, argv);
if(!result)
return -1;
static_assert(!GemmConfig::Preshuffle, "Not implemented");
static_assert(!GemmConfig::UseStructuredSparsity, "Not implemented");
static_assert(!GemmConfig::PermuteA, "Not implemented");
static_assert(!GemmConfig::PermuteB, "Not implemented");
static_assert(!GemmConfiguration::PRESHUFFLE, "Not implemented");
static_assert(!GemmConfiguration::USE_STRUCTURED_SPARSITY, "Not implemented");
static_assert(!GemmConfiguration::PERMUTE_A, "Not implemented");
static_assert(!GemmConfiguration::PERMUTE_B, "Not implemented");
using ADataType = typename TypeConfig::ADataType;
using BDataType = typename TypeConfig::BDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
using ADataType = typename TypeConfiguration::ADataType;
using BDataType = typename TypeConfiguration::BDataType;
using AccumulatorDataType = typename TypeConfiguration::AccDataType;
using CDataType = typename TypeConfiguration::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t m_dim = arg_parser.get_int("m");
ck_tile::index_t n_dim = arg_parser.get_int("n");
ck_tile::index_t k_dim = arg_parser.get_int("k");
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
int warmup_iterations = arg_parser.get_int("warmup");
int repeat_iterations = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool flush_cache = arg_parser.get_bool("flush_cache");
ck_tile::StreamKReductionStrategy reduction_strategy =
get_reduction_strategy_value(arg_parser.get_str("reduction_strategy"));
getReductionStrategyValue(arg_parser.get_str("reduction_strategy"));
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
stride_a = ck_tile::get_default_stride(m_dim, k_dim, stride_a, is_row_major(a_layout));
stride_b = ck_tile::get_default_stride(k_dim, n_dim, stride_b, is_row_major(b_layout));
stride_c = ck_tile::get_default_stride(m_dim, n_dim, stride_c, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<ADataType> a_m_k_host(
ck_tile::host_tensor_descriptor(m_dim, k_dim, stride_a, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n_host(
ck_tile::host_tensor_descriptor(k_dim, n_dim, stride_b, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_device_result(
ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{})));
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_host);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_host);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k_host);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n_host);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_host);
}
else
{
a_m_k.SetZero();
b_k_n.SetZero();
a_m_k_host.SetZero();
b_k_n_host.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
ck_tile::DeviceMem a_m_k_device_memory(a_m_k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_device_memory(b_k_n_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_device_memory(c_m_n_device_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
a_m_k_device_memory.ToDevice(a_m_k_host.data());
b_k_n_device_memory.ToDevice(b_k_n_host.data());
c_m_n_device_memory.SetZero();
c_m_n_device_result.SetZero();
auto [average_time, num_wgs_per_tile] = invokeGemm<GemmConfiguration,
ADataType,
BDataType,
ck_tile::tuple<>,
AccumulatorDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_device_memory,
b_k_n_device_memory,
c_m_n_device_memory,
m_dim,
n_dim,
k_dim,
stride_a,
stride_b,
stride_c,
warmup_iterations,
repeat_iterations,
flush_cache,
reduction_strategy);
auto [ave_time, num_wgs_per_tile] = invoke_gemm<GemmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
n_warmup,
n_repeat,
flush_cache,
reduction_strategy);
c_m_n_device_memory.FromDevice(c_m_n_device_result.data());
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
std::size_t flop = std::size_t(2) * m_dim * n_dim * k_dim;
std::size_t num_byte = sizeof(ADataType) * m_dim * k_dim + sizeof(BDataType) * n_dim * k_dim +
sizeof(CDataType) * m_dim * n_dim;
float tflops = static_cast<float>(flop) / 1.E9 / average_time;
float gb_per_sec = num_byte / 1.E6 / average_time;
std::cout << "Run Gemm kernel with M=" << m_dim << " N=" << n_dim << " K=" << k_dim
<< " StrideA=" << stride_a << " StrideB=" << stride_b << " StrideC=" << stride_c
<< " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name
<< " C_Layout=" << CLayout::name
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << average_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
bool pass = false;
// Memory on host to store gpu reference result
ck_tile::HostTensor<CDataType> c_m_n_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_ref.SetZero();
ck_tile::HostTensor<CDataType> c_m_n_reference(
ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{})));
c_m_n_reference.SetZero();
if(arg_parser.get_int("v") == 1) // Validate on the CPU
{
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_ref);
ck_tile::reference_gemm<ADataType, BDataType, AccumulatorDataType, CDataType>(
a_m_k_host, b_k_n_host, c_m_n_reference);
const float max_accumulated_value =
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, num_wgs_per_tile, max_accumulated_value);
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
*std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end());
const auto relative_absolute_tolerances =
calculateRtolAtol<ADataType, BDataType, AccumulatorDataType, CDataType>(
k_dim, num_wgs_per_tile, max_accumulated_value);
pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "CPU");
}
else if(arg_parser.get_int("v") == 2) // Validate on the GPU
{
// Memory on device to store gpu reference result
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes());
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::DeviceMem c_m_n_gpu_buffer_reference(
c_m_n_reference.get_element_space_size_in_bytes());
c_m_n_gpu_buffer_reference.SetZero();
ADataType* d_A = static_cast<ADataType*>(a_m_k_device_memory.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_device_memory.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buffer_reference.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
AccumulatorDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());
CLayout>(
d_A, d_B, d_C, m_dim, n_dim, k_dim, stride_a, stride_b, stride_c);
c_m_n_gpu_buffer_reference.FromDevice(c_m_n_reference.data());
const float max_accumulated_value =
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, num_wgs_per_tile, max_accumulated_value);
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU");
*std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end());
const auto relative_absolute_tolerances =
calculateRtolAtol<ADataType, BDataType, AccumulatorDataType, CDataType>(
k_dim, num_wgs_per_tile, max_accumulated_value);
pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "GPU");
}
return pass;

View File

@@ -4,11 +4,11 @@
#include "gemm_utils.hpp"
#include "ck_tile/ops/common.hpp"
template <typename GemmConfig,
template <typename GemmConfiguration,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename AccumulatorDataType,
typename CDataType,
typename ALayout,
typename BLayout,
@@ -17,43 +17,49 @@ template <typename GemmConfig,
typename CDEElementWise,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& s)
const ck_tile::stream_config& stream_config)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<GemmConfiguration::M_TILE,
GemmConfiguration::N_TILE,
GemmConfiguration::K_TILE>,
ck_tile::sequence<GemmConfiguration::M_WARP,
GemmConfiguration::N_WARP,
GemmConfiguration::K_WARP>,
ck_tile::sequence<GemmConfiguration::M_WARP_TILE,
GemmConfiguration::N_WARP_TILE,
GemmConfiguration::K_WARP_TILE>,
GemmConfiguration::PERMUTE_A,
GemmConfiguration::PERMUTE_B>;
using TilePartitioner =
ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
using TilePartitioner = ck_tile::
StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfiguration::PERSISTENT>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
GemmConfig::Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<GemmConfiguration::PAD_M,
GemmConfiguration::PAD_N,
GemmConfiguration::PAD_K,
GemmConfiguration::DOUBLE_SMEM_BUFFER,
ALayout,
BLayout,
ELayout,
GemmConfiguration::TRANSPOSE_C,
GemmConfiguration::USE_STRUCTURED_SPARSITY,
GemmConfiguration::PERSISTENT,
GemmConfiguration::NUM_WAVE_GROUPS,
GemmConfiguration::PRESHUFFLE>;
const auto Run = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
const auto runKernel = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
GemmConfig::Scheduler>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccumulatorDataType,
GemmShape,
GemmUniversalTraits,
GemmConfiguration::SCHEDULER>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
@@ -61,39 +67,39 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
AccumulatorDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfiguration::M_WARP,
GemmConfiguration::N_WARP,
GemmConfiguration::M_WARP_TILE,
GemmConfiguration::N_WARP_TILE,
GemmConfiguration::K_WARP_TILE,
UniversalGemmProblem::TransposeC,
memory_operation.value,
GemmConfig::NumWaveGroups>>;
GemmConfiguration::NUM_WAVE_GROUPS>>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
auto kernel_args = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
dim3 grids = Kernel::GridSize(kargs.tile_partitioner);
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
if(!Kernel::IsSupportedArgument(kernel_args))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
if(stream_config.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
@@ -109,7 +115,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
{
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
}
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
{
@@ -120,45 +126,47 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
std::function<void()> preprocess = reset_data_buffers;
float ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
float average_time =
ck_tile::launch_kernel_time_mask(stream_config,
preprocess,
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
Kernel{}, grids, blocks, 0, kernel_args));
ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
return std::tuple{ave_time, num_wgs_per_tile};
ck_tile::index_t num_wgs_per_tile =
kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
return std::tuple{average_time, num_wgs_per_tile};
};
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// Since we are doing stream K, in the case of
// atomics, multiple workgroups may write to the same
// output tile in the C tensor, so we must atomic add
// the results (not set)
ck_tile::memory_operation_enum::atomic_add>{});
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// Since we are doing stream K, in the case of
// atomics, multiple workgroups may write to the
// same output tile in the C tensor, so we must
// atomic add the results (not set)
ck_tile::memory_operation_enum::atomic_add>{});
}
else // We are using ck_tile::StreamKReductionStrategy::Reduction
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// In this case, there is only ever 1 WG writing final
// results to each macro tile in the C tensor, so we
// can do a set.
ck_tile::memory_operation_enum::set>{});
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// In this case, there is only ever 1 WG writing
// final results to each macro tile in the C
// tensor, so we can do a set.
ck_tile::memory_operation_enum::set>{});
}
}
#include "run_gemm_example.inc"
template <typename GemmConfig, typename TypeConfig>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
template <typename GemmConfiguration, typename TypeConfiguration>
int runGemmExamplePrecisionType(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, TypeConfig>(
return runGemmExampleWithLayouts<GemmConfiguration, TypeConfiguration>(
argc, argv, Row{}, Col{}, Row{});
}
else
@@ -169,72 +177,74 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
return 0;
}
template <template <typename PreType, bool Persistent_> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
template <template <typename PrecisionType, bool IsPersistent> typename GemmConfiguration>
int runGemmExample(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
auto [result, arg_parser] = createArgs(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
auto persistent_dp = arg_parser.get_bool("persistent_dp");
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
auto persistent_data_parallel = arg_parser.get_bool("persistent_dp");
if(data_type == "bf16")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf16_t>;
if(persistent_dp)
using TypeConfiguration = StreamKGemmTypeConfiguration<ck_tile::bf16_t>;
if(persistent_data_parallel)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf16_t, true>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf16_t, false>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
}
else if(data_type == "fp16")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::half_t>;
if(persistent_dp)
using TypeConfiguration = StreamKGemmTypeConfiguration<ck_tile::half_t>;
if(persistent_data_parallel)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::half_t, true>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::half_t, false>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
}
else if(data_type == "fp8")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
if(persistent_dp)
using TypeConfiguration =
StreamKGemmTypeConfiguration<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
if(persistent_data_parallel)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::fp8_t, true>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::fp8_t, false>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
}
else if(data_type == "bf8")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
if(persistent_dp)
using TypeConfiguration =
StreamKGemmTypeConfiguration<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
if(persistent_data_parallel)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf8_t, true>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf8_t, false>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
}
else
@@ -247,5 +257,5 @@ int run_gemm_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
return !run_gemm_example<GemmConfigMemoryInterwave>(argc, argv);
return !runGemmExample<GemmConfigurationMemoryInterwave>(argc, argv);
}