diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 36a5b49bf7..e4045870d2 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -214,6 +214,34 @@ CK_TILE_DEVICE void async_load_tile_raw_long(LdsTileWindow_&& lds_tile, bool_constant{}); } +// Variant of async_load_tile_raw that dispatches to +// async_load_raw_lazy_rebase: the fast buffer_load_dword_lds path, but with +// a wave-uniform SRD base pointer that is lazily re-anchored whenever the +// per-issue page offset would otherwise overflow int32 voffsets. Lifts the +// 4 GB cache-pool limit of the regular async_load_tile_raw without paying +// the per-lane 64-bit base cost of async_load_tile_raw_long. Requires the +// tile_window to have been set up with init_raw_lazy_rebase() and the +// WaveSpanInN <= runtime page_size precondition documented on +// async_load_raw_lazy_rebase. The tile_window is passed by non-const +// reference because the rebase mutates its SRD state. +template +CK_TILE_DEVICE void async_load_tile_raw_lazy_rebase( + LdsTileWindow_&& lds_tile, + TileWindow_& tile_window, + number = {}, + bool_constant = {}, + bool_constant = {}) +{ + tile_window.async_load_raw_lazy_rebase(lds_tile, + number{}, + bool_constant{}, + bool_constant{}); +} + CK_TILE_DEVICE void async_load_fence(index_t cnt = 0) { asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 2730310e20..4d547a21cb 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -719,6 +719,170 @@ struct tile_scatter_gather }); } + // ------------------------------------------------------------------ + // Variant of async_load_raw that lazily re-anchors the wave-uniform SRD + // base pointer so per-lane voffsets stay within int32 range even when + // the total cache pool exceeds 4 GB. For every load issue: + // + // 1. read the per-lane absolute page offset (long_index_t, in + // elements of DataType); + // 2. take lane-0's value as a wave-uniform anchor candidate via + // amd_wave_read_first_lane(); + // 3. if (wave_anchor - cur_anchor_) is outside [0, kRebaseThreshold) + // shift the SRD base pointer to p_data_orig_ + wave_anchor and + // reinit the buffer resource; update cur_anchor_ accordingly; + // 4. issue the buffer_load with voffset = (lane_page_offset - + // cur_anchor_), which is guaranteed to fit in int32 (after the + // *sizeof(T) byte scaling inside amd_async_buffer_load_with_oob_raw). + // + // Correctness precondition: within a single issue every lane of the + // wave must map to the same physical page block, i.e. + // WaveSpanInN <= runtime page_size + // Under this precondition the per-lane spread relative to the + // wave-uniform anchor is bounded by page_size * row_stride * sizeof(T), + // which fits comfortably in the half-INT32 element-window we leave + // (kRebaseThreshold below). When the precondition does not hold use + // async_load_raw_long instead. + // + // Fast path (no overflow this issue): one wave-read, one 64-bit + // subtract, one compare-branch. Branch is wave-uniform; rebase rate is + // low so the branch is well predicted by the SIMD scheduler. + // + // This method is non-const because it mutates bottom_tensor_view_ + // (rebase) and cur_anchor_ (anchor tracking). Use after + // init_raw_lazy_rebase(). + template + CK_TILE_DEVICE auto async_load_raw_lazy_rebase( + LdsTileWindow_&& lds_tile, + number = {}, + bool_constant = {}, + bool_constant = {}) + { + using LdsTileWindow = remove_cvref_t; + using LdsDataType = typename LdsTileWindow::DataType; + + // issues * warps * lanes + static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + + const index_t size_per_buf = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType); + + const index_t size_per_wave = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<1>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t size_per_issue = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<1>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + m0_set_with_memory(amd_wave_read_first_lane(m0_init_value)); + + using Traits = load_store_traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; + + // The buffer-load builtin scales the element offset by sizeof(DataType) + // and feeds the result to a 32-bit voffset. To keep the byte offset + // within INT32_MAX *for any active lane in the wave*, leave a margin + // of half the element window for per-lane spread relative to lane-0. + constexpr long_index_t kInt32ElemWindow = + static_cast(INT32_MAX) / static_cast(sizeof(DataType)); + constexpr long_index_t kRebaseThreshold = kInt32ElemWindow / 2; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); + + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_gather = get_gather_index(idx_ys_start); + + // Per-lane absolute page offset (in elements of DataType). + const long_index_t lane_page_offset = + static_cast(page_idx_[idx_gather]); + + // Wave-uniform anchor candidate: lane-0's value (or first + // active lane). Promoted to SGPRs by the readfirstlane. + const long_index_t wave_anchor = amd_wave_read_first_lane(lane_page_offset); + + // Lazy rebase: only when the wave-uniform anchor has drifted + // outside the current int32 voffset window around cur_anchor_. + const long_index_t rel = wave_anchor - cur_anchor_; + if(rel < 0 || rel >= kRebaseThreshold) + { + cur_anchor_ = wave_anchor; + bottom_tensor_view_.buf_.p_data_ = p_data_orig_ + cur_anchor_; + using BufSizeT = + remove_cvref_t; + bottom_tensor_view_.buf_.buffer_size_ = + static_cast(buffer_size_orig_ - cur_anchor_); + bottom_tensor_view_.init_raw(); + } + + // Per-lane voffset relative to (possibly new) cur_anchor_. + // Fits in int32 by construction (kRebaseThreshold + spread). + const index_t lane_voffset = + static_cast(lane_page_offset - cur_anchor_); + + // read from bottom tensor + if constexpr(std::is_same_v) + { + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, bottom_tensor_thread_coord, lane_voffset, 0, pre_nop_); + } + else + { + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, + bottom_tensor_thread_coord, + lane_voffset, + valids_[idx_gather], + 0, + pre_nop_); + } + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, + number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + m0_inc_with_memory(size_per_issue); + } + }); + }); + } + // TODO: fix with swizzle template (bottom_tensor_view_.buf_.buffer_size_); + cur_anchor_ = 0; + bottom_tensor_view_.init_raw(); + } + // this is the bottom tensor view // [x0', x1', ...] ==> [offset] BottomTensorView bottom_tensor_view_; @@ -1302,6 +1481,20 @@ struct tile_scatter_gather array, NumCoord>, std::byte> pre_computed_warp_coords_; + + // State used by async_load_raw_lazy_rebase(). Populated by + // init_raw_lazy_rebase(); ignored by all other load paths. + // p_data_orig_ : original SRD base pointer (never mutated post-init) + // buffer_size_orig_ : original SRD size in elements of DataType + // cur_anchor_ : current wave-uniform SRD shift (in elements, + // relative to p_data_orig_); kept in SGPRs as the + // value is only ever assigned from + // amd_wave_read_first_lane(...). When non-zero, + // bottom_tensor_view_.buf_.p_data_ == + // p_data_orig_ + cur_anchor_. + typename BottomTensorView::buffer_view::type* p_data_orig_ = nullptr; + long_index_t buffer_size_orig_ = 0; + long_index_t cur_anchor_ = 0; }; // TODO: use strategy diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index b79f1190bd..709e7ae23c 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -829,7 +829,15 @@ struct UnifiedAttentionPipeline {0, 0}, k_dist, k_page_offsets); - k_dram_window.init_raw(); + // Use the lazy-rebase-aware init when overflow is possible so the + // rebase path has the original SRD base/size captured. The fast + // path is unaffected: init_raw_lazy_rebase() ends by calling + // init_raw() so the short load path is still valid until the + // first rebase fires. + if(cache_ptr_int32_overflow_possible) + k_dram_window.init_raw_lazy_rebase(); + else + k_dram_window.init_raw(); auto v_dram_window = make_tile_scatter_gather(v_view, @@ -837,7 +845,10 @@ struct UnifiedAttentionPipeline {0, 0}, v_dist, v_page_offsets); - v_dram_window.init_raw(); + if(cache_ptr_int32_overflow_possible) + v_dram_window.init_raw_lazy_rebase(); + else + v_dram_window.init_raw(); // prefetch K tile constexpr index_t k0_loops = 1; @@ -940,27 +951,29 @@ struct UnifiedAttentionPipeline // // Two load paths, dispatched on the runtime overflow flag: // - false: `async_load_tile_raw` → `buffer_load_dword_lds` with a - // wave-uniform 4 GB-capped SRD. Faster, but per-lane voffsets - // are int32 so the path is only correct while + // wave-uniform 4 GB-capped SRD. Fastest path; only correct when // `num_blocks * page_size * row_stride * sizeof(T) ≤ INT32_MAX`. - // - true: `async_load_tile_raw_long` → `global_load_lds_dwordx*` - // with per-lane 64-bit base pointers, lifting the 4 GB limit - // at the cost of lower throughput. + // - true: `async_load_tile_raw_lazy_rebase` → still + // `buffer_load_dword_lds`, but with a wave-uniform SRD base + // pointer that is lazily re-anchored at each issue whenever + // the per-lane page offset would otherwise overflow the int32 + // voffset. Lifts the 4 GB cache-pool limit without paying the + // per-lane 64-bit base cost of `async_load_tile_raw_long`. + // Precondition: WaveSpanInN ≤ runtime page_size (so within a + // single issue every lane of the wave maps to the same physical + // page block and the per-lane spread relative to the + // wave-uniform anchor stays inside a half-INT32 element window). + // If the precondition fails, swap this back to + // `async_load_tile_raw_long` (per-lane 64-bit `global_load_lds`). // The branch is on a wave-uniform value, so no execution divergence. // - // We tried a third "per-issue SRD rebase" path - // (`async_load_tile_raw_rebased`: buffer_load_dword_lds with a - // per-issue SRD whose 48-bit base absorbs the wave-uniform page - // offset, valid when WaveSpanInN ≤ runtime page_size). It was - // correct on every big-cache decode shape tested but came out at - // best tied with the long path and at worst ~6% slower (e.g. - // b=1 sk=1M d=64: 2.46 ms vs 2.32 ms; b=8 sk=200k d=128: 2.12 ms - // vs 2.02 ms — see git log for the full table). These workloads - // are compute / softmax bound, not K/V-load bandwidth bound, so - // the buffer_load vs global_load_lds throughput edge never - // materialises, while per-issue SRD construction adds real SGPR - // pressure. The rebased helper has been removed to keep the - // dispatch (and emitted kernel size) minimal. + // History: an earlier "per-issue SRD rebase" path (rebase on every + // issue regardless of whether overflow was imminent) was tested and + // came out at best tied with the long path and at worst ~6% slower + // because per-issue SRD construction adds real SGPR pressure on + // compute/softmax-bound shapes. The current `_lazy_rebase` only + // rebases when the wave anchor drifts outside the current int32 + // voffset window, keeping the fast path register-cheap. constexpr index_t KWaveSpanInN = (KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] - 1) * KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}] + @@ -969,7 +982,8 @@ struct UnifiedAttentionPipeline auto K_mem_load = [&](auto k_lds_write_idx) { if(cache_ptr_int32_overflow_possible) - async_load_tile_raw_long(k_lds_window_store(k_lds_write_idx), k_dram_window); + async_load_tile_raw_lazy_rebase(k_lds_window_store(k_lds_write_idx), + k_dram_window); else async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); k_block_idx++; @@ -979,7 +993,8 @@ struct UnifiedAttentionPipeline auto V_mem_load = [&](auto v_lds_write_idx) { if(cache_ptr_int32_overflow_possible) - async_load_tile_raw_long(v_lds_window_store(v_lds_write_idx), v_dram_window); + async_load_tile_raw_lazy_rebase(v_lds_window_store(v_lds_write_idx), + v_dram_window); else async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); v_block_idx++;