mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
[CK_TILE] Introduces a new GEMM API that splits the existing basic GEMM class into multiple specialized classes. (#2520)
* Init commit new API * apply clang-format * PreShuffle preapring * Apply Preshuffle condition to universal_gemm * Fix: convert size_t to index_t * Review changes * Mode 100755 -> 100644 --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -233,7 +233,7 @@ struct GemmKernel {{
|
||||
static constexpr bool kPadN = {pad_n};
|
||||
static constexpr bool kPadK = {pad_k};
|
||||
|
||||
static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{
|
||||
static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
|
||||
static constexpr bool permuteA = false;
|
||||
static constexpr bool permuteB = false;
|
||||
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
|
||||
@@ -335,7 +335,7 @@ struct GemmKernel {{
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, stream.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], stream.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {{
|
||||
@@ -680,7 +680,7 @@ struct GemmDispatcher {
|
||||
// Use a static local variable
|
||||
static std::unordered_map<
|
||||
std::string,
|
||||
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>>
|
||||
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>>
|
||||
kernel_map;
|
||||
return kernel_map;
|
||||
}
|
||||
@@ -705,7 +705,7 @@ struct GemmDispatcher {
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) = tile[j]
|
||||
content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """
|
||||
content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """
|
||||
content += f"""
|
||||
if(structured_sparsity){{ // SMFMA"""
|
||||
sparse = (
|
||||
@@ -746,7 +746,7 @@ struct GemmDispatcher {
|
||||
content += """ }
|
||||
|
||||
template <typename Kernel>
|
||||
static std::tuple<std::string, float> run_kernel(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream)
|
||||
static std::tuple<std::string, float> run_kernel(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream)
|
||||
{
|
||||
std::string name = Kernel::get_name();
|
||||
float avg_time = Kernel::launch(args, stream);
|
||||
|
||||
@@ -22,7 +22,7 @@ class GemmProfiler
|
||||
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
|
||||
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
@@ -89,10 +89,9 @@ class GemmProfiler
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs<> gemm_args = {
|
||||
ck_tile::GemmHostArgs gemm_args = {
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
{}, // ds_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
gemm_problem.split_k_,
|
||||
gemm_problem.m_,
|
||||
@@ -100,7 +99,6 @@ class GemmProfiler
|
||||
gemm_problem.k_,
|
||||
gemm_problem.stride_a_,
|
||||
gemm_problem.stride_b_,
|
||||
{}, // stride_Ds
|
||||
gemm_problem.stride_c_,
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user