diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index f813610890..9a4ec64242 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -304,6 +304,14 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; +template +struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill +{ + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + template struct GemmTypeConfig; diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 2b8f8b32ae..0f323cb0e3 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -91,7 +91,11 @@ int main(int argc, char* argv[]) try { +#if CK_TILE_USE_WMMA + return !run_gemm_example(arg_parser); +#else return !run_gemm_example(arg_parser); +#endif } catch(const std::runtime_error& e) { diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index e926c3cbaa..cc980a75f7 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -176,16 +176,43 @@ template auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if(ck_tile::is_gfx12_supported()) + { + // TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase + constexpr int divisor = 2; + constexpr int kABK0PerLane = 2; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + kABK0PerLane, + GemmConfig::K_Warp_Tile / divisor / kABK0PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } } template diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index f8e21d5ee4..1fb53909ac 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -190,6 +190,30 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr bool kPadK = true; }; +template +struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 32 / sizeof(PrecType); + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadK = true; + + static constexpr int kBlockPerCu = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; +}; + template struct PipelineTypeTraits; @@ -266,16 +290,43 @@ template auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if(ck_tile::is_gfx12_supported()) + { + // TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase + constexpr int divisor = 2; + constexpr int kABK0PerLane = 2; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + kABK0PerLane, + GemmConfig::K_Warp_Tile / divisor / kABK0PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } } template (argc, argv); +#else return !run_grouped_gemm_example(argc, argv); +#endif } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 93117e5b75..280da8d333 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -268,6 +268,9 @@ int main(int argc, char* argv[]) try { +#if defined(CK_TILE_USE_WMMA) + return !run_flatmm_example(argc, argv); +#else int warp_tile = arg_parser.get_int("warp_tile"); if(warp_tile == 0) { @@ -285,6 +288,7 @@ int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); } +#endif } catch(const std::runtime_error& e) { diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 64e141860e..8f8f65e214 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -86,6 +86,14 @@ struct FlatmmConfig16_950 : public FlatmmConfig16 static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128; }; +template +struct FlatmmConfig16_Wmma : public FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 64; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + template struct GemmBasicTypeConfig; @@ -183,8 +191,10 @@ auto create_args(int argc, char* argv[]) .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "splitK value") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") +#if !defined(CK_TILE_USE_WMMA) .insert( "warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)") +#endif .insert("json", "0", "0: No Json, 1: Dump Results in Json format") .insert("jsonfile", "flatmm_basic.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index b6b92b5801..63d0a80555 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -43,15 +43,40 @@ auto shuffle_b(const ck_tile::HostTensor& t) int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; - int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2) - : (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4); - ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, - FlatmmConfig::N_Warp_Tile, - k_ / FlatmmConfig::K_Warp_Tile, - divisor, - FlatmmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + if(ck_tile::is_gfx12_supported()) + { + // TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase + constexpr int divisor = 2; + constexpr int kABK0PerLane = 2; + ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, + FlatmmConfig::N_Warp_Tile, + k_ / FlatmmConfig::K_Warp_Tile, + divisor, + kABK0PerLane, + FlatmmConfig::K_Warp_Tile / divisor / kABK0PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, + FlatmmConfig::N_Warp_Tile, + k_ / FlatmmConfig::K_Warp_Tile, + divisor, + FlatmmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } } template diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 20ca976590..a924279d52 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -127,7 +127,10 @@ struct FlatmmKernel return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize); + } CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const FlatmmHostArgs& hostArgs) diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 1a28366e24..0cae1a467d 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -185,11 +185,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV } template - CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - index_t num_loop, - void* p_smem) const + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const { static_assert( std::is_same_v> && diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 3ca79fc46e..5fd1fb8d39 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -237,8 +237,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() { - using TileShape = typename Problem::BlockGemmShape; + using TileShape = typename Problem::BlockGemmShape; +#if defined(__gfx11__) + constexpr index_t scale = 4; +#else constexpr index_t scale = get_warp_size() == 32 ? 2 : 1; +#endif if constexpr(TileShape::WarpTile::at(I1) == 32) { return TileShape::WarpTile::at(I2) * scale / 2; @@ -342,7 +346,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape @@ -350,8 +354,13 @@ struct UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNum = BlockSize / WaveSize; - constexpr index_t KBPerLoad = GetKBPerLoad(); - constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim + constexpr index_t KBPerLoad = GetKBPerLoad(); +#if defined(__gfx11__) + constexpr index_t KRepeatInWave = 2; +#else + constexpr index_t KRepeatInWave = 1; +#endif + constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; constexpr index_t KRepeat = 1; static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); @@ -362,16 +371,15 @@ struct UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t NRepeat = 1; constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; - return make_static_tile_distribution( tile_distribution_encoding< - sequence, // ? + sequence, // ? tuple, // second direction sequence>, // first direction // wave in blk, // thd in wave // // - tuple, sequence<1, 2>>, // which direction - tuple, sequence<2, 2>>, // which index + tuple, sequence<0, 1, 2>>, // which direction + tuple, sequence<1, 2, 2>>, // which index // sequence<1, 1, 2, 2>, sequence<0, 3, 0, 3>>{}); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 8b95639516..71ca907c07 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -89,14 +89,19 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() { using TileShape = typename Problem::BlockGemmShape; +#if defined(__gfx11__) + constexpr index_t scale = 4; +#else + constexpr index_t scale = get_warp_size() == 32 ? 2 : 1; +#endif if constexpr(TileShape::WarpTile::at(I1) == 32) { - return TileShape::WarpTile::at(I2) / 2; + return TileShape::WarpTile::at(I2) * scale / 2; } else { static_assert(TileShape::WarpTile::at(I1) == 16); - return TileShape::WarpTile::at(I2) / 4; + return TileShape::WarpTile::at(I2) * scale / 4; } } @@ -192,7 +197,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -200,8 +205,13 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNum = BlockSize / WaveSize; - constexpr index_t KBPerLoad = GetKBPerLoad(); - constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim + constexpr index_t KBPerLoad = GetKBPerLoad(); +#if defined(__gfx11__) + constexpr index_t KRepeatInWave = 2; +#else + constexpr index_t KRepeatInWave = 1; +#endif + constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; constexpr index_t KRepeat = 1; static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); @@ -212,16 +222,15 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy constexpr index_t NRepeat = 1; constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; - return make_static_tile_distribution( tile_distribution_encoding< - sequence, // ? + sequence, // ? tuple, // second direction sequence>, // first direction // wave in blk, // thd in wave // // - tuple, sequence<1, 2>>, // which direction - tuple, sequence<2, 2>>, // which index + tuple, sequence<0, 1, 2>>, // which direction + tuple, sequence<1, 2, 2>>, // which index // sequence<1, 1, 2, 2>, sequence<0, 3, 0, 3>>{}); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp index b91c211d91..290f24a7f5 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp @@ -189,11 +189,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 } template - CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - index_t num_loop, - void* p_smem) const + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const { static_assert( std::is_same_v> && diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 7104e318d2..129eac6557 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -146,10 +146,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr index_t mfma_per_wg = 1; #endif static constexpr index_t dsread_per_wg = - WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize; - static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0); - - static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp; + max(index_t(WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1); +#if defined(__HIP_DEVICE_COMPILE__) + static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) % + Problem::VectorLoadSize == + 0); +#endif + static constexpr index_t dsread_num_perK = + WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize / Problem::VectorLoadSize; static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp); static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp; static constexpr index_t Aload_num_perK = dswrite_num_perK; @@ -499,12 +503,12 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction> - CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - index_t num_loop, - void* p_smem_ping, - void* p_smem_pong) const + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const { static_assert( std::is_same_v>,