mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[TileEngine] Support for sparsity in codegen (#2128)
* Added sparsity flag in codegen * remove comments * clan formatted * added sparsity as runtime argument * updated README * updated stream config variable * fix typo for tail_num in hot loop
This commit is contained in:
@@ -20,24 +20,25 @@ make tile_engine_gemm -j
|
||||
## tile_engine_gemm inputs
|
||||
```
|
||||
|
||||
-m m dimension (default:3840)
|
||||
-n n dimension (default:4096)
|
||||
-k k dimension (default:2048)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-split_k SplitK value (default:1)
|
||||
-v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2)
|
||||
-warmup Number of iterations before benchmark the kernel (default:50)
|
||||
-repeat Number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0)
|
||||
-pipeline possible values are: compv3, compv4, mem (default:compv3)
|
||||
-scheduler possible values are: intrawave, interwave (default:intrawave)
|
||||
-epilogue possible values are: cshuffle, default (default:cshuffle)
|
||||
-pad_m Pad in m direction - true/false (default:false)
|
||||
-pad_n Pad in n direction - true/false (default:false)
|
||||
-pad_k Pad in k direction - true/false (default:false)
|
||||
-m m dimension (default:3840)
|
||||
-n n dimension (default:4096)
|
||||
-k k dimension (default:2048)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-split_k SplitK value (default:1)
|
||||
-v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2)
|
||||
-warmup Number of iterations before benchmark the kernel (default:50)
|
||||
-repeat Number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0)
|
||||
-structured_sparsity Sparsity for tensor - 0:false, 1:true (default: 0)
|
||||
-pipeline possible values are: compv3, compv4, mem (default:compv3)
|
||||
-scheduler possible values are: intrawave, interwave (default:intrawave)
|
||||
-epilogue possible values are: cshuffle, default (default:cshuffle)
|
||||
-pad_m Pad in m direction - true/false (default:false)
|
||||
-pad_n Pad in n direction - true/false (default:false)
|
||||
-pad_k Pad in k direction - true/false (default:false)
|
||||
|
||||
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json
|
||||
```
|
||||
|
||||
20
tile_engine/ops/gemm/gemm_host_api.cpp
Normal file → Executable file
20
tile_engine/ops/gemm/gemm_host_api.cpp
Normal file → Executable file
@@ -10,12 +10,19 @@ void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify,
|
||||
bool structured_sparsity,
|
||||
KernelTraits& trait,
|
||||
ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
const ck_tile::stream_config& stream)
|
||||
{
|
||||
return GemmDispatcher::dispatch(
|
||||
c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, trait, args, s);
|
||||
return GemmDispatcher::dispatch(c_m_n_dev_buf,
|
||||
c_m_n_host_result,
|
||||
c_m_n_dev_result,
|
||||
verify,
|
||||
structured_sparsity,
|
||||
trait,
|
||||
args,
|
||||
stream);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
@@ -43,6 +50,7 @@ void run(const ck_tile::ArgParser& arg_parser)
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
int verify = arg_parser.get_int("v");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool structured_sparsity = arg_parser.get_bool("structured_sparsity");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
@@ -76,6 +84,11 @@ void run(const ck_tile::ArgParser& arg_parser)
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
if(structured_sparsity)
|
||||
{
|
||||
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
@@ -153,6 +166,7 @@ void run(const ck_tile::ArgParser& arg_parser)
|
||||
c_m_n_host_result,
|
||||
c_m_n_dev_result,
|
||||
verify,
|
||||
structured_sparsity,
|
||||
trait,
|
||||
gemm_args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
1
tile_engine/ops/gemm/gemm_host_api.hpp
Normal file → Executable file
1
tile_engine/ops/gemm/gemm_host_api.hpp
Normal file → Executable file
@@ -118,6 +118,7 @@ inline auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("structured_sparsity", "0", "0:false, 1:true")
|
||||
.insert("pipeline", "compv3", "compv3, compv4, mem")
|
||||
.insert("scheduler", "intrawave", "intrawave, interwave")
|
||||
.insert("epilogue", "cshuffle", "cshuffle, default")
|
||||
|
||||
@@ -69,7 +69,7 @@ HOT_LOOP_FALSE = """
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -347,7 +347,8 @@ namespace {group_name} {{
|
||||
return f"""
|
||||
template <int TileM, int TileN, int TileK,
|
||||
int WarpM, int WarpN, int WarpK,
|
||||
int WarpTileM, int WarpTileN, int WarpTileK>
|
||||
int WarpTileM, int WarpTileN, int WarpTileK,
|
||||
bool structured_sparsity>
|
||||
struct GemmKernel {{
|
||||
static constexpr bool kPadM = {BOOL_MAP(kPadM)};
|
||||
static constexpr bool kPadN = {BOOL_MAP(kPadN)};
|
||||
@@ -356,7 +357,7 @@ struct GemmKernel {{
|
||||
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{
|
||||
static constexpr bool permuteA = false;
|
||||
static constexpr bool permuteB = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
|
||||
static constexpr bool TransposeC = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
@@ -381,7 +382,7 @@ struct GemmKernel {{
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, CLayout, TransposeC>;
|
||||
ALayout, BLayout, CLayout, TransposeC, structured_sparsity>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
@@ -494,7 +495,7 @@ struct GemmDispatcher {
|
||||
return kernel_map;
|
||||
}
|
||||
|
||||
static void init() {
|
||||
static void init(bool structured_sparsity) {
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(!kernel_map.empty()) return;
|
||||
\n"""
|
||||
@@ -513,11 +514,11 @@ struct GemmDispatcher {
|
||||
|
||||
|
||||
for group in self.all_kernels:
|
||||
content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s) {{
|
||||
const ck_tile::stream_config& stream) {{
|
||||
"""
|
||||
for tile in tile_params:
|
||||
# Check if we have valid tile/warp combinations
|
||||
@@ -526,7 +527,11 @@ struct GemmDispatcher {
|
||||
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
|
||||
continue
|
||||
content += f"""
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);"""
|
||||
if(structured_sparsity) {{
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {1}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
|
||||
}} else {{
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {0}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);
|
||||
}}"""
|
||||
content += f"""
|
||||
}};\n"""
|
||||
|
||||
@@ -536,9 +541,9 @@ struct GemmDispatcher {
|
||||
static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream)
|
||||
{
|
||||
float avg_time = Kernel::launch(args, s);
|
||||
float avg_time = Kernel::launch(args, stream);
|
||||
std::string description = Kernel::get_name();
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
@@ -559,13 +564,13 @@ struct GemmDispatcher {
|
||||
static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
|
||||
const ck_tile::stream_config& s) {
|
||||
init();
|
||||
int verify, bool structured_sparsity, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
|
||||
const ck_tile::stream_config& stream) {
|
||||
init(structured_sparsity);
|
||||
const std::string key = assemble_key(trait);
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
|
||||
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s);
|
||||
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, stream);
|
||||
}
|
||||
throw std::runtime_error("No suitable kernel found: " + key);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user