From 1fc7bdd402939e66a7c387eaa7d98852ff45a224 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Fri, 19 Dec 2025 10:28:13 +0800 Subject: [PATCH] [CK_TILE] MX Flatmm Use Byte Pointer Arithmetic for A Tensor (#3446) * A as bytes * Reformat with static_for_product [ROCm/composable_kernel commit: 2220cbaba75892de5780f8f556554ee92ba19e29] --- example/ck_tile/18_flatmm/CMakeLists.txt | 1 + include/ck_tile/core/utility/functional.hpp | 28 ++ ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 475 ++++++++---------- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 119 ++--- 4 files changed, 309 insertions(+), 314 deletions(-) diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 0fd819d552..696cb4f60b 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -20,6 +20,7 @@ if(has_supported_gpu) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() + list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1") add_executable(tile_example_flatmm_basic flatmm_basic.cpp) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index 90740dcbe3..898d21574e 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -82,6 +82,34 @@ struct static_for<0, N, 1> : detail::make_applier using detail::make_applier::operator(); }; +template +struct static_for_product; +template +struct static_for_product> : public static_for +{ +}; +template +struct static_for_product> : public static_for +{ +}; +template +struct static_for_product> : public static_for<0, I, 1> +{ +}; +template +struct static_for_product +{ + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f) const + { + static_for_product{}([=](auto I) { + static_for_product{}([=](auto... Is) { // + f(I, Is...); + }); + }); + } +}; + struct identity { template diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index c3843f1044..125d32aad8 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -521,43 +521,40 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( - a_copy_dram_window_tmp.get_bottom_tensor_view()), - a_copy_dram_window_tmp.get_window_lengths(), - a_copy_dram_window_tmp.get_window_origin(), - PipelinePolicy::template MakeMX_ADramTileDistribution()); + auto a_dram_window = PipelinePolicy::template MakeMX_AAsyncLoadBytesDramWindow( + a_copy_dram_window_tmp); __builtin_amdgcn_sched_barrier(0); // A tile in LDS - ADataType* p_a_lds_ping = static_cast(p_smem_ping); - ADataType* p_a_lds_pong = static_cast(p_smem_pong); - constexpr auto a_lds_block_desc = - PipelinePolicy::template MakeMX_ALdsBlockDescriptor(); + PipelinePolicy::template MakeMX_ALdsBytesBlockDescriptor(); - auto a_lds_block_ping = - make_tensor_view(p_a_lds_ping, a_lds_block_desc); - auto a_lds_block_pong = - make_tensor_view(p_a_lds_pong, a_lds_block_desc); + auto a_lds_block_ping = make_tensor_view( + static_cast(p_smem_ping), a_lds_block_desc); + auto a_lds_block_pong = make_tensor_view( + static_cast(p_smem_pong), a_lds_block_desc); - auto a_store_lds_window_ping = make_tile_window( - a_lds_block_ping, make_tuple(number{}, number{}), {0, 0}); - auto a_store_lds_window_pong = make_tile_window( - a_lds_block_pong, make_tuple(number{}, number{}), {0, 0}); + auto a_store_lds_window_ping = make_tile_window( // + a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}); + auto a_store_lds_window_pong = make_tile_window( // + a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}); // ping-pong window for A LDS auto a_warp_window_ping = make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeMX_ALDS_TileDistribution()); + PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); auto a_warp_window_pong = make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeMX_ALDS_TileDistribution()); + PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution()); // B flat DRAM window for load @@ -624,7 +621,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { @@ -639,23 +636,23 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto mIter_pack) { - static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset( + static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { + scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset( scale_a_dram_window, - mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); + impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); }); // move Scale A window to next K move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); // prefetch Scale B - static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { - static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset( + static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { + scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset( scale_b_dram_window, - nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); + inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); }); // move Scale B window to next K @@ -666,7 +663,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, MIterPerWarp> @@ -685,7 +682,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + a_warp_window_ping, + tuple, number>{}); }); __builtin_amdgcn_sched_barrier(0); @@ -706,63 +704,55 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter_pack) { - static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset( + static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) { + scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset( scale_a_dram_window, - mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); + impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); }); - static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { - static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset( + static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { + scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset( scale_b_dram_window, - nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); + inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); }); // GEMM 2i - static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { - static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { - static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; - constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; - // warp GEMM - WG{}.template - operator()( - c_warp_tensors(number{})(number{}), - bit_cast( - a_warp_tensor(number{})), - bit_cast( - b_warp_tensor_ping(number{})(number{})), - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); - }); - // preload next A from lds - constexpr auto addr = - m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NPackIterPerWarp - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_ping, - tuple, number>{}); - } - }); - }); - }); + static_for_product, + number, + number, + number, + number, + number>{}( // + [&](auto ikpack, auto impack, auto inpack, auto ikxdl, auto imxdl, auto inxdl) { + constexpr auto n_iter = inpack * NXdlPack + inxdl; + constexpr auto m_iter = impack * MXdlPack + imxdl; + constexpr auto k_iter = ikpack * KXdlPack + ikxdl; + constexpr auto APackIter = ikxdl * MXdlPack + imxdl; // idx inside a xdl pack + // warp GEMM + WG{}.template operator()( + c_warp_tensors(number{})(number{}), + bit_cast(a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), + scale_a_tile_tensor_ping(impack)(ikpack).get_thread_buffer()[0], + scale_b_tile_tensor_ping(inpack)(ikpack).get_thread_buffer()[0]); + // preload next A from lds + constexpr auto addr = m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (n_iter == NIterPerWarp - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( // + a_warp_window_ping, + tuple, + number>{}); + } }); - }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished s_waitcnt< // vmcnt Bload_num + ScaleAload_num + ScaleBload_num>(); @@ -770,7 +760,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + a_warp_window_pong, + tuple, number>{}); }); HotLoopScheduler(); @@ -802,63 +793,55 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter_pack) { - static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset( + static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) { + scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset( scale_a_dram_window, - mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); + impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); }); - static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { - static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset( + static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { + scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset( scale_b_dram_window, - nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); + inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); }); // GEMM 2i+1 - static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { - static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { - static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; - constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; - // warp GEMM - WG{}.template - operator()( - c_warp_tensors(number{})(number{}), - bit_cast( - a_warp_tensor(number{})), - bit_cast( - b_warp_tensor_pong(number{})(number{})), - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], // scale A - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); // scale B - }); - // preload next A from lds - constexpr auto addr = - m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NPackIterPerWarp - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_pong, - tuple, number>{}); - } - }); - }); - }); + static_for_product, + number, + number, + number, + number, + number>{}( // + [&](auto ikpack, auto impack, auto inpack, auto ikxdl, auto imxdl, auto inxdl) { + constexpr auto m_iter = impack * MXdlPack + imxdl; + constexpr auto n_iter = inpack * NXdlPack + inxdl; + constexpr auto k_iter = ikpack * KXdlPack + ikxdl; + constexpr auto APackIter = ikxdl * MXdlPack + imxdl; // idx inside a xdl pack + // warp GEMM + WG{}.template operator()( + c_warp_tensors(number{})(number{}), + bit_cast(a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_pong(number{})(number{})), + scale_a_tile_tensor_pong(impack)(ikpack).get_thread_buffer()[0], // scale A + scale_b_tile_tensor_pong(inpack)(ikpack).get_thread_buffer()[0]); // scale B + // preload next A from lds + constexpr auto addr = m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (n_iter == NIterPerWarp - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( // + a_warp_window_pong, + tuple, + number>{}); + } }); - }); // barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished s_waitcnt< // vmcnt Bload_num + ScaleAload_num + ScaleBload_num>(); @@ -866,7 +849,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + a_warp_window_ping, + tuple, number>{}); }); HotLoopScheduler(); }; @@ -904,62 +888,54 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto mIter_pack) { - static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset( + static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { + scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset( scale_a_dram_window, - mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); + impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k); }); }); - static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { - static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset( + static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { + scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset( scale_b_dram_window, - nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); + inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k); }); }); // GEMM loopK-1 - static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { - static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { - static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; - constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; - // warp GEMM - WG{}.template - operator()( - c_warp_tensors(number{})(number{}), - bit_cast( - a_warp_tensor(number{})), - bit_cast( - b_warp_tensor_ping(number{})(number{})), - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], // scale A - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); // scale B - }); - // preload next A from lds - constexpr auto addr = - m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NPackIterPerWarp - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_ping, - tuple, number>{}); - } - }); - }); - }); + static_for_product, + number, + number, + number, + number, + number>{}( // + [&](auto ikpack, auto impack, auto inpack, auto ikxdl, auto imxdl, auto inxdl) { + constexpr auto m_iter = impack * MXdlPack + imxdl; + constexpr auto n_iter = inpack * NXdlPack + inxdl; + constexpr auto k_iter = ikpack * KXdlPack + ikxdl; + constexpr auto APackIter = ikxdl * MXdlPack + imxdl; // idx inside a xdl pack + // warp GEMM + WG{}.template operator()( + c_warp_tensors(number{})(number{}), + bit_cast(a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), + scale_a_tile_tensor_ping(impack)(ikpack).get_thread_buffer()[0], // scale A + scale_b_tile_tensor_ping(inpack)(ikpack).get_thread_buffer()[0]); // scale B + // preload next A from lds + constexpr auto addr = m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (n_iter == NIterPerWarp - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( // + a_warp_window_ping, + tuple, + number>{}); + } }); - }); // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished s_waitcnt< // vmcnt Bload_num + ScaleAload_num + ScaleBload_num>(); @@ -970,97 +946,82 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, number>{}); + a_warp_window_pong, + tuple, number>{}); }); Last2ndHotLoopScheduler(); // GEMM loopK - static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { - static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { - static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; - constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; - // warp GEMM - WG{}.template - operator()( - c_warp_tensors(number{})(number{}), - bit_cast( - a_warp_tensor(number{})), - bit_cast( - b_warp_tensor_pong(number{})(number{})), - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], // scale A - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); // scale B - }); - // preload next A from lds - constexpr auto addr = - m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NPackIterPerWarp - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_pong, - tuple, number>{}); - } - }); - }); - }); + static_for_product, + number, + number, + number, + number, + number>{}( // + [&](auto ikpack, auto impack, auto inpack, auto ikxdl, auto imxdl, auto inxdl) { + constexpr auto m_iter = impack * MXdlPack + imxdl; + constexpr auto n_iter = inpack * NXdlPack + inxdl; + constexpr auto k_iter = ikpack * KXdlPack + ikxdl; + constexpr auto APackIter = ikxdl * MXdlPack + imxdl; // idx inside a xdl pack + // warp GEMM + WG{}.template operator()( + c_warp_tensors(number{})(number{}), + bit_cast(a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_pong(number{})(number{})), + scale_a_tile_tensor_pong(impack)(ikpack).get_thread_buffer()[0], // scale A + scale_b_tile_tensor_pong(inpack)(ikpack).get_thread_buffer()[0]); // scale B + // preload next A from lds + constexpr auto addr = m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (n_iter == NIterPerWarp - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = + load_tile_with_offset(a_warp_window_pong, + tuple, + number>{}); + } }); - }); LastHotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) { // GEMM loopK - static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { - static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { - static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { - static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { - static_for<0, MXdlPack, 1>{}([&](auto imxdl) { - constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; - constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; - constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; - // warp GEMM - WG{}.template - operator()( - c_warp_tensors(number{})(number{}), - bit_cast( - a_warp_tensor(number{})), - bit_cast( - b_warp_tensor_ping(number{})(number{})), - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) - .get_thread_buffer()[0], // scale A - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) - .get_thread_buffer()[0]); // scale B - }); - // preload next A from lds - constexpr auto addr = - m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; - if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NPackIterPerWarp - 1)) - { - constexpr auto AmIter = addr % 2 + addr / 4 * 2; - constexpr auto AkIter = addr / 2 % 2; - a_warp_tensor(number{}) = load_tile_with_offset( - a_warp_window_ping, - tuple, number>{}); - } - }); - }); - }); + static_for_product, + number, + number, + number, + number, + number>{}( // + [&](auto ikpack, auto impack, auto inpack, auto ikxdl, auto imxdl, auto inxdl) { + constexpr auto m_iter = impack * MXdlPack + imxdl; + constexpr auto n_iter = inpack * NXdlPack + inxdl; + constexpr auto k_iter = ikpack * KXdlPack + ikxdl; + constexpr auto APackIter = ikxdl * MXdlPack + imxdl; // idx inside a xdl pack + // warp GEMM + WG{}.template operator()( + c_warp_tensors(number{})(number{}), + bit_cast(a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), + scale_a_tile_tensor_ping(impack)(ikpack).get_thread_buffer()[0], // scale A + scale_b_tile_tensor_ping(inpack)(ikpack).get_thread_buffer()[0]); // scale B + // preload next A from lds + constexpr auto addr = m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (n_iter == NIterPerWarp - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = + load_tile_with_offset(a_warp_window_ping, + tuple, + number>{}); + } }); - }); LastHotLoopScheduler(); } else diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index e188ddec61..081fbbe48a 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -75,19 +75,41 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy return BlockFlatmmASmemBSmemCRegV1{}; } - template - CK_TILE_DEVICE static constexpr auto - MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view) + CK_TILE_DEVICE static constexpr auto MakeMX_ABytesDramTileDistribution() { - const auto& naive_desc = naive_view.get_tensor_descriptor(); - constexpr auto ndims = remove_cvref_t::get_num_of_dimension(); - static_assert(ndims == 2, "only support 2D tensor"); - const auto rows = naive_desc.get_length(number<0>{}); - const auto cols = naive_desc.get_length(number<1>{}); + constexpr index_t K2 = DWORDx4; // 16 bytes + constexpr index_t K1 = kDramLoadPackBytes / K2; // 8 + constexpr index_t K0 = KPerBlock / (K1 * K2 * APackedSize); // KPerBlock/256/packsize - constexpr index_t K2 = AK1; // f4=32; f8=16 + constexpr index_t M2 = WaveSize / K1; // 8 + constexpr index_t M1 = BlockSize / WaveSize; // 4 + constexpr index_t M0 = MPerBlock / (M2 * M1); + static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); + static_assert(K0 * K1 * K2 * APackedSize == KPerBlock, + "K0, K1, K2 must cover whole KPerBlock!"); + + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence<1>, + tuple, sequence>, // ?,4,8 1,8,32 or 2,8,16 + tuple, sequence<1, 2>>, // M1 M2,K1 + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, // M0,K0,K2 + sequence<0, 0, 2>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto + MakeMX_AAsyncLoadBytesDramWindow(const WindowTmp& window_tmp) + { + constexpr auto ndims = std::decay_t::get_num_of_dimension(); + static_assert(ndims == 2, "only support 2D tensor"); + auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); + const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); + + constexpr index_t K2 = DWORDx4; // 16 bytes constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 - const index_t K0 = cols / (K1 * K2); + const index_t K0 = cols / (K1 * K2 * APackedSize); const auto col_lens = make_tuple(K0, number{}, number{}); constexpr index_t M1 = 4; // so that we can use imm offset to load lds @@ -110,41 +132,24 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy make_merge_transform_v3_division_mod(col_lens)), make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), make_tuple(sequence<0>{}, sequence<1>{})); - // printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1)); - return tensor_view, - TensorView::DstInMemOp>{naive_view.buf_, desc}; + auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); + auto&& byte_tensor_view = make_tensor_view(byte_ptr, desc); + + auto&& origin_tmp = window_tmp.get_window_origin(); + return make_tile_window(byte_tensor_view, + make_tuple(number{}, number{}), + {origin_tmp[0], origin_tmp[1] / APackedSize}, + MakeMX_ABytesDramTileDistribution()); } - CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBytesBlockDescriptor() { - constexpr index_t K2 = AK1; // f4=32; f8=16 - constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 - constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 - - constexpr index_t M2 = WaveSize / K1; // 8 - constexpr index_t M1 = BlockSize / WaveSize; // 4 - constexpr index_t M0 = MPerBlock / (M2 * M1); - static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); - static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); - - return make_static_tile_distribution( - tile_distribution_encoding< // - sequence<1>, - tuple, sequence>, // ?,4,8 1,8,32 or 2,8,16 - tuple, sequence<1, 2>>, // M1 M2,K1 - tuple, sequence<2, 1>>, - sequence<1, 2, 2>, // M0,K0,K2 - sequence<0, 0, 2>>{}); - } - - CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor() - { - constexpr index_t K2 = AK1; // f4=32; f8=16 + constexpr index_t K2 = AK1 / APackedSize; // 16 constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8 - constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 - static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); + constexpr index_t K0 = KPerBlock / (K1 * AK1); // KPerBlock/256 + static_assert(K0 * K1 * K2 * APackedSize == KPerBlock, + "K0, K1, K2 must cover whole KPerBlock!"); constexpr index_t M3 = 4; // so that we can use imm offset to load lds constexpr index_t M2 = WaveSize / K1 / M3; // 2 @@ -152,7 +157,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16 static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!"); - constexpr index_t Pad = 4 * K2; // 4 * 32 + constexpr index_t Pad = 4 * K2; // 4 dwords constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( // make_tuple(number{}, @@ -205,7 +210,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy return a_lds_block_desc; } - CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDSBytes_TileDistribution() { static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1"); @@ -213,20 +218,21 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy return make_static_tile_distribution( tile_distribution_encoding< // sequence, - tuple, sequence>, + tuple, sequence>, tuple, sequence<2, 1>>, tuple, sequence<0, 2>>, sequence<2>, sequence<1>>{}); else - return make_static_tile_distribution(tile_distribution_encoding< // - sequence, - tuple, - sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, - sequence<2, 2>, - sequence<0, 2>>{}); + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<2, 2>, + sequence<0, 2>>{}); } CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution() @@ -364,8 +370,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() / - APackedSize; + return sizeof(ADataType) * MakeMX_ALdsBytesBlockDescriptor().get_element_space_size(); } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return GetSmemSizeA(); } @@ -383,10 +388,10 @@ struct MXFlatmmPipelineAgBgCrPolicy } FORWARD_METHOD_(GetBlockFlatmm); - FORWARD_METHOD_(MakeMX_AAsyncLoadDramDescriptor); - FORWARD_METHOD_(MakeMX_ADramTileDistribution); - FORWARD_METHOD_(MakeMX_ALdsBlockDescriptor); - FORWARD_METHOD_(MakeMX_ALDS_TileDistribution); + FORWARD_METHOD_(MakeMX_AAsyncLoadBytesDramWindow); + FORWARD_METHOD_(MakeMX_ABytesDramTileDistribution); + FORWARD_METHOD_(MakeMX_ALdsBytesBlockDescriptor); + FORWARD_METHOD_(MakeMX_ALDSBytes_TileDistribution); FORWARD_METHOD_(MakeMX_BFlatBytesDramTileDistribution); FORWARD_METHOD_(MakeMX_BFlatBytesDramWindow); FORWARD_METHOD_(MakeMX_ScaleA_DramTileDistribution);