diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 5c2307ca7d..5c5460daf3 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -78,9 +78,6 @@ struct buffer_store; template struct buffer_store_if; -template -struct async_buffer_load; - #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" // TODO: strict aliasing rule seems fail when reinterpret_cast between vector type @@ -1651,90 +1648,6 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, } } -template -struct async_buffer_load<16, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(CK_TILE_LDS_ADDR T* smem, - int32x4_t rsrc /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) - { -#if defined(__gfx950__) - asm volatile("s_nop 4\n" - "buffer_load_dwordx4 %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(v_offset), "s"(rsrc), "n"(i_offset) - : "memory"); - // __builtin_amdgcn_struct_buffer_load_lds(rsrc, smem, 16, i_offset, v_offset, 0, 0, 0); -#else - ignore = smem; - ignore = rsrc; - ignore = v_offset; - ignore = i_offset; -#endif - } -}; - -template -struct async_buffer_load<12, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(CK_TILE_LDS_ADDR T* smem, - int32x4_t rsrc /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) - { -#if defined(__gfx950__) - asm volatile("s_nop 4\n" - "buffer_load_dwordx3 %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(v_offset), "s"(rsrc), "n"(i_offset) - : "memory"); - // __builtin_amdgcn_struct_buffer_load_lds(rsrc, smem, 12, i_offset, v_offset, 0, 0, 0); -#else - ignore = smem; - ignore = rsrc; - ignore = v_offset; - ignore = i_offset; -#endif - } -}; - -template -struct async_buffer_load<4, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(CK_TILE_LDS_ADDR T* smem, - int32x4_t rsrc /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) - { -#if defined(__gfx950__) - asm volatile("s_nop 4\n" - "buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(v_offset), "s"(rsrc), "n"(i_offset) - : "memory"); - // __builtin_amdgcn_struct_buffer_load_lds(rsrc, smem, 4, i_offset, v_offset, 0, 0, 0); -#else - ignore = smem; - ignore = rsrc; - ignore = v_offset; - ignore = i_offset; -#endif - } -}; - template = {}) { -#if defined(__gfx950__) - constexpr index_t bytes = sizeof(T) * N; - static_assert(bytes == 4 || bytes == 12 || bytes == 16, - "wrong! only support in dword, dwordx3, dwordx4"); - async_buffer_load{}(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - 0, - bool_constant{}); -#else static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); async_buffer_load_dword_v(smem, @@ -1767,7 +1668,6 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem, src_immediate_addr_offset, 0, bool_constant{}); -#endif } template = {}) { - // #if defined(__gfx950__) constexpr index_t bytes = sizeof(T) * N; +#if defined(__gfx950__) static_assert(bytes == 4 || bytes == 12 || bytes == 16, "wrong! only support in dword, dwordx3, dwordx4"); +#else + static_assert(bytes == 4, "wrong! not implemented vector size"); +#endif ignore = src_wave_addr_offset; ignore = src_immediate_addr_offset; if constexpr(oob_conditional_check) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 304657bada..cea819b7a8 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -546,7 +546,7 @@ struct tile_window_with_static_distribution { using LdsTileWindow = remove_cvref_t; using LdsDataType = typename LdsTileWindow::DataType; - // #if defined(__gfx950__) +#if defined(__gfx950__) using Traits = load_store_traits; using vector_t = typename Traits::vector_t; @@ -586,76 +586,71 @@ struct tile_window_with_static_distribution } }); }); - // #else +#else + // issues * warps * lanes + static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded - // // issues * warps * lanes - // static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + // TODO: LDS offset is not good for intrinsic based implementation(compiler can't + figure out + // dependency) hence avoid use offset based solution. size_per_buf should be zero + (how to + // check?) + constexpr index_t size_per_buf = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<0>{}, number<0>{})); - // // TODO: LDS offset is not good for intrinsic based implementation(compiler can't - // figure out - // // dependency) hence avoid use offset based solution. size_per_buf should be zero - // (how to - // // check?) - // constexpr index_t size_per_buf = - // lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - // make_tuple(number<0>{}, number<0>{}, number<0>{})); + constexpr index_t size_per_wave = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<1>{}, number<0>{})) - + size_per_buf; - // constexpr index_t size_per_wave = - // lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - // make_tuple(number<0>{}, number<1>{}, number<0>{})) - - // size_per_buf; + constexpr index_t size_per_issue = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<1>{}, number<0>{}, number<0>{})) - + size_per_buf; - // constexpr index_t size_per_issue = - // lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - // make_tuple(number<1>{}, number<0>{}, number<0>{})) - - // size_per_buf; + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - // const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + using Traits = load_store_traits; - // using Traits = load_store_traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; - // using vector_t = typename Traits::vector_t; - // using SFC_Ys = typename Traits::SFC_Ys; + // TODO: we force CK_TILE_LDS_ADDR + CK_TILE_LDS_ADDR LdsDataType* smem = + lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value; - // // TODO: we force CK_TILE_LDS_ADDR - // CK_TILE_LDS_ADDR LdsDataType* smem = - // lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value; + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; - // // loop over thread tensor space [y0, y1, ...] - // static_for<0, NumCoord, 1>{}([&](auto iCoord) { - // /// TODO: use structure binding (to be captured later) if compiled in C++20 - // auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; - // auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; - // static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - // constexpr auto iAccess = number{}; + // read from bottom tensor + get_bottom_tensor_view().template async_get_vectorized_elements( + smem, bottom_tensor_thread_coord, 0, bool_constant{}); - // // read from bottom tensor - // get_bottom_tensor_view().template - // async_get_vectorized_elements( - // smem, bottom_tensor_thread_coord, 0, - // bool_constant{}); + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); - // // move thread coordinate - // if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) - // { - // constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); - // constexpr auto idx_diff_ps_ys = container_concat( - // generate_tuple([&](auto) { return number<0>{}; }, - // number{}), idx_diff_ys); + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); - // move_window_adaptor_and_bottom_tensor_thread_coordinate( - // window_adaptor_thread_coord, bottom_tensor_thread_coord, - // idx_diff_ps_ys); - - // smem += size_per_issue; // Note we manually increase the per-issue - // offset - // } - // }); - // }); - // #endif + smem += size_per_issue; // Note we manually increase the per-issue + offset + } + }); + }); +#endif } template diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 81a8cb3833..3c7a75485d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -33,11 +33,11 @@ struct GemmPipelineAgBgCrImplBase } template - CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow& dst_block_tile, + CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow& dst_block_window, SrcTileWindow& dram_tile_window, const DramTileWindowStep& dram_tile_window_step) const { - async_load_tile(dst_block_tile, dram_tile_window); + async_load_tile(dst_block_window, dram_tile_window); move_tile_window(dram_tile_window, dram_tile_window_step); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_async.cpp b/test/ck_tile/gemm/test_gemm_pipeline_async.cpp deleted file mode 100644 index e279e283cf..0000000000 --- a/test/ck_tile/gemm/test_gemm_pipeline_async.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include "test_gemm_pipeline_kernel_types.hpp" -#include "test_gemm_pipeline_util.hpp" -#include "gtest/gtest.h" - -template -class TestCkTileGemmPipelineAsync : public TestCkTileGemmPipeline -{ -}; - -#define TEST_SUITE_NAME TestCkTileGemmPipelineAsync - -TYPED_TEST_SUITE(TestCkTileGemmPipelineAsync, KernelTypesAsync); - -#include "test_gemm_pipeline_ut_cases.inc" - -#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 77e47bb540..bd1502516b 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -20,7 +20,6 @@ using Interwave = ck_tile::integral_constant; using CompV3 = ck_tile::integral_constant; using CompV4 = ck_tile::integral_constant; -using Async = ck_tile::integral_constant; // clang-format off using KernelTypesMem = ::testing::Types< @@ -60,8 +59,4 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4> >; -using KernelTypesAsync = ::testing::Types< - std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Async> ->; - // clang-format on diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index f12a842440..85742cb3de 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -36,8 +36,7 @@ enum struct GemmPipelineType { Mem, CompV3, - CompV4, - Async + CompV4 }; template @@ -64,13 +63,6 @@ struct GemmPipelineTypeSelector using pipeline = ck_tile::GemmPipelineAgBgCrCompV4; }; -template -struct GemmPipelineTypeSelector -{ - using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync; - using pipeline = ck_tile::GemmPipelineAgBgCrCompAsync; -}; - template void try_run(ck_tile::TailNumber tn) { @@ -105,10 +97,7 @@ class TestCkTileGemmPipeline : public ::testing::Test // TODO: This should be parameterized in tests constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = - (PipelineType == GemmPipelineType::CompV4 || PipelineType == GemmPipelineType::Async) - ? 32 - : 64; + constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -122,10 +111,7 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr bool kPadN = PadN; constexpr bool kPadK = PadK; - constexpr bool DoubleSmemBuffer = - (PipelineType == GemmPipelineType::CompV4 || PipelineType == GemmPipelineType::Async) - ? true - : false; + constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false; // TODO: For now - but this should also be a test parameter constexpr bool TransposeC = false;