diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md index 495232f19b..08456a1675 100644 --- a/tile_engine/ops/gemm/README.md +++ b/tile_engine/ops/gemm/README.md @@ -20,24 +20,25 @@ make tile_engine_gemm -j ## tile_engine_gemm inputs ``` - -m m dimension (default:3840) - -n n dimension (default:4096) - -k k dimension (default:2048) - -stride_a Tensor A stride (default:0) - -stride_b Tensor B stride (default:0) - -stride_c Tensor C stride (default:0) - -split_k SplitK value (default:1) - -v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2) - -warmup Number of iterations before benchmark the kernel (default:50) - -repeat Number of iterations to benchmark the kernel (default:100) - -timer gpu:gpu timer, cpu:cpu timer (default:gpu) - -init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0) - -pipeline possible values are: compv3, compv4, mem (default:compv3) - -scheduler possible values are: intrawave, interwave (default:intrawave) - -epilogue possible values are: cshuffle, default (default:cshuffle) - -pad_m Pad in m direction - true/false (default:false) - -pad_n Pad in n direction - true/false (default:false) - -pad_k Pad in k direction - true/false (default:false) + -m m dimension (default:3840) + -n n dimension (default:4096) + -k k dimension (default:2048) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -split_k SplitK value (default:1) + -v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2) + -warmup Number of iterations before benchmark the kernel (default:50) + -repeat Number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0) +-structured_sparsity Sparsity for tensor - 0:false, 1:true (default: 0) + -pipeline possible values are: compv3, compv4, mem (default:compv3) + -scheduler possible values are: intrawave, interwave (default:intrawave) + -epilogue possible values are: cshuffle, default (default:cshuffle) + -pad_m Pad in m direction - true/false (default:false) + -pad_n Pad in n direction - true/false (default:false) + -pad_k Pad in k direction - true/false (default:false) Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json ``` diff --git a/tile_engine/ops/gemm/gemm_host_api.cpp b/tile_engine/ops/gemm/gemm_host_api.cpp old mode 100644 new mode 100755 index 3cef425a51..a5447cd658 --- a/tile_engine/ops/gemm/gemm_host_api.cpp +++ b/tile_engine/ops/gemm/gemm_host_api.cpp @@ -10,12 +10,19 @@ void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, int verify, + bool structured_sparsity, KernelTraits& trait, ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) + const ck_tile::stream_config& stream) { - return GemmDispatcher::dispatch( - c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, trait, args, s); + return GemmDispatcher::dispatch(c_m_n_dev_buf, + c_m_n_host_result, + c_m_n_dev_result, + verify, + structured_sparsity, + trait, + args, + stream); } template {}(a_m_k); + } + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); @@ -153,6 +166,7 @@ void run(const ck_tile::ArgParser& arg_parser) c_m_n_host_result, c_m_n_dev_result, verify, + structured_sparsity, trait, gemm_args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp old mode 100644 new mode 100755 index c1e1e1dc4f..579d2770db --- a/tile_engine/ops/gemm/gemm_host_api.hpp +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -118,6 +118,7 @@ inline auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("structured_sparsity", "0", "0:false, 1:true") .insert("pipeline", "compv3", "compv3, compv4, mem") .insert("scheduler", "intrawave", "intrawave, interwave") .insert("epilogue", "cshuffle", "cshuffle, default") diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index cfefd38cd2..b6c7685fb2 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -69,7 +69,7 @@ HOT_LOOP_FALSE = """ else if(tail_num == ck_tile::TailNumber::Even) { Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + ck_tile::integral_constant{}); } else { @@ -347,7 +347,8 @@ namespace {group_name} {{ return f""" template + int WarpTileM, int WarpTileN, int WarpTileK, + bool structured_sparsity> struct GemmKernel {{ static constexpr bool kPadM = {BOOL_MAP(kPadM)}; static constexpr bool kPadN = {BOOL_MAP(kPadN)}; @@ -356,7 +357,7 @@ struct GemmKernel {{ static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{ static constexpr bool permuteA = false; static constexpr bool permuteB = false; - static constexpr bool DoubleSmemBuffer = false; + static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"}; static constexpr bool TransposeC = false; static constexpr int kBlockPerCu = 1; @@ -381,7 +382,7 @@ struct GemmKernel {{ using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + ALayout, BLayout, CLayout, TransposeC, structured_sparsity>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -494,7 +495,7 @@ struct GemmDispatcher { return kernel_map; } - static void init() { + static void init(bool structured_sparsity) { auto& kernel_map = get_kernel_map(); if(!kernel_map.empty()) return; \n""" @@ -513,11 +514,11 @@ struct GemmDispatcher { for group in self.all_kernels: - content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf, + content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, int verify, ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) {{ + const ck_tile::stream_config& stream) {{ """ for tile in tile_params: # Check if we have valid tile/warp combinations @@ -526,7 +527,11 @@ struct GemmDispatcher { ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): continue content += f""" - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);""" + if(structured_sparsity) {{ + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {1}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); + }} else {{ + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {0}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); + }}""" content += f""" }};\n""" @@ -536,9 +541,9 @@ struct GemmDispatcher { static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, - int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) { - float avg_time = Kernel::launch(args, s); + float avg_time = Kernel::launch(args, stream); std::string description = Kernel::get_name(); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); @@ -559,13 +564,13 @@ struct GemmDispatcher { static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, - int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args, - const ck_tile::stream_config& s) { - init(); + int verify, bool structured_sparsity, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args, + const ck_tile::stream_config& stream) { + init(structured_sparsity); const std::string key = assemble_key(trait); auto& kernel_map = get_kernel_map(); if(auto it = kernel_map.find(key); it != kernel_map.end()) { - return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s); + return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, stream); } throw std::runtime_error("No suitable kernel found: " + key); }