diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 87237458c5..8249c3a8de 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -10,6 +10,6 @@ endif() # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16_F8=1 -Wno-unused-local-typedef) list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x128_F8=1 -Wno-unused-local-typedef) list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps) -list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0") +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm -enable-noalias-to-md-conversion=0") #list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --disable-schedmodel-in-sched-mi=1 -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental -mllvm --misched-bottomup=1") target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 8a94d749f3..7de8d42e58 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -33,17 +33,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c ck_tile::sequence>; - - // static constexpr index_t kM = BlockTile::at(number<0>{}); - // static constexpr index_t kN = BlockTile::at(number<1>{}); - // static constexpr index_t kK = BlockTile::at(number<2>{}); - - // static constexpr bool PermuteA = PermuteA_; - // static constexpr bool PermuteB = PermuteB_; - - // static constexpr index_t flatNPerWarp = BlockWarps::at(number<1>{}); // 4 - // static constexpr index_t flatKPerWarp = WarpTile::at(number<2>{}) * WarpTile::at(number<1>{});// 16 * 64 - // static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(number<2>{}); // 16 * 128 using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner); const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index eb80497436..4cd6b40b7f 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -71,10 +71,12 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV static constexpr index_t GetVectorSizeA() { + static_assert(PipelinePolicy::template GetVectorSizeA()==16); return PipelinePolicy::template GetVectorSizeA(); } static constexpr index_t GetVectorSizeB() { + static_assert(PipelinePolicy::template GetVectorSizeB()==16); return PipelinePolicy::template GetVectorSizeB(); } @@ -706,19 +708,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - // Prefill A0 - // if constexpr(std::is_same_v) - // { - // auto a_shuffle_tmp = make_static_distributed_tensor( - // PipelinePolicy::template MakeShuffledARegBlockDistribution()); - // shuffle_tile(a_shuffle_tmp, a_block_tile); - // const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); - // store_tile(a_copy_lds_window_ping, a_block_tile_tmp); - // } - // else - // { - // store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, a_block_tile)); - // } auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_ping, a_block_tile_tmp); __builtin_amdgcn_sched_barrier(0); @@ -760,17 +749,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // { while(iCounter > 0) { - // prefetch B(2i+1) - // static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - // static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - // move_tile_window(b_flat_dram_windows(nIter)(kIter), - // {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - // b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - // }); - // }); b_warp_tensor_pong = load_tile(b_flat_dram_window); @@ -805,15 +783,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // if constexpr(mIter==0 && nIter ==0) - // if(threadIdx.x % 16== 0 && threadIdx.x<64){ - // for(int i=0;i(b_warp_tensor_ping(nIter)(kIter).thread_buf_(i))); - // } - // } // warp GEMM WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_ping(nIter)(kIter)); // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( @@ -840,22 +811,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - // HotLoopScheduler(); + HotLoopScheduler(); //Next K - - // prefetch B(2i+2) - // static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - // static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - // move_tile_window(b_flat_dram_windows(nIter)(kIter), - // {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - // b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - // }); - // }); - b_warp_tensor_ping = load_tile(b_flat_dram_window); // Prefill A(2i+2) @@ -883,23 +841,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV a_warp_tensor.get_thread_buffer() = a_warp_tensor_ping(AwarpIter).get_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1>{}, a_warp_y_lengths)); - // warp GEMM - // if(threadIdx.x % 16 == 0 && threadIdx.x>=192){ - // for(int i=0;i(a_warp_tensor.thread_buf_(i)), type_convert(b_warp_tensor_ping(nIter)(kIter).thread_buf_(i))); - // } - - // for(int i=0;i(a_warp_tensor_ping(AwarpIter).thread_buf_[i])); - // } - // } - // if constexpr(mIter==0 && nIter ==0) - // if(threadIdx.x % 16== 0 && threadIdx.x<64){ - // for(int i=0;i(b_warp_tensor_pong(nIter)(kIter).thread_buf_(i))); - - // } - // } BWarpTensor b_warp_tensor; b_warp_tensor.get_thread_buffer() = b_warp_tensor_pong.get_y_sliced_thread_data( @@ -932,7 +873,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - // HotLoopScheduler(); + HotLoopScheduler(); iCounter--; } @@ -1042,45 +983,45 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // // __builtin_amdgcn_sched_barrier(0); // } // else if constexpr(TailNum == TailNumber::Odd) - if constexpr(TailNum == TailNumber::Odd) - { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = number{}; - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + // if constexpr(TailNum == TailNumber::Odd) + // { + // static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // constexpr auto AwarpIter = number{}; + // static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + // static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // // read C warp tensor from C block tensor + // CWarpTensor c_warp_tensor; - AWarpTensor a_warp_tensor; + // AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_warp_tensor_ping(AwarpIter).get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1>{}, a_warp_y_lengths)); + // a_warp_tensor.get_thread_buffer() = a_warp_tensor_ping(AwarpIter).get_y_sliced_thread_data( + // merge_sequences(sequence{}, a_warp_y_index_zeros), + // merge_sequences(sequence<1>{}, a_warp_y_lengths)); - // set_tile(a_warp_tensor, 1.0f); - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_warp_tensor_ping.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // // set_tile(a_warp_tensor, 1.0f); + // c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + // merge_sequences(sequence{}, c_warp_y_index_zeros), + // merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // BWarpTensor b_warp_tensor; + // b_warp_tensor.get_thread_buffer() = b_warp_tensor_ping.get_y_sliced_thread_data( + // merge_sequences(sequence{}, b_warp_y_index_zeros), + // merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - // __builtin_amdgcn_sched_barrier(0x7F6); - }); - }); - if constexpr(mIter < MIterPerWarp - 2) - { - a_warp_tensor_ping(AwarpIter) = load_tile(a_warp_windows_ping(number{})); - } - }); - } + // // write C warp tensor into C block tensor + // c_block_tile.set_y_sliced_thread_data( + // merge_sequences(sequence{}, c_warp_y_index_zeros), + // merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + // c_warp_tensor.get_thread_buffer()); + // // __builtin_amdgcn_sched_barrier(0x7F6); + // }); + // }); + // if constexpr(mIter < MIterPerWarp - 2) + // { + // a_warp_tensor_ping(AwarpIter) = load_tile(a_warp_windows_ping(number{})); + // } + // }); + // } // } return c_block_tile; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index cba7bd780d..964fdbad4c 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -23,38 +23,38 @@ struct UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); - // if constexpr(MPerXdl == 16 && NPerXdl == 16) - // { - // /*reduce transform layers,compare with old ck*/ - // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - // constexpr index_t KPack = GetSmemPackA(); + if constexpr(MPerXdl == 16 && NPerXdl == 16) + { + /*reduce transform layers,compare with old ck*/ + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); - // constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - // make_tuple(number{}, number{}, number{}), - // make_tuple(number{}, number{}, number<1>{}), - // number{}, - // number<1>{}); + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); - // constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - // a_lds_block_desc_0, - // make_tuple(make_xor_transform( - // make_tuple(number{}, number{})), - // make_pass_through_transform(number{})), - // make_tuple(sequence<1, 0>{}, sequence<2>{}), - // make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); - // constexpr auto a_lds_block_desc = transform_tensor_descriptor( - // a_lds_block_desc_permuted, - // make_tuple(make_pass_through_transform(number{}), - // make_merge_transform_v3_division_mod( - // make_tuple(number{}, number{}))), - // make_tuple(sequence<1>{}, sequence<0, 2>{}), - // make_tuple(sequence<0>{}, sequence<1>{})); + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // return a_lds_block_desc; - // } - // else + return a_lds_block_desc; + } + else { constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;