diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 2c9b227124..ba7eeb1936 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -627,16 +627,9 @@ struct tile_window_with_static_distribution const auto lds_coord = make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx); - constexpr auto IMM_RANGE = - (1 << 12) / sizeof(typename Base::DataType) * Traits::PackedSize; - constexpr auto imm_total = lds_ys_offset; - constexpr auto imm_valid = imm_total % IMM_RANGE; - constexpr auto imm_overflow = imm_total - imm_valid; - // Calculate SMEM address using base pointer - CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr + - lds_coord.get_offset() / Traits::PackedSize + - imm_overflow / Traits::PackedSize; + CK_TILE_LDS_ADDR LdsDataType* smem = + lds_base_ptr + (lds_coord.get_offset() + lds_ys_offset) / Traits::PackedSize; const auto dram_ys_offset = [&]() { if constexpr(static_move_ys) @@ -656,13 +649,14 @@ struct tile_window_with_static_distribution offset + dram_ys_offset, bool_constant{}); else + { this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord.get_offset() + offset, - dram_ys_offset - imm_valid, - number{}, + dram_ys_offset, + number<0>{}, bool_constant{}); - + } // Move thread coordinate if not last access if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) {