diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 4dc8d049bb..18980ee0f4 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -33,6 +33,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c ck_tile::sequence>; + std::cout << "CodegenFlatmmShape: " << CodegenFlatmmShape::GetName() << std::endl; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner>( argc, argv, Row{}, Col{}, Row{}); } @@ -261,11 +264,13 @@ int main(int argc, char* argv[]) { int warp_tile = arg_parser.get_int("warp_tile"); if(warp_tile == 0) - { + { + std::cout << "Running with warp tile size 16x16" << std::endl; return !run_flatmm_example(argc, argv); } else if(warp_tile == 1) { + std::cout << "Running with 32x32 tile size" << std::endl; return !run_flatmm_example(argc, argv); } else if(warp_tile == 2) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 963a6ba675..1fd58989f2 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -77,7 +77,7 @@ struct FlatmmConfig16 static constexpr int TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool DoubleSmemBuffer = false; + static constexpr bool DoubleSmemBuffer = true; }; template diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 02130ac3de..db647915b8 100755 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -618,7 +618,7 @@ struct FlatmmKernel a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong); // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I2); + auto& c_block_window = gemm_tile_windows.at(I3); // Create empty D tensors constexpr auto empty_ds_dram_windows = ck_tile::make_tuple(); diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 837aeb13e3..6ae5493f34 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -339,7 +339,37 @@ struct UniversalFlatmmPipelineAgBgCrPolicy } } } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution() + { + using ADataType = remove_cvref_t; + //using ALayout = remove_cvref_t; + constexpr index_t BlockSize = Problem::kBlockSize; + + // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + // constexpr index_t M0 = MPerBlock / (M2 * M1); + // static_assert(M0 * M1 * M2 == MPerBlock, + // "Incorrect M0, M2, M1 configuration! " + // "M0, M1, M2 must cover whole MPerBlock!"); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<2>, + sequence<1>>{}); + } template CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() {