diff --git a/example/ck_tile/36_copy/CMakeLists.txt b/example/ck_tile/36_copy/CMakeLists.txt deleted file mode 100644 index d1b9ba923c..0000000000 --- a/example/ck_tile/36_copy/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -add_executable(test_copy_kernel EXCLUDE_FROM_ALL test_copy.cpp) -target_compile_options(test_copy_kernel PRIVATE - -mllvm -enable-noalias-to-md-conversion=0 -) \ No newline at end of file diff --git a/example/ck_tile/36_copy/test_copy.cpp b/example/ck_tile/36_copy/test_copy.cpp deleted file mode 100644 index 4123408453..0000000000 --- a/example/ck_tile/36_copy/test_copy.cpp +++ /dev/null @@ -1,118 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck_tile/host.hpp" -#include -#include "test_copy.hpp" - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "64", "m dimension") - .insert("n", "8", "n dimension") - .insert("id", "0", "warp to use") - .insert("v", "1", "cpu validation or not") - .insert("prec", "fp16", "precision") - .insert("warmup", "50", "cold iter") - .insert("repeat", "100", "hot iter"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -template -bool run(const ck_tile::ArgParser& arg_parser) -{ - using XDataType = DataType; - using YDataType = DataType; - - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - ck_tile::index_t warp_id = arg_parser.get_int("id"); - int do_validation = arg_parser.get_int("v"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - - ck_tile::HostTensor x_host({m, n}); - ck_tile::HostTensor y_host_ref({m, n}); - ck_tile::HostTensor y_host_dev({m, n}); - - // ck_tile::FillConstant{1.f}(x_host); - ck_tile::half_t value = 1; - for(int i = 0; i < m; i++) - { - value = 1; - for(int j = 0; j < n; j++) - { - x_host(i, j) = value++; - } - } - - ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); - - x_buf.ToDevice(x_host.data()); - - using BlockWaves = ck_tile::sequence<2, 1>; - using BlockTile = ck_tile::sequence<64, 8>; - using WaveTile = ck_tile::sequence<64, 8>; - using Vector = ck_tile::sequence<1, 2>; - constexpr bool AsyncCopy = true; - - ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); - std::cout << "grid size " << kGridSize << std::endl; - - using Shape = ck_tile::TileCopyShape; - using Problem = ck_tile::TileCopyProblem; - using Kernel = ck_tile::TileCopy; - - constexpr ck_tile::index_t kBlockSize = 128; - constexpr ck_tile::index_t kBlockPerCu = 1; - std::cout << "block size " << kBlockSize << std::endl; - std::cout << "warp SIze " << ck_tile::get_warp_size() << std::endl; - std::cout << "warps per block _M " << Shape::WarpPerBlock_M << " " << Shape::WarpPerBlock_N - << std::endl; - std::cout << "Block waves: " << BlockWaves::at(ck_tile::number<0>{}) << " " - << BlockWaves::at(ck_tile::number<1>{}) << std::endl; - std::cout << " Wave Groups: " << Shape::WaveGroups << std::endl; - - float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, - ck_tile::make_kernel( - Kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(x_buf.GetDeviceBuffer()), - static_cast(y_buf.GetDeviceBuffer()), - m, - n, - warp_id)); - - std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; - - bool pass = true; - - if(do_validation) - { - // reference - y_buf.FromDevice(y_host_dev.mData.data()); - pass = ck_tile::check_err(y_host_dev, x_host); - - std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; - } - - return pass; -} - -int main(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - const std::string data_type = arg_parser.get_str("prec"); - return run(arg_parser) ? 0 : -2; -} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index db5cc71888..b317ed18aa 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -21,6 +21,5 @@ add_subdirectory(18_flatmm) add_subdirectory(19_gemm_multi_d) add_subdirectory(20_grouped_convolution) add_subdirectory(35_batched_transpose) -add_subdirectory(36_copy) add_subdirectory(37_transpose) add_subdirectory(38_block_scale_gemm) diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 3dd9604b01..e2a73e6242 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -10,6 +10,15 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" +#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111 +#define CK_TILE_VMCNT(cnt) \ + ([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \ + ((cnt)&0b1111) | (((cnt)&0b110000) << 10)) +#define CK_TILE_EXPCNT(cnt) \ + ([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4)) +#define CK_TILE_LGKMCNT(cnt) \ + ([]() { static_assert(!((cnt) >> 4), "LGKM only has 4 bits"); }(), ((cnt) << 8)) + namespace ck_tile { template @@ -113,13 +122,72 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0) #endif } +// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html +struct waitcnt_arg +{ + // bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210 + // [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV + CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111; + + CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111; + CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111; + CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b1111; + + template + CK_TILE_DEVICE static constexpr index_t from_vmcnt() + { + static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]"); + return MAX & ((cnt & 0b1111) | ((cnt & 0b110000) << 10)); + } + + template + CK_TILE_DEVICE static constexpr index_t from_expcnt() + { + static_assert(cnt >= 0 && !(cnt >> 3), "valid range is [0..7]"); + return MAX & (cnt << 4); + } + + template + CK_TILE_DEVICE static constexpr index_t from_lgkmcnt() + { + static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]"); + return MAX & (cnt << 8); + } +}; + +template +CK_TILE_DEVICE void s_waitcnt() +{ + __builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt() | + waitcnt_arg::from_expcnt() | + waitcnt_arg::from_lgkmcnt()); +} + +template +CK_TILE_DEVICE void s_waitcnt_barrier() +{ + s_waitcnt(); + __builtin_amdgcn_s_barrier(); +} + CK_TILE_DEVICE void block_sync_lds_direct_load() { +#if 1 + // invoke clang builtins which *should* produce the same result as the inline asm below + // difference: inline asm is being compiled to wait vmcnt(0) after the barrier + s_waitcnt_barrier<0, waitcnt_arg::kMaxExpCnt, 0>(); +#else + // same content as in old CK (#999) asm volatile("\ s_waitcnt vmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif } CK_TILE_DEVICE void s_nop(index_t cnt = 0) diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 8f3fbd52c5..fb566b2a00 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -11,6 +11,7 @@ add_subdirectory(data_type) add_subdirectory(permute) add_subdirectory(moe_sorting) add_subdirectory(slice_tile) +add_subdirectory(memory_copy) add_subdirectory(batched_transpose) add_subdirectory(smoothquant) add_subdirectory(topk_softmax) diff --git a/test/ck_tile/memory_copy/CMakeLists.txt b/test/ck_tile/memory_copy/CMakeLists.txt new file mode 100644 index 0000000000..5311e5060a --- /dev/null +++ b/test/ck_tile/memory_copy/CMakeLists.txt @@ -0,0 +1,3 @@ +if(GPU_TARGETS MATCHES "gfx950") + add_gtest_executable(test_memory_copy test_copy.cpp) +endif() diff --git a/example/ck_tile/36_copy/README.md b/test/ck_tile/memory_copy/README.md similarity index 100% rename from example/ck_tile/36_copy/README.md rename to test/ck_tile/memory_copy/README.md diff --git a/test/ck_tile/memory_copy/test_copy.cpp b/test/ck_tile/memory_copy/test_copy.cpp new file mode 100644 index 0000000000..e8962dce29 --- /dev/null +++ b/test/ck_tile/memory_copy/test_copy.cpp @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "test_copy.hpp" + +struct MemoryCopyParam +{ + MemoryCopyParam(ck_tile::index_t m_, ck_tile::index_t n_, ck_tile::index_t warp_id_) + : m(m_), n(n_), warp_id(warp_id_) + { + } + ck_tile::index_t m; + ck_tile::index_t n; + ck_tile::index_t warp_id; +}; + +template +class TestCkTileMemoryCopy : public ::testing::TestWithParam> +{ + protected: + void Run(const MemoryCopyParam& memcpy_params) + { + using XDataType = DataType; + using YDataType = DataType; + + ck_tile::index_t m = memcpy_params.m; + ck_tile::index_t n = memcpy_params.n; + ck_tile::index_t warp_id = memcpy_params.warp_id; + + constexpr auto dword_bytes = 4; + + if(n % (dword_bytes / sizeof(DataType)) != 0) + { + std::cerr << "n size should be multiple of dword_bytes" << std::endl; + } + + ck_tile::HostTensor x_host({m, n}); + ck_tile::HostTensor y_host_dev({m, n}); + std::cout << "input: " << x_host.mDesc << std::endl; + std::cout << "output: " << y_host_dev.mDesc << std::endl; + + ck_tile::index_t value = 1; + for(int i = 0; i < m; i++) + { + value = 1; + for(int j = 0; j < n; j++) + { + value = (value + 1) % 127; + x_host(i, j) = static_cast(value); + } + } + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + + using BlockWaves = ck_tile::sequence<2, 1>; + using BlockTile = ck_tile::sequence<64, 8>; + using WaveTile = ck_tile::sequence<64, 8>; + using Vector = ck_tile::sequence<1, dword_bytes / sizeof(DataType)>; + + ck_tile::index_t kGridSize = + ck_tile::integer_divide_ceil(m, BlockTile::at(ck_tile::number<0>{})); + + using Shape = ck_tile::TileCopyShape; + using Problem = ck_tile::TileCopyProblem; + using Kernel = ck_tile::TileCopy; + + constexpr ck_tile::index_t kBlockSize = 128; + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto ms = launch_kernel(ck_tile::stream_config{nullptr, true}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + m, + n, + warp_id)); + + auto bytes = 2 * m * n * sizeof(DataType); + std::cout << "elapsed: " << ms << " (ms)" << std::endl; + std::cout << (bytes * 1e-6 / ms) << " (GB/s)" << std::endl; + + // reference + y_buf.FromDevice(y_host_dev.mData.data()); + bool pass = ck_tile::check_err(y_host_dev, x_host); + + EXPECT_TRUE(pass); + } +}; + +class TestCkTileMemoryCopyHalfAsync : public TestCkTileMemoryCopy +{ +}; + +class TestCkTileMemoryCopyHalfSync : public TestCkTileMemoryCopy +{ +}; + +class TestCkTileMemoryCopyFloatAsync : public TestCkTileMemoryCopy +{ +}; + +class TestCkTileMemoryCopyFP8Async : public TestCkTileMemoryCopy +{ +}; + +TEST_P(TestCkTileMemoryCopyHalfAsync, TestCorrectness) +{ + auto [M, N, warp_id] = GetParam(); + this->Run({M, N, warp_id}); +} + +TEST_P(TestCkTileMemoryCopyHalfSync, TestCorrectness) +{ + auto [M, N, warp_id] = GetParam(); + this->Run({M, N, warp_id}); +} + +TEST_P(TestCkTileMemoryCopyFloatAsync, TestCorrectness) +{ + auto [M, N, warp_id] = GetParam(); + this->Run({M, N, warp_id}); +} + +TEST_P(TestCkTileMemoryCopyFP8Async, TestCorrectness) +{ + auto [M, N, warp_id] = GetParam(); + this->Run({M, N, warp_id}); +} + +INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, + TestCkTileMemoryCopyHalfAsync, + ::testing::Values(std::tuple{64, 8, 0}, + std::tuple{63, 8, 0}, + std::tuple{63, 2, 0}, + std::tuple{127, 30, 0}, + std::tuple{64, 8, 1}, + std::tuple{63, 8, 1}, + std::tuple{63, 2, 1}, + std::tuple{127, 30, 1}, + std::tuple{16384, 16384, 0}, + std::tuple{16384, 16384, 1})); + +INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, + TestCkTileMemoryCopyHalfSync, + ::testing::Values(std::tuple{64, 8, 0}, + std::tuple{63, 8, 0}, + std::tuple{63, 2, 0}, + std::tuple{127, 30, 0}, + std::tuple{64, 8, 1}, + std::tuple{63, 8, 1}, + std::tuple{63, 2, 1}, + std::tuple{127, 30, 1}, + std::tuple{16384, 16384, 0}, + std::tuple{16384, 16384, 1})); + +INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, + TestCkTileMemoryCopyFloatAsync, + ::testing::Values(std::tuple{64, 8, 0}, + std::tuple{63, 8, 0}, + std::tuple{63, 2, 0}, + std::tuple{127, 30, 0}, + std::tuple{64, 8, 1}, + std::tuple{63, 8, 1}, + std::tuple{63, 2, 1}, + std::tuple{127, 30, 1}, + std::tuple{16384, 16384, 0}, + std::tuple{16384, 16384, 1})); + +INSTANTIATE_TEST_SUITE_P(TestCkTileMemCopySuite, + TestCkTileMemoryCopyFP8Async, + ::testing::Values(std::tuple{64, 8, 0}, + std::tuple{63, 8, 0}, + std::tuple{63, 4, 0}, + std::tuple{127, 20, 0}, + std::tuple{64, 8, 1}, + std::tuple{63, 8, 1}, + std::tuple{63, 4, 1}, + std::tuple{127, 20, 1}, + std::tuple{16384, 16384, 0}, + std::tuple{16384, 16384, 1})); diff --git a/example/ck_tile/36_copy/test_copy.hpp b/test/ck_tile/memory_copy/test_copy.hpp similarity index 56% rename from example/ck_tile/36_copy/test_copy.hpp rename to test/ck_tile/memory_copy/test_copy.hpp index 0b3c87d472..a9840ba2c6 100644 --- a/example/ck_tile/36_copy/test_copy.hpp +++ b/test/ck_tile/memory_copy/test_copy.hpp @@ -14,14 +14,14 @@ namespace ck_tile { template typename BlockTile, // block size, seq typename WaveTile, // warp size, seq - typename Vector> // contiguous elements(vector size) along seq + typename Vector> // contiguous elements (vector size) along seq struct TileCopyShape { // We split Workgroup waves into two specialized groups. - // One for reading data from global -> LDS, the other is doing reduction + // One for reading data from global -> LDS, the other idling static constexpr index_t WaveGroups = 2; static constexpr index_t MWarps = BlockWaves::at(number<0>{}); - static constexpr index_t NWarps = BlockWaves::at(number<0>{}); + static constexpr index_t NWarps = BlockWaves::at(number<1>{}); static constexpr index_t Block_M = BlockTile::at(number<0>{}); static constexpr index_t Block_N = BlockTile::at(number<1>{}); @@ -35,10 +35,9 @@ struct TileCopyShape static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; - static constexpr index_t WarpPerBlock_M = - integer_divide_ceil(BlockWaves::at(number<0>{}), WaveGroups); - static constexpr index_t WarpPerBlock_N = - integer_divide_ceil(BlockWaves::at(number<1>{}), WaveGroups); + // We splitted the waves on M dimension + static constexpr index_t WarpPerBlock_M = integer_divide_ceil(MWarps, WaveGroups); + static constexpr index_t WarpPerBlock_N = NWarps; static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); @@ -47,7 +46,8 @@ struct TileCopyShape static constexpr index_t BlockSize = get_warp_size() * WaveNum; static constexpr index_t WaveGroupSize = WaveNum / WaveGroups; - static_assert(WaveGroupSize == WarpPerBlock_M * WarpPerBlock_N, "Inconsisten wave group size!"); + static_assert(WaveGroupSize == WarpPerBlock_M * WarpPerBlock_N, + "Inconsistent wave group size!"); }; template @@ -78,20 +78,21 @@ struct TileCopy S::Vector_N; // no. of elements along N dimensions to be read by each thread. constexpr index_t Y0 = - S::WaveNum / S::WaveGroups; // no. of active warps working in this thread block. - constexpr index_t Y1 = warp_size / X0; // no. of threads in a warp needed along M dimension. + S::WaveNum / S::WaveGroups; // number of active warps working in this thread block. constexpr index_t Y2 = + warp_size / X0; // number of threads in a warp needed along M dimension. + constexpr index_t Y1 = S::Warp_M / - (Y1 * - Y0); // no. of iterations each warp needs to perform to cover the entire tile window. + Y2; // number of iterations each warp needs to perform to cover the entire tile window. constexpr auto outer_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<0, 0>>, + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, sequence<1, 2>, sequence<1, 1>>{}; + return make_static_tile_distribution(outer_encoding); } @@ -100,90 +101,69 @@ struct TileCopy { using S = typename Problem::BlockShape; - // LDS Data. - __shared__ XDataType x_lds[number{} * number{}]; - XDataType* __restrict__ p_x_lds = static_cast(x_lds); + // LDS buffer + __shared__ XDataType x_lds[S::Block_M * S::Block_N]; + + constexpr auto block_dims = make_tuple(number{}, number{}); + constexpr auto block_strides = make_tuple(number{}, number<1>{}); const auto x_lds_desc = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number{}, number{}, 1), - number{}, - number<1>{}); + block_dims, block_strides, number{}, number<1>{}); - auto x_lds_block_desc = transform_tensor_descriptor( - x_lds_desc, - make_tuple(make_pass_through_transform(number{}), - make_merge_transform( - make_tuple(number{} / S::Vector_N, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + auto x_lds_view = make_tensor_view(x_lds, x_lds_desc); - auto x_lds_view = make_tensor_view(p_x_lds, x_lds_block_desc); + auto x_block_lds_write_window = make_tile_window(x_lds_view, block_dims, {0, 0}); - auto x_block_lds_window = - make_tile_window(x_lds_view, - make_tuple(number{}, number{}), - {0, 0}, - MakeDRAMDistribution()); - auto x_block_lds_window_no_dist = make_tile_window( - x_lds_view, make_tuple(number{}, number{}), {0, 0}); + auto x_block_lds_read_window = + make_tile_window(x_lds_view, block_dims, {0, 0}, MakeDRAMDistribution()); + const index_t iM = __builtin_amdgcn_readfirstlane(get_block_id() * S::Block_M); // Input tensor - const auto iM = get_block_id() * S::Block_M; const auto x_m_n = make_naive_tensor_view( p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); auto x_block_window = - make_tile_window(x_m_n, - make_tuple(number{}, number{}), - {iM, 0}, - MakeDRAMDistribution()); + make_tile_window(x_m_n, block_dims, {iM, 0}, MakeDRAMDistribution()); // Output tensor const auto y_m = make_naive_tensor_view( p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + auto y_block_window = make_tile_window(y_m, block_dims, {iM, 0}); - auto y_block_window = - make_tile_window(y_m, make_tuple(number{}, number{}), {iM, 0}); - - // Programming logic - index_t num_n_tile_iteration = + const index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); - auto my_id = get_warp_id(); - - auto DramTileDist = x_block_window.get_tile_distribution(); - using dram_reg_tile = decltype(make_static_distributed_tensor(DramTileDist)); - + const index_t my_id = __builtin_amdgcn_readfirstlane(get_warp_id()); + constexpr index_t async_copy_fence_cnt = 0; for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - dram_reg_tile dram_tile; - if(my_id == warp_id) { if constexpr(AsyncCopy) { - async_load_tile(x_block_lds_window_no_dist, x_block_window); - - load_tile(dram_tile, x_block_lds_window); - + async_load_tile(x_block_lds_write_window, x_block_window); + // We don't have prefetch here, wait the data back immediately. + // Wait all asyncload insts complete. + // Wait all waves synced + s_waitcnt_barrier(); + auto lds_tile = load_tile(x_block_lds_read_window); // store from registers to DRAM - store_tile(y_block_window, dram_tile); + store_tile(y_block_window, lds_tile); } else { // load from DRAM to registers - load_tile(dram_tile, x_block_window); - + auto dram_tile = load_tile(x_block_window); // store in lds - store_tile(x_block_lds_window_no_dist, dram_tile); - + store_tile(x_block_lds_write_window, dram_tile); + // Wait all lds write insts complete + // Wait all waves synced + block_sync_lds(); // read from lds to registers - load_tile(dram_tile, x_block_lds_window); - + auto lds_tile = load_tile(x_block_lds_read_window); // store from registers to DRAM - store_tile(y_block_window, dram_tile); + store_tile(y_block_window, lds_tile); } } - __syncthreads(); + move_tile_window(x_block_window, {0, S::Block_N}); move_tile_window(y_block_window, {0, S::Block_N}); }