mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Comment Addressed
This commit is contained in:
@@ -29,6 +29,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
|
||||
* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166)
|
||||
* Added Vectorize Transpose optimization for CK Tile (#2131)
|
||||
* Added the asynchronous copy for gfx950 (#2425)
|
||||
|
||||
|
||||
### Fixes
|
||||
|
||||
@@ -159,7 +159,16 @@ struct TileCopy
|
||||
|
||||
if(my_id == warp_id)
|
||||
{
|
||||
if constexpr(AsyncCopy == false)
|
||||
if constexpr(AsyncCopy)
|
||||
{
|
||||
async_load_tile(x_block_lds_window_no_dist, x_block_window);
|
||||
|
||||
load_tile(dram_tile, x_block_lds_window);
|
||||
|
||||
// store from registers to DRAM
|
||||
store_tile(y_block_window, dram_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
// load from DRAM to registers
|
||||
load_tile(dram_tile, x_block_window);
|
||||
@@ -170,15 +179,6 @@ struct TileCopy
|
||||
// read from lds to registers
|
||||
load_tile(dram_tile, x_block_lds_window);
|
||||
|
||||
// store from registers to DRAM
|
||||
store_tile(y_block_window, dram_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
async_load_tile(x_block_lds_window_no_dist, x_block_window);
|
||||
|
||||
load_tile(dram_tile, x_block_lds_window);
|
||||
|
||||
// store from registers to DRAM
|
||||
store_tile(y_block_window, dram_tile);
|
||||
}
|
||||
|
||||
@@ -24,6 +24,8 @@
|
||||
#define LIKELY(x) (__builtin_expect(!!(x), 1))
|
||||
#endif
|
||||
|
||||
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
@@ -1271,7 +1273,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
as3_uint32_ptr lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
@@ -1791,8 +1793,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(smem)),
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
0,
|
||||
@@ -1803,8 +1804,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(smem)),
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
@@ -1818,8 +1818,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(smem)),
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
@@ -1830,8 +1829,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(smem)),
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
@@ -2812,9 +2810,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
: "memory");
|
||||
#else
|
||||
// LDS pointer must be attributed with the LDS address space.
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
as3_uint32_ptr lds_ptr =
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
@@ -1138,7 +1140,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
as3_uint32_ptr lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
@@ -1560,8 +1562,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(smem)),
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
0,
|
||||
@@ -1572,8 +1573,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(smem)),
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
@@ -1587,8 +1587,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(smem)),
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
@@ -1599,8 +1598,7 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(smem)),
|
||||
reinterpret_cast<as3_uint32_ptr t*>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
@@ -2581,9 +2579,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
: "memory");
|
||||
#else
|
||||
// LDS pointer must be attributed with the LDS address space.
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
as3_uint32_ptr lds_ptr =
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
||||
|
||||
@@ -349,34 +349,41 @@ struct tile_window_with_static_distribution
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
// Precompute invariant values outside loops
|
||||
const auto window_origin = lds_tile.get_window_origin();
|
||||
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
|
||||
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
|
||||
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
lds_tile.get_window_origin() + window_adaptor_thread_coord.get_bottom_index();
|
||||
|
||||
const auto lds_coord = make_tensor_coordinate(
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor(),
|
||||
lds_bottom_tensor_thread_idx);
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ +
|
||||
lds_coord.get_offset();
|
||||
// write into bottom tensor
|
||||
// Use precomputed window origin
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
window_origin + window_adaptor_thread_coord.get_bottom_index();
|
||||
|
||||
// Use precomputed tensor descriptor
|
||||
const auto lds_coord =
|
||||
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
|
||||
|
||||
// Calculate SMEM address using base pointer
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
|
||||
|
||||
// Write into bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
number<0>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
// Move thread coordinate if not last access
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
@@ -557,31 +557,34 @@ struct tile_window_linear
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
using vector_t = typename traits::vector_t;
|
||||
|
||||
// currently we only support everything is non linear dim
|
||||
// actually it's not performant if we have linear dim(e.g. fast changing)
|
||||
static_assert(NumAccess_NonLinear == NumAccess);
|
||||
static_assert(NumAccess_NonLinear == NumAccess, "Unsupported configuration");
|
||||
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global);
|
||||
address_space_enum::global,
|
||||
"Requires global memory");
|
||||
|
||||
// Precompute invariant values outside the lambda
|
||||
const auto window_origin = lds_tile.get_window_origin();
|
||||
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
|
||||
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
|
||||
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
|
||||
// Use precomputed values
|
||||
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
|
||||
auto window_adaptor_coord = cached_window_adaptor_coords_[non_linear_id];
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess];
|
||||
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
lds_tile.get_window_origin() + window_adaptor_coord.get_bottom_index();
|
||||
|
||||
window_origin + window_adaptor_coord.get_bottom_index();
|
||||
const auto lds_coord =
|
||||
make_tensor_coordinate(lds_tile.get_bottom_tensor_view().get_tensor_descriptor(),
|
||||
lds_bottom_tensor_thread_idx);
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ +
|
||||
lds_coord.get_offset();
|
||||
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
|
||||
|
||||
// read from bottom tensor
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
|
||||
|
||||
// Read from bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
@@ -589,6 +592,7 @@ struct tile_window_linear
|
||||
bottom_tensor_flag,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
};
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user