Revert "[CK_TILE] Tile loop persistent gemm kernel (#2191)" (#2293)

This reverts commit ffb52783d0.
This commit is contained in:
Illia Silin
2025-06-05 09:24:00 -07:00
committed by GitHub
parent 7ea1508b59
commit 233e274077
10 changed files with 18 additions and 232 deletions

View File

@@ -18,12 +18,9 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
bool Persistent>
typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if constexpr(Persistent)
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;

View File

@@ -213,8 +213,7 @@ auto create_args(int argc, char* argv[])
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent");
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -227,6 +226,5 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
bool Persistent = false>
typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -162,8 +162,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat,
bool persistent)
int n_repeat)
{
ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
@@ -177,31 +176,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_B = stride_B;
args.stride_C = stride_C;
float ave_time;
if(persistent)
{
ave_time = gemm_calc<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
true>(
float ave_time =
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
}
else
{
ave_time = gemm_calc<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
false>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
}
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
@@ -216,8 +193,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
@@ -252,7 +229,6 @@ int run_gemm_example_with_layouts(int argc,
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool persistent = arg_parser.get_int("persistent");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -340,8 +316,7 @@ int run_gemm_example_with_layouts(int argc,
stride_C,
kbatch,
n_warmup,
n_repeat,
persistent);
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;

View File

@@ -19,8 +19,7 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
bool Persistent>
typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
@@ -49,8 +48,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
BLayout,
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
Persistent>;
GemmConfig::UseStructuredSparsity>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
@@ -100,15 +98,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))