Comment Addressed

This commit is contained in:
ThomasNing
2025-07-06 17:01:11 +00:00
parent 50d2e36380
commit 32b68ff886
6 changed files with 68 additions and 62 deletions

View File

@@ -29,6 +29,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166)
* Added Vectorize Transpose optimization for CK Tile (#2131)
* Added the asynchronous copy for gfx950 (#2425)
### Fixes

View File

@@ -159,7 +159,16 @@ struct TileCopy
if(my_id == warp_id)
{
if constexpr(AsyncCopy == false)
if constexpr(AsyncCopy)
{
async_load_tile(x_block_lds_window_no_dist, x_block_window);
load_tile(dram_tile, x_block_lds_window);
// store from registers to DRAM
store_tile(y_block_window, dram_tile);
}
else
{
// load from DRAM to registers
load_tile(dram_tile, x_block_window);
@@ -170,15 +179,6 @@ struct TileCopy
// read from lds to registers
load_tile(dram_tile, x_block_lds_window);
// store from registers to DRAM
store_tile(y_block_window, dram_tile);
}
else
{
async_load_tile(x_block_lds_window_no_dist, x_block_window);
load_tile(dram_tile, x_block_lds_window);
// store from registers to DRAM
store_tile(y_block_window, dram_tile);
}

View File

@@ -24,6 +24,8 @@
#define LIKELY(x) (__builtin_expect(!!(x), 1))
#endif
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
namespace ck_tile {
// 128 bit SGPRs to supply buffer resource in buffer instructions
@@ -1271,7 +1273,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
as3_uint32_ptr lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
@@ -1791,8 +1793,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(smem)),
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
0,
@@ -1803,8 +1804,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
{
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(smem)),
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
0,
@@ -1818,8 +1818,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(smem)),
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
src_wave_addr_offset,
@@ -1830,8 +1829,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
{
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(smem)),
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
src_wave_addr_offset,
@@ -2812,9 +2810,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
: "memory");
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
as3_uint32_ptr lds_ptr =
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);

View File

@@ -14,6 +14,8 @@
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
namespace ck_tile {
// 128 bit SGPRs to supply buffer resource in buffer instructions
@@ -1138,7 +1140,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
as3_uint32_ptr lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
@@ -1560,8 +1562,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(smem)),
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
0,
@@ -1572,8 +1573,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
{
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(smem)),
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
0,
@@ -1587,8 +1587,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(smem)),
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
src_wave_addr_offset,
@@ -1599,8 +1598,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
{
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(smem)),
reinterpret_cast<as3_uint32_ptr t*>(reinterpret_cast<uintptr_t>(smem)),
bytes,
src_thread_addr_offset,
src_wave_addr_offset,
@@ -2581,9 +2579,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
: "memory");
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
as3_uint32_ptr lds_ptr =
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);

View File

@@ -349,34 +349,41 @@ struct tile_window_with_static_distribution
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
// loop over thread tensor space [y0, y1, ...]
// Precompute invariant values outside loops
const auto window_origin = lds_tile.get_window_origin();
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
auto lds_bottom_tensor_thread_idx =
lds_tile.get_window_origin() + window_adaptor_thread_coord.get_bottom_index();
const auto lds_coord = make_tensor_coordinate(
lds_tile.get_bottom_tensor_view().get_tensor_descriptor(),
lds_bottom_tensor_thread_idx);
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ +
lds_coord.get_offset();
// write into bottom tensor
// Use precomputed window origin
auto lds_bottom_tensor_thread_idx =
window_origin + window_adaptor_thread_coord.get_bottom_index();
// Use precomputed tensor descriptor
const auto lds_coord =
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
// Calculate SMEM address using base pointer
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
// Write into bottom tensor
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
bottom_tensor_thread_coord,
number<0>{},
bool_constant<oob_conditional_check>{});
// move thread coordinate
// Move thread coordinate if not last access
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);

View File

@@ -557,31 +557,34 @@ struct tile_window_linear
using LdsDataType = typename LdsTileWindow::DataType;
using vector_t = typename traits::vector_t;
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert(NumAccess_NonLinear == NumAccess);
static_assert(NumAccess_NonLinear == NumAccess, "Unsupported configuration");
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global);
address_space_enum::global,
"Requires global memory");
// Precompute invariant values outside the lambda
const auto window_origin = lds_tile.get_window_origin();
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
// Use precomputed values
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto window_adaptor_coord = cached_window_adaptor_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
auto lds_bottom_tensor_thread_idx =
lds_tile.get_window_origin() + window_adaptor_coord.get_bottom_index();
window_origin + window_adaptor_coord.get_bottom_index();
const auto lds_coord =
make_tensor_coordinate(lds_tile.get_bottom_tensor_view().get_tensor_descriptor(),
lds_bottom_tensor_thread_idx);
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ +
lds_coord.get_offset();
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
// read from bottom tensor
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
// Read from bottom tensor
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
bottom_tensor_thread_coord,
@@ -589,6 +592,7 @@ struct tile_window_linear
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
};
WINDOW_DISPATCH_ISSUE();
}