mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK TILE GEMM STREAMK] update identifier names according to the new code style (#3348)
* [CK TILE GEMM STREAMK] update identifier names according to the new code style
This commit is contained in:
@@ -7,46 +7,46 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
struct GemmConfigBase
|
||||
struct GemmConfigurationBase
|
||||
{
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
static constexpr bool kPadK = true;
|
||||
static constexpr bool PAD_M = true;
|
||||
static constexpr bool PAD_N = true;
|
||||
static constexpr bool PAD_K = true;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
static constexpr bool PERMUTE_A = false;
|
||||
static constexpr bool PERMUTE_B = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr bool TRANSPOSE_C = false;
|
||||
static constexpr bool USE_STRUCTURED_SPARSITY = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr int BLOCK_PER_CU = 1;
|
||||
static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NUM_WAVE_GROUPS = 1;
|
||||
static constexpr bool PRESHUFFLE = false;
|
||||
static constexpr bool DOUBLE_SMEM_BUFFER = false;
|
||||
};
|
||||
|
||||
template <typename PrecType, bool Persistent_>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
template <typename PrecisionType, bool IsPersistent>
|
||||
struct GemmConfigurationMemoryInterwave : public GemmConfigurationBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 16;
|
||||
static constexpr ck_tile::index_t M_TILE = 256;
|
||||
static constexpr ck_tile::index_t N_TILE = 256;
|
||||
static constexpr ck_tile::index_t K_TILE = 16;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
static constexpr ck_tile::index_t M_WARP = 2;
|
||||
static constexpr ck_tile::index_t N_WARP = 2;
|
||||
static constexpr ck_tile::index_t K_WARP = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
static constexpr ck_tile::index_t M_WARP_TILE = 32;
|
||||
static constexpr ck_tile::index_t N_WARP_TILE = 32;
|
||||
static constexpr ck_tile::index_t K_WARP_TILE = sizeof(PrecisionType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool Persistent = Persistent_;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool PERSISTENT = IsPersistent;
|
||||
static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
|
||||
struct StreamKGemmTypeConfig
|
||||
struct StreamKGemmTypeConfiguration
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
@@ -54,7 +54,7 @@ struct StreamKGemmTypeConfig
|
||||
using CDataType = CDataType_;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
auto createArgs(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "512", "m dimension")
|
||||
|
||||
@@ -12,31 +12,35 @@ static constexpr inline auto is_row_major(Layout)
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
auto calculateRtolAtol(const ck_tile::index_t k_dim,
|
||||
const ck_tile::index_t k_batch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto relative_tolerance =
|
||||
ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(k_dim, k_batch));
|
||||
const auto absolute_tolerance =
|
||||
ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / k_batch, ck_tile::integer_divide_ceil(k_dim, k_batch));
|
||||
// Calculate error due to multiple WGs working in the same C macro tile
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
const auto relative_tolerance_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(k_batch);
|
||||
const auto absolute_tolerance_split_k =
|
||||
ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(max_accumulated_value,
|
||||
k_batch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
return ck_tile::make_tuple(std::max(relative_tolerance, relative_tolerance_split_k),
|
||||
std::max(absolute_tolerance, absolute_tolerance_split_k));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
template <typename GemmConfiguration,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename AccumulatorDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -45,102 +49,107 @@ template <typename GemmConfig,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s);
|
||||
const ck_tile::stream_config& stream_config);
|
||||
|
||||
template <typename GemmConfig,
|
||||
template <typename GemmConfiguration,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename AccumulatorDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
std::tuple<float, ck_tile::index_t>
|
||||
invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
bool flush_cache,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy)
|
||||
std::tuple<float, ck_tile::index_t> invokeGemm(ck_tile::DeviceMem& a_m_k_device_memory,
|
||||
ck_tile::DeviceMem& b_k_n_device_memory,
|
||||
ck_tile::DeviceMem& c_m_n_device_memory,
|
||||
ck_tile::index_t m_dim,
|
||||
ck_tile::index_t n_dim,
|
||||
ck_tile::index_t k_dim,
|
||||
ck_tile::index_t stride_a,
|
||||
ck_tile::index_t stride_b,
|
||||
ck_tile::index_t stride_c,
|
||||
int warmup_iterations,
|
||||
int repeat_iterations,
|
||||
bool flush_cache,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy)
|
||||
{
|
||||
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C};
|
||||
ck_tile::StreamKHostArgs args{a_m_k_device_memory.GetDeviceBuffer(),
|
||||
b_k_n_device_memory.GetDeviceBuffer(),
|
||||
c_m_n_device_memory.GetDeviceBuffer(),
|
||||
m_dim,
|
||||
n_dim,
|
||||
k_dim,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c};
|
||||
|
||||
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
|
||||
std::tuple<float, ck_tile::index_t> average_time_and_batch;
|
||||
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
ave_time_and_batch = gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache});
|
||||
average_time_and_batch = gemm<GemmConfiguration,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args,
|
||||
ck_tile::stream_config{
|
||||
nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache});
|
||||
}
|
||||
else /*Reduction*/
|
||||
{
|
||||
ave_time_and_batch = gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy::Reduction>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache});
|
||||
average_time_and_batch = gemm<GemmConfiguration,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy::Reduction>(
|
||||
args,
|
||||
ck_tile::stream_config{
|
||||
nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache});
|
||||
}
|
||||
|
||||
return ave_time_and_batch;
|
||||
return average_time_and_batch;
|
||||
}
|
||||
|
||||
template <typename CDataType>
|
||||
bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
const ck_tile::HostTensor<CDataType>& c_m_n_ref,
|
||||
const ck_tile::tuple<double, double>& rtol_atol,
|
||||
const char* variant)
|
||||
bool doVerify(const ck_tile::HostTensor<CDataType>& c_m_n_device_result,
|
||||
const ck_tile::HostTensor<CDataType>& c_m_n_reference,
|
||||
const ck_tile::tuple<double, double>& relative_absolute_tolerances,
|
||||
const char* variant)
|
||||
{
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_ref,
|
||||
bool pass = ck_tile::check_err(c_m_n_device_result,
|
||||
c_m_n_reference,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
relative_absolute_tolerances.at(ck_tile::number<0>{}),
|
||||
relative_absolute_tolerances.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "Relative error threshold: "
|
||||
<< relative_absolute_tolerances.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: "
|
||||
<< relative_absolute_tolerances.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The " << variant << " verification result is:" << (pass ? "correct" : "fail")
|
||||
<< std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string& strategy)
|
||||
ck_tile::StreamKReductionStrategy getReductionStrategyValue(const std::string& strategy)
|
||||
{
|
||||
if(strategy == "atomic")
|
||||
{
|
||||
@@ -156,172 +165,169 @@ ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
template <typename GemmConfiguration,
|
||||
typename TypeConfiguration,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
int runGemmExampleWithLayouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
auto [result, arg_parser] = createArgs(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
static_assert(!GemmConfig::Preshuffle, "Not implemented");
|
||||
static_assert(!GemmConfig::UseStructuredSparsity, "Not implemented");
|
||||
static_assert(!GemmConfig::PermuteA, "Not implemented");
|
||||
static_assert(!GemmConfig::PermuteB, "Not implemented");
|
||||
static_assert(!GemmConfiguration::PRESHUFFLE, "Not implemented");
|
||||
static_assert(!GemmConfiguration::USE_STRUCTURED_SPARSITY, "Not implemented");
|
||||
static_assert(!GemmConfiguration::PERMUTE_A, "Not implemented");
|
||||
static_assert(!GemmConfiguration::PERMUTE_B, "Not implemented");
|
||||
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using AccDataType = typename TypeConfig::AccDataType;
|
||||
using CDataType = typename TypeConfig::CDataType;
|
||||
using ADataType = typename TypeConfiguration::ADataType;
|
||||
using BDataType = typename TypeConfiguration::BDataType;
|
||||
using AccumulatorDataType = typename TypeConfiguration::AccDataType;
|
||||
using CDataType = typename TypeConfiguration::CDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
ck_tile::index_t m_dim = arg_parser.get_int("m");
|
||||
ck_tile::index_t n_dim = arg_parser.get_int("n");
|
||||
ck_tile::index_t k_dim = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
|
||||
int warmup_iterations = arg_parser.get_int("warmup");
|
||||
int repeat_iterations = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool flush_cache = arg_parser.get_bool("flush_cache");
|
||||
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy =
|
||||
get_reduction_strategy_value(arg_parser.get_str("reduction_strategy"));
|
||||
getReductionStrategyValue(arg_parser.get_str("reduction_strategy"));
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
stride_a = ck_tile::get_default_stride(m_dim, k_dim, stride_a, is_row_major(a_layout));
|
||||
stride_b = ck_tile::get_default_stride(k_dim, n_dim, stride_b, is_row_major(b_layout));
|
||||
stride_c = ck_tile::get_default_stride(m_dim, n_dim, stride_c, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::HostTensor<ADataType> a_m_k_host(
|
||||
ck_tile::host_tensor_descriptor(m_dim, k_dim, stride_a, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n_host(
|
||||
ck_tile::host_tensor_descriptor(k_dim, n_dim, stride_b, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_device_result(
|
||||
ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{})));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_host);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k_host);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n_host);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
b_k_n.SetZero();
|
||||
a_m_k_host.SetZero();
|
||||
b_k_n_host.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem a_m_k_device_memory(a_m_k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_device_memory(b_k_n_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_device_memory(c_m_n_device_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
a_m_k_device_memory.ToDevice(a_m_k_host.data());
|
||||
b_k_n_device_memory.ToDevice(b_k_n_host.data());
|
||||
c_m_n_device_memory.SetZero();
|
||||
c_m_n_device_result.SetZero();
|
||||
auto [average_time, num_wgs_per_tile] = invokeGemm<GemmConfiguration,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_device_memory,
|
||||
b_k_n_device_memory,
|
||||
c_m_n_device_memory,
|
||||
m_dim,
|
||||
n_dim,
|
||||
k_dim,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
warmup_iterations,
|
||||
repeat_iterations,
|
||||
flush_cache,
|
||||
reduction_strategy);
|
||||
|
||||
auto [ave_time, num_wgs_per_tile] = invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
flush_cache,
|
||||
reduction_strategy);
|
||||
c_m_n_device_memory.FromDevice(c_m_n_device_result.data());
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
|
||||
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
|
||||
std::size_t flop = std::size_t(2) * m_dim * n_dim * k_dim;
|
||||
std::size_t num_byte = sizeof(ADataType) * m_dim * k_dim + sizeof(BDataType) * n_dim * k_dim +
|
||||
sizeof(CDataType) * m_dim * n_dim;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / average_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / average_time;
|
||||
std::cout << "Run Gemm kernel with M=" << m_dim << " N=" << n_dim << " K=" << k_dim
|
||||
<< " StrideA=" << stride_a << " StrideB=" << stride_b << " StrideC=" << stride_c
|
||||
<< " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name
|
||||
<< " C_Layout=" << CLayout::name
|
||||
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
|
||||
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
|
||||
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
|
||||
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
|
||||
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time
|
||||
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << average_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
bool pass = false;
|
||||
|
||||
// Memory on host to store gpu reference result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_ref.SetZero();
|
||||
ck_tile::HostTensor<CDataType> c_m_n_reference(
|
||||
ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{})));
|
||||
c_m_n_reference.SetZero();
|
||||
|
||||
if(arg_parser.get_int("v") == 1) // Validate on the CPU
|
||||
{
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_ref);
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccumulatorDataType, CDataType>(
|
||||
a_m_k_host, b_k_n_host, c_m_n_reference);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, num_wgs_per_tile, max_accumulated_value);
|
||||
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
|
||||
*std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end());
|
||||
const auto relative_absolute_tolerances =
|
||||
calculateRtolAtol<ADataType, BDataType, AccumulatorDataType, CDataType>(
|
||||
k_dim, num_wgs_per_tile, max_accumulated_value);
|
||||
pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "CPU");
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2) // Validate on the GPU
|
||||
{
|
||||
// Memory on device to store gpu reference result
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes());
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
ck_tile::DeviceMem c_m_n_gpu_buffer_reference(
|
||||
c_m_n_reference.get_element_space_size_in_bytes());
|
||||
c_m_n_gpu_buffer_reference.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_device_memory.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_device_memory.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buffer_reference.GetDeviceBuffer());
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());
|
||||
CLayout>(
|
||||
d_A, d_B, d_C, m_dim, n_dim, k_dim, stride_a, stride_b, stride_c);
|
||||
c_m_n_gpu_buffer_reference.FromDevice(c_m_n_reference.data());
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, num_wgs_per_tile, max_accumulated_value);
|
||||
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU");
|
||||
*std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end());
|
||||
const auto relative_absolute_tolerances =
|
||||
calculateRtolAtol<ADataType, BDataType, AccumulatorDataType, CDataType>(
|
||||
k_dim, num_wgs_per_tile, max_accumulated_value);
|
||||
pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "GPU");
|
||||
}
|
||||
|
||||
return pass;
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
#include "gemm_utils.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
template <typename GemmConfiguration,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename AccumulatorDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -17,43 +17,49 @@ template <typename GemmConfig,
|
||||
typename CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
const ck_tile::stream_config& stream_config)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<GemmConfiguration::M_TILE,
|
||||
GemmConfiguration::N_TILE,
|
||||
GemmConfiguration::K_TILE>,
|
||||
ck_tile::sequence<GemmConfiguration::M_WARP,
|
||||
GemmConfiguration::N_WARP,
|
||||
GemmConfiguration::K_WARP>,
|
||||
ck_tile::sequence<GemmConfiguration::M_WARP_TILE,
|
||||
GemmConfiguration::N_WARP_TILE,
|
||||
GemmConfiguration::K_WARP_TILE>,
|
||||
GemmConfiguration::PERMUTE_A,
|
||||
GemmConfiguration::PERMUTE_B>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
|
||||
using TilePartitioner = ck_tile::
|
||||
StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfiguration::PERSISTENT>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
GemmConfig::Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<GemmConfiguration::PAD_M,
|
||||
GemmConfiguration::PAD_N,
|
||||
GemmConfiguration::PAD_K,
|
||||
GemmConfiguration::DOUBLE_SMEM_BUFFER,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfiguration::TRANSPOSE_C,
|
||||
GemmConfiguration::USE_STRUCTURED_SPARSITY,
|
||||
GemmConfiguration::PERSISTENT,
|
||||
GemmConfiguration::NUM_WAVE_GROUPS,
|
||||
GemmConfiguration::PRESHUFFLE>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
|
||||
const auto runKernel = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::Scheduler>;
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccumulatorDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfiguration::SCHEDULER>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
@@ -61,39 +67,39 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfiguration::M_WARP,
|
||||
GemmConfiguration::N_WARP,
|
||||
GemmConfiguration::M_WARP_TILE,
|
||||
GemmConfiguration::N_WARP_TILE,
|
||||
GemmConfiguration::K_WARP_TILE,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation.value,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
GemmConfiguration::NUM_WAVE_GROUPS>>;
|
||||
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
|
||||
auto kernel_args = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
workspace_data.SetZero();
|
||||
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
|
||||
dim3 grids = Kernel::GridSize(kargs.tile_partitioner);
|
||||
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
@@ -109,7 +115,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
{
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
@@ -120,45 +126,47 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
|
||||
std::function<void()> preprocess = reset_data_buffers;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
float average_time =
|
||||
ck_tile::launch_kernel_time_mask(stream_config,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
|
||||
Kernel{}, grids, blocks, 0, kernel_args));
|
||||
|
||||
ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
return std::tuple{ave_time, num_wgs_per_tile};
|
||||
ck_tile::index_t num_wgs_per_tile =
|
||||
kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
return std::tuple{average_time, num_wgs_per_tile};
|
||||
};
|
||||
|
||||
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// Since we are doing stream K, in the case of
|
||||
// atomics, multiple workgroups may write to the same
|
||||
// output tile in the C tensor, so we must atomic add
|
||||
// the results (not set)
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// Since we are doing stream K, in the case of
|
||||
// atomics, multiple workgroups may write to the
|
||||
// same output tile in the C tensor, so we must
|
||||
// atomic add the results (not set)
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
else // We are using ck_tile::StreamKReductionStrategy::Reduction
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// In this case, there is only ever 1 WG writing final
|
||||
// results to each macro tile in the C tensor, so we
|
||||
// can do a set.
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
return runKernel(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// In this case, there is only ever 1 WG writing
|
||||
// final results to each macro tile in the C
|
||||
// tensor, so we can do a set.
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename GemmConfig, typename TypeConfig>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
template <typename GemmConfiguration, typename TypeConfiguration>
|
||||
int runGemmExamplePrecisionType(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, TypeConfig>(
|
||||
return runGemmExampleWithLayouts<GemmConfiguration, TypeConfiguration>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -169,72 +177,74 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <template <typename PreType, bool Persistent_> typename GemmConfig>
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
template <template <typename PrecisionType, bool IsPersistent> typename GemmConfiguration>
|
||||
int runGemmExample(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
auto [result, arg_parser] = createArgs(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
auto persistent_dp = arg_parser.get_bool("persistent_dp");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
auto persistent_data_parallel = arg_parser.get_bool("persistent_dp");
|
||||
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf16_t>;
|
||||
if(persistent_dp)
|
||||
using TypeConfiguration = StreamKGemmTypeConfiguration<ck_tile::bf16_t>;
|
||||
if(persistent_data_parallel)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, true>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf16_t, true>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, false>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf16_t, false>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else if(data_type == "fp16")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::half_t>;
|
||||
if(persistent_dp)
|
||||
using TypeConfiguration = StreamKGemmTypeConfiguration<ck_tile::half_t>;
|
||||
if(persistent_data_parallel)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, true>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::half_t, true>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, false>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::half_t, false>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
|
||||
if(persistent_dp)
|
||||
using TypeConfiguration =
|
||||
StreamKGemmTypeConfiguration<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
|
||||
if(persistent_data_parallel)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, true>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::fp8_t, true>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, false>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::fp8_t, false>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
|
||||
if(persistent_dp)
|
||||
using TypeConfiguration =
|
||||
StreamKGemmTypeConfiguration<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
|
||||
if(persistent_data_parallel)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, true>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf8_t, true>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, false>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf8_t, false>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -247,5 +257,5 @@ int run_gemm_example(int argc, char* argv[])
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_gemm_example<GemmConfigMemoryInterwave>(argc, argv);
|
||||
return !runGemmExample<GemmConfigurationMemoryInterwave>(argc, argv);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user