[TileEngine] Support for sparsity in codegen (#2128)

* Added sparsity flag in codegen

* remove comments

* clan formatted

* added sparsity as runtime argument

* updated README

* updated stream config variable

* fix typo for tail_num in hot loop
This commit is contained in:
Khushbu Agarwal
2025-04-28 18:19:23 -07:00
committed by GitHub
parent 4094ad158a
commit 768c99eca9
4 changed files with 56 additions and 35 deletions

View File

@@ -69,7 +69,7 @@ HOT_LOOP_FALSE = """
else if(tail_num == ck_tile::TailNumber::Even)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
@@ -347,7 +347,8 @@ namespace {group_name} {{
return f"""
template <int TileM, int TileN, int TileK,
int WarpM, int WarpN, int WarpK,
int WarpTileM, int WarpTileN, int WarpTileK>
int WarpTileM, int WarpTileN, int WarpTileK,
bool structured_sparsity>
struct GemmKernel {{
static constexpr bool kPadM = {BOOL_MAP(kPadM)};
static constexpr bool kPadN = {BOOL_MAP(kPadN)};
@@ -356,7 +357,7 @@ struct GemmKernel {{
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{
static constexpr bool permuteA = false;
static constexpr bool permuteB = false;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
static constexpr bool TransposeC = false;
static constexpr int kBlockPerCu = 1;
@@ -381,7 +382,7 @@ struct GemmKernel {{
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, CLayout, TransposeC>;
ALayout, BLayout, CLayout, TransposeC, structured_sparsity>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
@@ -494,7 +495,7 @@ struct GemmDispatcher {
return kernel_map;
}
static void init() {
static void init(bool structured_sparsity) {
auto& kernel_map = get_kernel_map();
if(!kernel_map.empty()) return;
\n"""
@@ -513,11 +514,11 @@ struct GemmDispatcher {
for group in self.all_kernels:
content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf,
content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s) {{
const ck_tile::stream_config& stream) {{
"""
for tile in tile_params:
# Check if we have valid tile/warp combinations
@@ -526,7 +527,11 @@ struct GemmDispatcher {
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
continue
content += f"""
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);"""
if(structured_sparsity) {{
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {1}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
}} else {{
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {0}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
}}"""
content += f"""
}};\n"""
@@ -536,9 +541,9 @@ struct GemmDispatcher {
static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream)
{
float avg_time = Kernel::launch(args, s);
float avg_time = Kernel::launch(args, stream);
std::string description = Kernel::get_name();
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
@@ -559,13 +564,13 @@ struct GemmDispatcher {
static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
const ck_tile::stream_config& s) {
init();
int verify, bool structured_sparsity, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
const ck_tile::stream_config& stream) {
init(structured_sparsity);
const std::string key = assemble_key(trait);
auto& kernel_map = get_kernel_map();
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s);
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, stream);
}
throw std::runtime_error("No suitable kernel found: " + key);
}