diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 8acfea4580..25734a23ce 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -283,24 +283,18 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeADramTileDistribution()); + return Policy::template MakeAAsyncLoadBytesDramWindow( + a_dram_block_window_tmp[number{}]); }, number{}); - // B DRAM window(s) for load + // B DRAM tile window(s) for async byte-based load auto b_tile_windows = generate_tuple( [&](auto idx) { - return make_tile_window( - b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp[number{}].get_window_origin(), - Policy::template MakeBDramTileDistribution()); + return Policy::template MakeBAsyncLoadBytesDramWindow( + b_dram_block_window_tmp[number{}]); }, number{}); @@ -334,21 +328,24 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}, number{}); + const auto b_dram_tile_window_step = + make_tuple(number<0>{}, number{}); - constexpr ADramTileWindowStep a_dram_tile_window_step = - is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); - constexpr BDramTileWindowStep b_dram_tile_window_step = - is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + // Define async load tile lambda + auto async_load_tile_ = [](auto lds, auto dram) { + async_load_tile(lds, dram, number<-1>{}, true_type{}, true_type{}); + }; // read A(0), B(0) from DRAM to LDS window(0) // and advance the DRAM windows - Base::GlobalPrefetchAsync( - a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step); - Base::GlobalPrefetchAsync( - b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step); + async_load_tile_(a_copy_lds_window0, a_tile_windows[number<0>{}]); + move_tile_window(a_tile_windows[number<0>{}], a_dram_tile_window_step); + async_load_tile_(b_copy_lds_window0, b_tile_windows[number<0>{}]); + move_tile_window(b_tile_windows[number<0>{}], b_dram_tile_window_step); // initialize block gemm auto block_gemm = BlockGemm(); @@ -359,10 +356,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}], a_dram_tile_window_step); - Base::GlobalPrefetchAsync( - b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step); + async_load_tile_(a_copy_lds_window1, a_tile_windows[number<0>{}]); + move_tile_window(a_tile_windows[number<0>{}], a_dram_tile_window_step); + async_load_tile_(b_copy_lds_window1, b_tile_windows[number<0>{}]); + move_tile_window(b_tile_windows[number<0>{}], b_dram_tile_window_step); // tile distribution for the register tiles constexpr auto ALdsTileDistr = @@ -423,10 +420,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}], a_dram_tile_window_step); - Base::GlobalPrefetchAsync( - b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step); + async_load_tile_(a_copy_lds_window0, a_tile_windows[number<0>{}]); + move_tile_window(a_tile_windows[number<0>{}], a_dram_tile_window_step); + async_load_tile_(b_copy_lds_window0, b_tile_windows[number<0>{}]); + move_tile_window(b_tile_windows[number<0>{}], b_dram_tile_window_step); if(HasHotLoop) { @@ -445,12 +442,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}], - a_dram_tile_window_step); - Base::GlobalPrefetchAsync(b_copy_lds_window1, - b_tile_windows[number<0>{}], - b_dram_tile_window_step); + async_load_tile_(a_copy_lds_window1, a_tile_windows[number<0>{}]); + move_tile_window(a_tile_windows[number<0>{}], a_dram_tile_window_step); + async_load_tile_(b_copy_lds_window1, b_tile_windows[number<0>{}]); + move_tile_window(b_tile_windows[number<0>{}], b_dram_tile_window_step); // C(i-3) = A(i-3) @ B(i-3) block_gemm(c_block_tile, a_block_tile0, b_block_tile0); HotLoopScheduler(); @@ -466,12 +461,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}], - a_dram_tile_window_step); - Base::GlobalPrefetchAsync(b_copy_lds_window0, - b_tile_windows[number<0>{}], - b_dram_tile_window_step); + async_load_tile_(a_copy_lds_window0, a_tile_windows[number<0>{}]); + move_tile_window(a_tile_windows[number<0>{}], a_dram_tile_window_step); + async_load_tile_(b_copy_lds_window0, b_tile_windows[number<0>{}]); + move_tile_window(b_tile_windows[number<0>{}], b_dram_tile_window_step); // C(i-2) = A(i-2) @ B(i-2) block_gemm(c_block_tile, a_block_tile1, b_block_tile1); HotLoopScheduler(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index ffe889af41..c0c9edbc3d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -18,6 +18,9 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked; static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked; + static constexpr index_t kDramLoadPackBytes = 128; + static constexpr index_t DWORDx4 = 16; + template > CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() @@ -93,6 +96,141 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy } } + // Methods for async byte-based loading (similar to mx_flatmm) + template + CK_TILE_DEVICE static constexpr auto MakeAAsyncLoadBytesDramWindow(const WindowTmp& window_tmp) + { + using ADataType = remove_cvref_t; + constexpr index_t APackedSize = numeric_traits::PackedSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto ndims = std::decay_t::get_num_of_dimension(); + static_assert(ndims == 2, "only support 2D tensor"); + auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); + const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); + + constexpr index_t K2 = DWORDx4; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; + const index_t K0 = cols / (K1 * K2 * APackedSize); + const auto col_lens = make_tuple(K0, number{}, number{}); + + constexpr index_t M1 = 4; + const index_t M0 = integer_divide_ceil(rows, M1); + const auto row_lens = make_tuple(M0, number{}); + + const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens)); + const auto desc_0 = decltype(d0)( + d0.get_transforms(), tensor_view_tmp.get_tensor_descriptor().get_element_space_size()); + const auto desc_1 = transform_tensor_descriptor( + desc_0, + make_tuple(make_pass_through_transform(M0), + make_xor_transform(make_tuple(number{}, number{})), + make_pass_through_transform(K0), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{})); + const auto desc = + transform_tensor_descriptor(desc_1, + make_tuple(make_merge_transform_v3_division_mod(row_lens), + make_merge_transform_v3_division_mod(col_lens)), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); + auto&& byte_tensor_view = make_tensor_view(byte_ptr, desc); + + auto&& origin_tmp = window_tmp.get_window_origin(); + + // Create tile distribution inline (reuse K2, K1, K0 from above) + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t M2_dstr = WaveSize / K1; + constexpr index_t M1_dstr = BlockSize / WaveSize; + constexpr index_t M0_dstr = MPerBlock / (M2_dstr * M1_dstr); + + constexpr auto tile_dstr = make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + + return make_tile_window(byte_tensor_view, + make_tuple(number{}, number{}), + {origin_tmp[0], origin_tmp[1] / APackedSize}, + tile_dstr); + } + + template + CK_TILE_DEVICE static constexpr auto MakeBAsyncLoadBytesDramWindow(const WindowTmp& window_tmp) + { + using BDataType = remove_cvref_t; + constexpr index_t BPackedSize = numeric_traits::PackedSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto ndims = std::decay_t::get_num_of_dimension(); + static_assert(ndims == 2, "only support 2D tensor"); + auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); + const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); + + constexpr index_t K2 = DWORDx4; + constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; + const index_t K0 = cols / (K1 * K2 * BPackedSize); + const auto col_lens = make_tuple(K0, number{}, number{}); + + constexpr index_t N1 = 4; + const index_t N0 = integer_divide_ceil(rows, N1); + const auto row_lens = make_tuple(N0, number{}); + + const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens)); + const auto desc_0 = decltype(d0)( + d0.get_transforms(), tensor_view_tmp.get_tensor_descriptor().get_element_space_size()); + const auto desc_1 = transform_tensor_descriptor( + desc_0, + make_tuple(make_pass_through_transform(N0), + make_xor_transform(make_tuple(number{}, number{})), + make_pass_through_transform(K0), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{})); + const auto desc = + transform_tensor_descriptor(desc_1, + make_tuple(make_merge_transform_v3_division_mod(row_lens), + make_merge_transform_v3_division_mod(col_lens)), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); + auto&& byte_tensor_view = make_tensor_view(byte_ptr, desc); + + auto&& origin_tmp = window_tmp.get_window_origin(); + + // Create tile distribution inline (reuse K2, K1, K0 from above) + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t N2_dstr = WaveSize / K1; + constexpr index_t N1_dstr = BlockSize / WaveSize; + constexpr index_t N0_dstr = NPerBlock / (N2_dstr * N1_dstr); + + constexpr auto tile_dstr = make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + + return make_tile_window(byte_tensor_view, + make_tuple(number{}, number{}), + {origin_tmp[0], origin_tmp[1] / BPackedSize}, + tile_dstr); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() {