mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Enable Async Copy for MI355 (#2425)
* add for async load builtin
* add async load api
* fix some compiling errors
* fix a compiling error
* fix some compiling errors
* add a pipeline which copies from v4
* add a new pipeline for async load
* fix some compiling errors
* add async load tests
* fix some issues in async load
* fix
* fix async inline assembly
* fix async inline assembly
* add ignore header file
* comment some not gfx950 codes
* comment some not gfx950 codes
* fix a error
* update async load apis
* fix lds descriptor
* fix a compiling error
* fix some compiling errors
* fix a descriptor issue
* update lds descriptor
* change async pipeline's tile distribution pattern from thread to warp
* fix clang format
* update async policy
* fix a CRTP issue
* fix a typo error
* change lds layout
* fix some sync issues
* improve codes
* delete the async test
* fix a commented format issue
* avoid compiling device functions when compile host
* make gemm run
* add the copy kernel support
* finish the feature
* Address comment
* add the support for buffer_builtin
* solved the merging problem
* Comment Addressed
---------
Co-authored-by: joye <joye@amd.com>
Co-authored-by: joyeamd <John.Ye@amd.com>
[ROCm/composable_kernel commit: f240ae3248]
This commit is contained in:
@@ -29,6 +29,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
|
||||
* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166)
|
||||
* Added Vectorize Transpose optimization for CK Tile (#2131)
|
||||
* Added the asynchronous copy for gfx950 (#2425)
|
||||
|
||||
|
||||
### Fixes
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
|
||||
// temporary workaround to get k_warp_tile based on PrecType and gfx950 or not
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
|
||||
@@ -53,16 +53,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
using BlockWaves = ck_tile::sequence<2, 1>;
|
||||
using BlockTile = ck_tile::sequence<64, 8>;
|
||||
using WaveTile = ck_tile::sequence<64, 8>;
|
||||
using Vector = ck_tile::sequence<1, 4>;
|
||||
using BlockWaves = ck_tile::sequence<2, 1>;
|
||||
using BlockTile = ck_tile::sequence<64, 8>;
|
||||
using WaveTile = ck_tile::sequence<64, 8>;
|
||||
using Vector = ck_tile::sequence<1, 2>;
|
||||
constexpr bool AsyncCopy = true;
|
||||
|
||||
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, Vector>;
|
||||
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
|
||||
using Problem = ck_tile::TileCopyProblem<XDataType, Shape, AsyncCopy>;
|
||||
using Kernel = ck_tile::TileCopy<Problem>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 128;
|
||||
|
||||
@@ -50,11 +50,12 @@ struct TileCopyShape
|
||||
static_assert(WaveGroupSize == WarpPerBlock_M * WarpPerBlock_N, "Inconsisten wave group size!");
|
||||
};
|
||||
|
||||
template <typename XDataType_, typename BlockShape_>
|
||||
template <typename XDataType_, typename BlockShape_, bool AsyncCopy_>
|
||||
struct TileCopyProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
static constexpr bool AsyncCopy = AsyncCopy_;
|
||||
};
|
||||
|
||||
template <typename Problem_>
|
||||
@@ -63,6 +64,8 @@ struct TileCopy
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
|
||||
static constexpr bool AsyncCopy = Problem::AsyncCopy;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution()
|
||||
{
|
||||
@@ -156,17 +159,29 @@ struct TileCopy
|
||||
|
||||
if(my_id == warp_id)
|
||||
{
|
||||
// load from DRAM to registers
|
||||
load_tile(dram_tile, x_block_window);
|
||||
if constexpr(AsyncCopy)
|
||||
{
|
||||
async_load_tile(x_block_lds_window_no_dist, x_block_window);
|
||||
|
||||
// store in lds
|
||||
store_tile(x_block_lds_window_no_dist, dram_tile);
|
||||
load_tile(dram_tile, x_block_lds_window);
|
||||
|
||||
// read from lds to registers
|
||||
load_tile(dram_tile, x_block_lds_window);
|
||||
// store from registers to DRAM
|
||||
store_tile(y_block_window, dram_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
// load from DRAM to registers
|
||||
load_tile(dram_tile, x_block_window);
|
||||
|
||||
// store from registers to DRAM
|
||||
store_tile(y_block_window, dram_tile);
|
||||
// store in lds
|
||||
store_tile(x_block_lds_window_no_dist, dram_tile);
|
||||
|
||||
// read from lds to registers
|
||||
load_tile(dram_tile, x_block_lds_window);
|
||||
|
||||
// store from registers to DRAM
|
||||
store_tile(y_block_window, dram_tile);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
move_tile_window(x_block_window, {0, S::Block_N});
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
// This attribute gives a hint to the compiler that a branch is likely to be taken.
|
||||
// Then, the compiler should remove if possible the associated s_cbranch_execz branch that would
|
||||
@@ -23,6 +24,8 @@
|
||||
#define LIKELY(x) (__builtin_expect(!!(x), 1))
|
||||
#endif
|
||||
|
||||
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
@@ -1270,7 +1273,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
as3_uint32_ptr lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
@@ -1749,7 +1752,7 @@ template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
@@ -1779,29 +1782,61 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
#if defined(__gfx950__)
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
ignore = src_wave_addr_offset;
|
||||
ignore = src_immediate_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#else
|
||||
static_assert(bytes == 4, "wrong! not implemented vector size");
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
@@ -2775,9 +2810,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
: "memory");
|
||||
#else
|
||||
// LDS pointer must be attributed with the LDS address space.
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
as3_uint32_ptr lds_ptr =
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
@@ -1138,7 +1140,7 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
as3_uint32_ptr lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
@@ -1549,29 +1551,61 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
#if defined(__gfx950__)
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
ignore = src_wave_addr_offset;
|
||||
ignore = src_immediate_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#else
|
||||
static_assert(bytes == 4, "wrong! not implemented vector size");
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
reinterpret_cast<as3_uint32_ptr t*>(reinterpret_cast<uintptr_t>(smem)),
|
||||
bytes,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
@@ -2545,9 +2579,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
: "memory");
|
||||
#else
|
||||
// LDS pointer must be attributed with the LDS address space.
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
as3_uint32_ptr lds_ptr =
|
||||
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
||||
|
||||
@@ -452,10 +452,12 @@ struct buffer_view<address_space_enum::global,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
const int32x4_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_data_, (buffer_size_) * sizeof(type));
|
||||
|
||||
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
|
||||
smem,
|
||||
cached_buf_res_,
|
||||
src_wave_buffer_resource,
|
||||
i,
|
||||
linear_offset,
|
||||
is_valid_element,
|
||||
|
||||
@@ -89,6 +89,19 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.async_load(
|
||||
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
|
||||
@@ -161,7 +161,8 @@ struct tensor_view
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset) const
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template async_get<X>(
|
||||
smem,
|
||||
@@ -181,7 +182,8 @@ struct tensor_view
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element) const
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template async_get<X>(smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
|
||||
@@ -344,64 +344,52 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
|
||||
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
|
||||
// check?)
|
||||
constexpr index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{}));
|
||||
|
||||
constexpr index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
constexpr index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
using Traits = typename Base::Traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
// TODO: we force CK_TILE_LDS_ADDR
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
|
||||
// Precompute invariant values outside loops
|
||||
const auto window_origin = lds_tile.get_window_origin();
|
||||
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
|
||||
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
|
||||
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// read from bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
// Use precomputed window origin
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
window_origin + window_adaptor_thread_coord.get_bottom_index();
|
||||
|
||||
// move thread coordinate
|
||||
// Use precomputed tensor descriptor
|
||||
const auto lds_coord =
|
||||
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
|
||||
|
||||
// Calculate SMEM address using base pointer
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
|
||||
|
||||
// Write into bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
number<0>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// Move thread coordinate if not last access
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
smem += size_per_issue; // Note we manually increase the per-issue offset
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -186,7 +186,7 @@ struct tile_window_linear
|
||||
const typename Base::WindowLengths& window_lengths,
|
||||
const typename Base::BottomTensorIndex& window_origin,
|
||||
const typename Base::TileDstr& tile_distribution)
|
||||
: cached_coords_{}, cached_flags_{}
|
||||
: cached_coords_{}, cached_window_adaptor_coords_{}, cached_flags_{}
|
||||
{
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
this->window_lengths_ = window_lengths;
|
||||
@@ -214,7 +214,8 @@ struct tile_window_linear
|
||||
|
||||
if constexpr(need_save_non_linear_coord)
|
||||
{
|
||||
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
|
||||
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
|
||||
cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
|
||||
}
|
||||
|
||||
// TODO: need pad_tensor_view to check which dim need use flag to check
|
||||
@@ -554,61 +555,42 @@ struct tile_window_linear
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
using vector_t = typename traits::vector_t;
|
||||
|
||||
// currently we only support everything is non linear dim
|
||||
// actually it's not performant if we have linear dim(e.g. fast changing)
|
||||
static_assert(NumAccess_NonLinear == NumAccess);
|
||||
static_assert(NumAccess_NonLinear == NumAccess, "Unsupported configuration");
|
||||
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global);
|
||||
address_space_enum::global,
|
||||
"Requires global memory");
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
// Precompute invariant values outside the lambda
|
||||
const auto window_origin = lds_tile.get_window_origin();
|
||||
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
|
||||
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
|
||||
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
|
||||
|
||||
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
|
||||
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
|
||||
// check?)
|
||||
constexpr index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{}));
|
||||
|
||||
constexpr index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
constexpr index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
|
||||
// TODO: we force CK_TILE_LDS_ADDR
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
|
||||
// Use precomputed values
|
||||
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
|
||||
auto window_adaptor_coord = cached_window_adaptor_coords_[non_linear_id];
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess];
|
||||
|
||||
// read from bottom tensor
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
window_origin + window_adaptor_coord.get_bottom_index();
|
||||
const auto lds_coord =
|
||||
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
|
||||
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
|
||||
|
||||
// Read from bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
bottom_tensor_flag,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(i_access_ != (NumAccess - 1))
|
||||
{
|
||||
smem += size_per_issue; // Note we manually increase the per-issue offset
|
||||
}
|
||||
};
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
@@ -928,7 +910,8 @@ struct tile_window_linear
|
||||
|
||||
if constexpr(need_save_non_linear_coord)
|
||||
{
|
||||
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
|
||||
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
|
||||
cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
|
||||
}
|
||||
|
||||
if constexpr(i_access != (NumAccess - 1))
|
||||
@@ -948,6 +931,8 @@ struct tile_window_linear
|
||||
|
||||
// this contains:
|
||||
array<typename Base::BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
|
||||
array<typename Base::WindowAdaptorCoord, traits::NumAccess_NonLinear>
|
||||
cached_window_adaptor_coords_;
|
||||
array<bool, Base::Traits::NumAccess> cached_flags_;
|
||||
};
|
||||
|
||||
|
||||
@@ -32,6 +32,15 @@ struct GemmPipelineAgBgCrImplBase
|
||||
move_tile_window(dram_tile_window, dram_tile_window_step);
|
||||
}
|
||||
|
||||
template <typename DstBlockWindow, typename SrcTileWindow, typename DramTileWindowStep>
|
||||
CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow& dst_block_window,
|
||||
SrcTileWindow& dram_tile_window,
|
||||
const DramTileWindowStep& dram_tile_window_step) const
|
||||
{
|
||||
async_load_tile(dst_block_window, dram_tile_window);
|
||||
move_tile_window(dram_tile_window, dram_tile_window_step);
|
||||
}
|
||||
|
||||
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
|
||||
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
|
||||
const SrcBlockTile& src_block_tile,
|
||||
|
||||
Reference in New Issue
Block a user