Finished the feature

This commit is contained in:
ThomasNing
2025-06-29 01:02:47 -05:00
parent 77b05ed9a3
commit c2ab44bf2f
5 changed files with 73 additions and 51 deletions

View File

@@ -1797,14 +1797,14 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
reinterpret_cast<uintptr_t>(smem)),
sizeof(uint32_t),
v_offset,
0,
0,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( ),
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
src_wave_addr_offset,
@@ -2798,7 +2798,7 @@ template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
{
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
static_assert(__has_builtin(__builtin_amdgcn_ds_read_tr16_b64_v4f16),
"We need to have the compatible compiler version to build this instruction");
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
{

View File

@@ -89,6 +89,19 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,

View File

@@ -161,7 +161,8 @@ struct tensor_view
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset) const
index_t linear_offset,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template async_get<X>(
smem,
@@ -181,7 +182,8 @@ struct tensor_view
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element) const
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template async_get<X>(smem,
coord.get_offset() / PackedSize,

View File

@@ -345,8 +345,44 @@ struct tile_window_with_static_distribution
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
auto lds_bottom_tensor_thread_idx =
lds_tile.get_window_origin() + window_adaptor_thread_coord.get_bottom_index();
const auto lds_coord = make_tensor_coordinate(
lds_tile.get_bottom_tensor_view().get_tensor_descriptor(),
lds_bottom_tensor_thread_idx);
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ +
lds_coord.get_offset();
// write into bottom tensor
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to

View File

@@ -186,7 +186,7 @@ struct tile_window_linear
const typename Base::WindowLengths& window_lengths,
const typename Base::BottomTensorIndex& window_origin,
const typename Base::TileDstr& tile_distribution)
: cached_coords_{}, cached_flags_{}
: cached_coords_{}, cached_window_adaptor_coords_{}, cached_flags_{}
{
this->bottom_tensor_view_ = bottom_tensor_view;
this->window_lengths_ = window_lengths;
@@ -554,63 +554,32 @@ struct tile_window_linear
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert(NumAccess_NonLinear == NumAccess);
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global);
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{}));
constexpr index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
size_per_buf;
constexpr index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using vector_t = typename Base::Traits::vector_t;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto window_adaptor_coord = cached_window_adaptor_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
// read from bottom tensor
auto lds_bottom_tensor_thread_idx =
lds_tile.get_window_origin() + window_adaptor_coord.get_bottom_index();
const auto lds_coord =
make_tensor_coordinate(lds_tile.get_bottom_tensor_view().get_tensor_descriptor(),
lds_bottom_tensor_thread_idx);
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ +
lds_coord.get_offset();
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
bottom_tensor_thread_coord,
0,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(i_access_ != (NumAccess - 1))
{
smem += size_per_issue; // Note we manually increase the per-issue offset
}
};
WINDOW_DISPATCH_ISSUE();
}
@@ -929,6 +898,7 @@ struct tile_window_linear
if constexpr(need_save_non_linear_coord)
{
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
}
if constexpr(i_access != (NumAccess - 1))
@@ -948,6 +918,7 @@ struct tile_window_linear
// this contains:
array<typename Base::BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
array<typename Base::WindowAdaptorCoord, traits::NumAccess_NonLinear> cached_window_adaptor_coords_;
array<bool, Base::Traits::NumAccess> cached_flags_;
};