mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[TileEngine] Support for sparsity in codegen (#2128)
* Added sparsity flag in codegen * remove comments * clan formatted * added sparsity as runtime argument * updated README * updated stream config variable * fix typo for tail_num in hot loop
This commit is contained in:
@@ -69,7 +69,7 @@ HOT_LOOP_FALSE = """
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -347,7 +347,8 @@ namespace {group_name} {{
|
||||
return f"""
|
||||
template <int TileM, int TileN, int TileK,
|
||||
int WarpM, int WarpN, int WarpK,
|
||||
int WarpTileM, int WarpTileN, int WarpTileK>
|
||||
int WarpTileM, int WarpTileN, int WarpTileK,
|
||||
bool structured_sparsity>
|
||||
struct GemmKernel {{
|
||||
static constexpr bool kPadM = {BOOL_MAP(kPadM)};
|
||||
static constexpr bool kPadN = {BOOL_MAP(kPadN)};
|
||||
@@ -356,7 +357,7 @@ struct GemmKernel {{
|
||||
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{
|
||||
static constexpr bool permuteA = false;
|
||||
static constexpr bool permuteB = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
|
||||
static constexpr bool TransposeC = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
@@ -381,7 +382,7 @@ struct GemmKernel {{
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, CLayout, TransposeC>;
|
||||
ALayout, BLayout, CLayout, TransposeC, structured_sparsity>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
@@ -494,7 +495,7 @@ struct GemmDispatcher {
|
||||
return kernel_map;
|
||||
}
|
||||
|
||||
static void init() {
|
||||
static void init(bool structured_sparsity) {
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(!kernel_map.empty()) return;
|
||||
\n"""
|
||||
@@ -513,11 +514,11 @@ struct GemmDispatcher {
|
||||
|
||||
|
||||
for group in self.all_kernels:
|
||||
content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s) {{
|
||||
const ck_tile::stream_config& stream) {{
|
||||
"""
|
||||
for tile in tile_params:
|
||||
# Check if we have valid tile/warp combinations
|
||||
@@ -526,7 +527,11 @@ struct GemmDispatcher {
|
||||
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
|
||||
continue
|
||||
content += f"""
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);"""
|
||||
if(structured_sparsity) {{
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {1}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
|
||||
}} else {{
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {0}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
|
||||
}}"""
|
||||
content += f"""
|
||||
}};\n"""
|
||||
|
||||
@@ -536,9 +541,9 @@ struct GemmDispatcher {
|
||||
static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream)
|
||||
{
|
||||
float avg_time = Kernel::launch(args, s);
|
||||
float avg_time = Kernel::launch(args, stream);
|
||||
std::string description = Kernel::get_name();
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
@@ -559,13 +564,13 @@ struct GemmDispatcher {
|
||||
static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
|
||||
const ck_tile::stream_config& s) {
|
||||
init();
|
||||
int verify, bool structured_sparsity, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
|
||||
const ck_tile::stream_config& stream) {
|
||||
init(structured_sparsity);
|
||||
const std::string key = assemble_key(trait);
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
|
||||
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s);
|
||||
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, stream);
|
||||
}
|
||||
throw std::runtime_error("No suitable kernel found: " + key);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user