mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Finished the feature
This commit is contained in:
@@ -1797,14 +1797,14 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
reinterpret_cast<uintptr_t>(smem)),
|
||||
sizeof(uint32_t),
|
||||
v_offset,
|
||||
0,
|
||||
0,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( ),
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
@@ -2798,7 +2798,7 @@ template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
{
|
||||
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
static_assert(__has_builtin(__builtin_amdgcn_ds_read_tr16_b64_v4f16),
|
||||
"We need to have the compatible compiler version to build this instruction");
|
||||
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
|
||||
{
|
||||
|
||||
@@ -89,6 +89,19 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.async_load(
|
||||
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
|
||||
@@ -161,7 +161,8 @@ struct tensor_view
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset) const
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template async_get<X>(
|
||||
smem,
|
||||
@@ -181,7 +182,8 @@ struct tensor_view
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element) const
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template async_get<X>(smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
|
||||
@@ -345,8 +345,44 @@ struct tile_window_with_static_distribution
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
lds_tile.get_window_origin() + window_adaptor_thread_coord.get_bottom_index();
|
||||
|
||||
const auto lds_coord = make_tensor_coordinate(
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor(),
|
||||
lds_bottom_tensor_thread_idx);
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ +
|
||||
lds_coord.get_offset();
|
||||
// write into bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
|
||||
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
|
||||
|
||||
@@ -186,7 +186,7 @@ struct tile_window_linear
|
||||
const typename Base::WindowLengths& window_lengths,
|
||||
const typename Base::BottomTensorIndex& window_origin,
|
||||
const typename Base::TileDstr& tile_distribution)
|
||||
: cached_coords_{}, cached_flags_{}
|
||||
: cached_coords_{}, cached_window_adaptor_coords_{}, cached_flags_{}
|
||||
{
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
this->window_lengths_ = window_lengths;
|
||||
@@ -554,63 +554,32 @@ struct tile_window_linear
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
|
||||
// currently we only support everything is non linear dim
|
||||
// actually it's not performant if we have linear dim(e.g. fast changing)
|
||||
static_assert(NumAccess_NonLinear == NumAccess);
|
||||
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global);
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
|
||||
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
|
||||
// check?)
|
||||
constexpr index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{}));
|
||||
|
||||
constexpr index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
constexpr index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
|
||||
// TODO: we force CK_TILE_LDS_ADDR
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
|
||||
auto window_adaptor_coord = cached_window_adaptor_coords_[non_linear_id];
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess];
|
||||
|
||||
// read from bottom tensor
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
lds_tile.get_window_origin() + window_adaptor_coord.get_bottom_index();
|
||||
|
||||
const auto lds_coord =
|
||||
make_tensor_coordinate(lds_tile.get_bottom_tensor_view().get_tensor_descriptor(),
|
||||
lds_bottom_tensor_thread_idx);
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ +
|
||||
lds_coord.get_offset();
|
||||
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
bottom_tensor_flag,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(i_access_ != (NumAccess - 1))
|
||||
{
|
||||
smem += size_per_issue; // Note we manually increase the per-issue offset
|
||||
}
|
||||
};
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
@@ -929,6 +898,7 @@ struct tile_window_linear
|
||||
if constexpr(need_save_non_linear_coord)
|
||||
{
|
||||
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
|
||||
cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
|
||||
}
|
||||
|
||||
if constexpr(i_access != (NumAccess - 1))
|
||||
@@ -948,6 +918,7 @@ struct tile_window_linear
|
||||
|
||||
// this contains:
|
||||
array<typename Base::BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
|
||||
array<typename Base::WindowAdaptorCoord, traits::NumAccess_NonLinear> cached_window_adaptor_coords_;
|
||||
array<bool, Base::Traits::NumAccess> cached_flags_;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user