Enable Async Copy for MI355 (#2425)

* add for async load builtin

* add async load api

* fix some compiling errors

* fix a compiling error

* fix some compiling errors

* add a pipeline which copies from v4

* add a new pipeline for async load

* fix some compiling errors

* add async load tests

* fix some issues in async load

* fix

* fix async inline assembly

* fix async inline assembly

* add ignore header file

* comment some not gfx950 codes

* comment some not gfx950 codes

* fix a error

* update async load apis

* fix lds descriptor

* fix a compiling error

* fix some compiling errors

* fix a descriptor issue

* update lds descriptor

* change async pipeline's tile distribution pattern from thread to warp

* fix clang format

* update async policy

* fix a CRTP issue

* fix a typo error

* change lds layout

* fix some sync issues

* improve codes

* delete the async test

* fix a commented format issue

* avoid compiling device functions when compile host

* make gemm run

* add the copy kernel support

* finish the feature

* Address comment

* add the support for buffer_builtin

* solved the merging problem

* Comment Addressed

---------

Co-authored-by: joye <joye@amd.com>
Co-authored-by: joyeamd <John.Ye@amd.com>

[ROCm/composable_kernel commit: f240ae3248]
This commit is contained in:
Thomas Ning
2025-07-07 10:08:49 -07:00
committed by GitHub
parent 67545a9d22
commit a01042c3cf
12 changed files with 225 additions and 143 deletions

View File

@@ -29,6 +29,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166)
* Added Vectorize Transpose optimization for CK Tile (#2131)
* Added the asynchronous copy for gfx950 (#2425)
### Fixes

View File

@@ -15,7 +15,6 @@
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
// temporary workaround to get k_warp_tile based on PrecType and gfx950 or not
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{

View File

@@ -53,16 +53,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
x_buf.ToDevice(x_host.data());
using BlockWaves = ck_tile::sequence<2, 1>;
using BlockTile = ck_tile::sequence<64, 8>;
using WaveTile = ck_tile::sequence<64, 8>;
using Vector = ck_tile::sequence<1, 4>;
using BlockWaves = ck_tile::sequence<2, 1>;
using BlockTile = ck_tile::sequence<64, 8>;
using WaveTile = ck_tile::sequence<64, 8>;
using Vector = ck_tile::sequence<1, 2>;
constexpr bool AsyncCopy = true;
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
std::cout << "grid size " << kGridSize << std::endl;
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, Vector>;
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
using Problem = ck_tile::TileCopyProblem<XDataType, Shape, AsyncCopy>;
using Kernel = ck_tile::TileCopy<Problem>;
constexpr ck_tile::index_t kBlockSize = 128;

View File

@@ -50,11 +50,12 @@ struct TileCopyShape
static_assert(WaveGroupSize == WarpPerBlock_M * WarpPerBlock_N, "Inconsisten wave group size!");
};
template <typename XDataType_, typename BlockShape_>
template <typename XDataType_, typename BlockShape_, bool AsyncCopy_>
struct TileCopyProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using XDataType = remove_cvref_t<XDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool AsyncCopy = AsyncCopy_;
};
template <typename Problem_>
@@ -63,6 +64,8 @@ struct TileCopy
using Problem = ck_tile::remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
static constexpr bool AsyncCopy = Problem::AsyncCopy;
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution()
{
@@ -156,17 +159,29 @@ struct TileCopy
if(my_id == warp_id)
{
// load from DRAM to registers
load_tile(dram_tile, x_block_window);
if constexpr(AsyncCopy)
{
async_load_tile(x_block_lds_window_no_dist, x_block_window);
// store in lds
store_tile(x_block_lds_window_no_dist, dram_tile);
load_tile(dram_tile, x_block_lds_window);
// read from lds to registers
load_tile(dram_tile, x_block_lds_window);
// store from registers to DRAM
store_tile(y_block_window, dram_tile);
}
else
{
// load from DRAM to registers
load_tile(dram_tile, x_block_window);
// store from registers to DRAM
store_tile(y_block_window, dram_tile);
// store in lds
store_tile(x_block_lds_window_no_dist, dram_tile);
// read from lds to registers
load_tile(dram_tile, x_block_lds_window);
// store from registers to DRAM
store_tile(y_block_window, dram_tile);
}
}
__syncthreads();
move_tile_window(x_block_window, {0, S::Block_N});

View File

@@ -13,6 +13,7 @@
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/ignore.hpp"
// This attribute gives a hint to the compiler that a branch is likely to be taken.
// Then, the compiler should remove if possible the associated s_cbranch_execz branch that would
@@ -23,6 +24,8 @@
#define LIKELY(x) (__builtin_expect(!!(x), 1))
#endif
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
namespace ck_tile {
// 128 bit SGPRs to supply buffer resource in buffer instructions
@@ -1270,7 +1273,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
as3_uint32_ptr lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
@@ -1749,7 +1752,7 @@ template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
@@ -1779,29 +1782,61 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
constexpr index_t bytes = sizeof(T) * N;
#if defined(__gfx950__)
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"wrong! only support in dword, dwordx3, dwordx4");
ignore = src_wave_addr_offset;
ignore = src_immediate_addr_offset;
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
0,
0,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
0,
0,
static_cast<index_t>(coherence));
}
#else
static_assert(bytes == 4, "wrong! not implemented vector size");
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
#endif
}
template <index_t N,
@@ -2775,9 +2810,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
: "memory");
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
as3_uint32_ptr lds_ptr =
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);

View File

@@ -14,6 +14,8 @@
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
namespace ck_tile {
// 128 bit SGPRs to supply buffer resource in buffer instructions
@@ -1138,7 +1140,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
as3_uint32_ptr lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
@@ -1549,29 +1551,61 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
constexpr index_t bytes = sizeof(T) * N;
#if defined(__gfx950__)
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"wrong! only support in dword, dwordx3, dwordx4");
ignore = src_wave_addr_offset;
ignore = src_immediate_addr_offset;
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
0,
0,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
0,
0,
static_cast<index_t>(coherence));
}
#else
static_assert(bytes == 4, "wrong! not implemented vector size");
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr t*>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
#endif
}
template <index_t N,
@@ -2545,9 +2579,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
: "memory");
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
as3_uint32_ptr lds_ptr =
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);

View File

@@ -452,10 +452,12 @@ struct buffer_view<address_space_enum::global,
"wrong! X should contain multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_data_, (buffer_size_) * sizeof(type));
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
smem,
cached_buf_res_,
src_wave_buffer_resource,
i,
linear_offset,
is_valid_element,

View File

@@ -89,6 +89,19 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,

View File

@@ -161,7 +161,8 @@ struct tensor_view
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset) const
index_t linear_offset,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template async_get<X>(
smem,
@@ -181,7 +182,8 @@ struct tensor_view
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element) const
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template async_get<X>(smem,
coord.get_offset() / PackedSize,

View File

@@ -344,64 +344,52 @@ struct tile_window_with_static_distribution
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{}));
constexpr index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
size_per_buf;
constexpr index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using Traits = typename Base::Traits;
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// Precompute invariant values outside loops
const auto window_origin = lds_tile.get_window_origin();
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// read from bottom tensor
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
// Use precomputed window origin
auto lds_bottom_tensor_thread_idx =
window_origin + window_adaptor_thread_coord.get_bottom_index();
// move thread coordinate
// Use precomputed tensor descriptor
const auto lds_coord =
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
// Calculate SMEM address using base pointer
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
// Write into bottom tensor
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
bottom_tensor_thread_coord,
number<0>{},
bool_constant<oob_conditional_check>{});
// Move thread coordinate if not last access
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
smem += size_per_issue; // Note we manually increase the per-issue offset
}
});
});

View File

@@ -186,7 +186,7 @@ struct tile_window_linear
const typename Base::WindowLengths& window_lengths,
const typename Base::BottomTensorIndex& window_origin,
const typename Base::TileDstr& tile_distribution)
: cached_coords_{}, cached_flags_{}
: cached_coords_{}, cached_window_adaptor_coords_{}, cached_flags_{}
{
this->bottom_tensor_view_ = bottom_tensor_view;
this->window_lengths_ = window_lengths;
@@ -214,7 +214,8 @@ struct tile_window_linear
if constexpr(need_save_non_linear_coord)
{
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
}
// TODO: need pad_tensor_view to check which dim need use flag to check
@@ -554,61 +555,42 @@ struct tile_window_linear
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
using vector_t = typename traits::vector_t;
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert(NumAccess_NonLinear == NumAccess);
static_assert(NumAccess_NonLinear == NumAccess, "Unsupported configuration");
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global);
address_space_enum::global,
"Requires global memory");
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// Precompute invariant values outside the lambda
const auto window_origin = lds_tile.get_window_origin();
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{}));
constexpr index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
size_per_buf;
constexpr index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using vector_t = typename Base::Traits::vector_t;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
// Use precomputed values
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto window_adaptor_coord = cached_window_adaptor_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
// read from bottom tensor
auto lds_bottom_tensor_thread_idx =
window_origin + window_adaptor_coord.get_bottom_index();
const auto lds_coord =
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
// Read from bottom tensor
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
bottom_tensor_thread_coord,
0,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(i_access_ != (NumAccess - 1))
{
smem += size_per_issue; // Note we manually increase the per-issue offset
}
};
WINDOW_DISPATCH_ISSUE();
@@ -928,7 +910,8 @@ struct tile_window_linear
if constexpr(need_save_non_linear_coord)
{
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
}
if constexpr(i_access != (NumAccess - 1))
@@ -948,6 +931,8 @@ struct tile_window_linear
// this contains:
array<typename Base::BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
array<typename Base::WindowAdaptorCoord, traits::NumAccess_NonLinear>
cached_window_adaptor_coords_;
array<bool, Base::Traits::NumAccess> cached_flags_;
};

View File

@@ -32,6 +32,15 @@ struct GemmPipelineAgBgCrImplBase
move_tile_window(dram_tile_window, dram_tile_window_step);
}
template <typename DstBlockWindow, typename SrcTileWindow, typename DramTileWindowStep>
CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow& dst_block_window,
SrcTileWindow& dram_tile_window,
const DramTileWindowStep& dram_tile_window_step) const
{
async_load_tile(dst_block_window, dram_tile_window);
move_tile_window(dram_tile_window, dram_tile_window_step);
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,