mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
improve codes
This commit is contained in:
@@ -78,9 +78,6 @@ struct buffer_store;
|
||||
template <index_t bytes>
|
||||
struct buffer_store_if;
|
||||
|
||||
template <index_t bytes, bool pre_nop = false>
|
||||
struct async_buffer_load;
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
|
||||
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
|
||||
@@ -1651,90 +1648,6 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
|
||||
}
|
||||
}
|
||||
|
||||
template <bool pre_nop>
|
||||
struct async_buffer_load<16, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t rsrc /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_dwordx4 %1, %2, 0 offen offset:%3 lds"
|
||||
: "=r"(smem) /*dummy dependency for smem*/
|
||||
: "v"(v_offset), "s"(rsrc), "n"(i_offset)
|
||||
: "memory");
|
||||
// __builtin_amdgcn_struct_buffer_load_lds(rsrc, smem, 16, i_offset, v_offset, 0, 0, 0);
|
||||
#else
|
||||
ignore = smem;
|
||||
ignore = rsrc;
|
||||
ignore = v_offset;
|
||||
ignore = i_offset;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <bool pre_nop>
|
||||
struct async_buffer_load<12, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t rsrc /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_dwordx3 %1, %2, 0 offen offset:%3 lds"
|
||||
: "=r"(smem) /*dummy dependency for smem*/
|
||||
: "v"(v_offset), "s"(rsrc), "n"(i_offset)
|
||||
: "memory");
|
||||
// __builtin_amdgcn_struct_buffer_load_lds(rsrc, smem, 12, i_offset, v_offset, 0, 0, 0);
|
||||
#else
|
||||
ignore = smem;
|
||||
ignore = rsrc;
|
||||
ignore = v_offset;
|
||||
ignore = i_offset;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <bool pre_nop>
|
||||
struct async_buffer_load<4, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t rsrc /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag*/ = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
asm volatile("s_nop 4\n"
|
||||
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
|
||||
: "=r"(smem) /*dummy dependency for smem*/
|
||||
: "v"(v_offset), "s"(rsrc), "n"(i_offset)
|
||||
: "memory");
|
||||
// __builtin_amdgcn_struct_buffer_load_lds(rsrc, smem, 4, i_offset, v_offset, 0, 0, 0);
|
||||
#else
|
||||
ignore = smem;
|
||||
ignore = rsrc;
|
||||
ignore = v_offset;
|
||||
ignore = i_offset;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
@@ -1746,18 +1659,6 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t src_immediate_addr_offset = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
async_buffer_load<bytes, pre_nop>{}(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
0,
|
||||
bool_constant<pre_nop>{});
|
||||
#else
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
|
||||
async_buffer_load_dword_v(smem,
|
||||
@@ -1767,7 +1668,6 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
|
||||
src_immediate_addr_offset,
|
||||
0,
|
||||
bool_constant<pre_nop>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
@@ -1782,10 +1682,13 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
// #if defined(__gfx950__)
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
#if defined(__gfx950__)
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"wrong! only support in dword, dwordx3, dwordx4");
|
||||
#else
|
||||
static_assert(bytes == 4, "wrong! not implemented vector size");
|
||||
#endif
|
||||
ignore = src_wave_addr_offset;
|
||||
ignore = src_immediate_addr_offset;
|
||||
if constexpr(oob_conditional_check)
|
||||
|
||||
@@ -546,7 +546,7 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
// #if defined(__gfx950__)
|
||||
#if defined(__gfx950__)
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
@@ -586,76 +586,71 @@ struct tile_window_with_static_distribution
|
||||
}
|
||||
});
|
||||
});
|
||||
// #else
|
||||
#else
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
// // issues * warps * lanes
|
||||
// static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't
|
||||
figure out
|
||||
// dependency) hence avoid use offset based solution. size_per_buf should be zero
|
||||
(how to
|
||||
// check?)
|
||||
constexpr index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{}));
|
||||
|
||||
// // TODO: LDS offset is not good for intrinsic based implementation(compiler can't
|
||||
// figure out
|
||||
// // dependency) hence avoid use offset based solution. size_per_buf should be zero
|
||||
// (how to
|
||||
// // check?)
|
||||
// constexpr index_t size_per_buf =
|
||||
// lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
// make_tuple(number<0>{}, number<0>{}, number<0>{}));
|
||||
constexpr index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
// constexpr index_t size_per_wave =
|
||||
// lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
// make_tuple(number<0>{}, number<1>{}, number<0>{})) -
|
||||
// size_per_buf;
|
||||
constexpr index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
// constexpr index_t size_per_issue =
|
||||
// lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
// make_tuple(number<1>{}, number<0>{}, number<0>{})) -
|
||||
// size_per_buf;
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
// const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
using Traits = load_store_traits;
|
||||
|
||||
// using Traits = load_store_traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
// using vector_t = typename Traits::vector_t;
|
||||
// using SFC_Ys = typename Traits::SFC_Ys;
|
||||
// TODO: we force CK_TILE_LDS_ADDR
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
|
||||
|
||||
// // TODO: we force CK_TILE_LDS_ADDR
|
||||
// CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
// lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
// // loop over thread tensor space [y0, y1, ...]
|
||||
// static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
// /// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
// auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
// auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
// constexpr auto iAccess = number<iCoord * NumAccessPerCoord +
|
||||
// iCoordAccess>{};
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
|
||||
// // read from bottom tensor
|
||||
// get_bottom_tensor_view().template
|
||||
// async_get_vectorized_elements<vector_t>(
|
||||
// smem, bottom_tensor_thread_coord, 0,
|
||||
// bool_constant<oob_conditional_check>{});
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
// // move thread coordinate
|
||||
// if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
// {
|
||||
// constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
// constexpr auto idx_diff_ps_ys = container_concat(
|
||||
// generate_tuple([&](auto) { return number<0>{}; },
|
||||
// number<NDimP>{}), idx_diff_ys);
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
// move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
// window_adaptor_thread_coord, bottom_tensor_thread_coord,
|
||||
// idx_diff_ps_ys);
|
||||
|
||||
// smem += size_per_issue; // Note we manually increase the per-issue
|
||||
// offset
|
||||
// }
|
||||
// });
|
||||
// });
|
||||
// #endif
|
||||
smem += size_per_issue; // Note we manually increase the per-issue
|
||||
offset
|
||||
}
|
||||
});
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
|
||||
@@ -33,11 +33,11 @@ struct GemmPipelineAgBgCrImplBase
|
||||
}
|
||||
|
||||
template <typename DstBlockWindow, typename SrcTileWindow, typename DramTileWindowStep>
|
||||
CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow& dst_block_tile,
|
||||
CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow& dst_block_window,
|
||||
SrcTileWindow& dram_tile_window,
|
||||
const DramTileWindowStep& dram_tile_window_step) const
|
||||
{
|
||||
async_load_tile(dst_block_tile, dram_tile_window);
|
||||
async_load_tile(dst_block_window, dram_tile_window);
|
||||
move_tile_window(dram_tile_window, dram_tile_window_step);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineAsync : public TestCkTileGemmPipeline<T>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineAsync
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineAsync, KernelTypesAsync);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -20,7 +20,6 @@ using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
|
||||
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
|
||||
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
|
||||
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
|
||||
using Async = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Async>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypesMem = ::testing::Types<
|
||||
@@ -60,8 +59,4 @@ using KernelTypesCompV4 = ::testing::Types<
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>
|
||||
>;
|
||||
|
||||
using KernelTypesAsync = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Async>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -36,8 +36,7 @@ enum struct GemmPipelineType
|
||||
{
|
||||
Mem,
|
||||
CompV3,
|
||||
CompV4,
|
||||
Async
|
||||
CompV4
|
||||
};
|
||||
|
||||
template <GemmPipelineType PT, typename Problem>
|
||||
@@ -64,13 +63,6 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrCompV4<Problem>;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct GemmPipelineTypeSelector<GemmPipelineType::Async, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync<Problem>;
|
||||
using pipeline = ck_tile::GemmPipelineAgBgCrCompAsync<Problem>;
|
||||
};
|
||||
|
||||
template <typename Pipeline, ck_tile::TailNumber TN>
|
||||
void try_run(ck_tile::TailNumber tn)
|
||||
{
|
||||
@@ -105,10 +97,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
// TODO: This should be parameterized in tests
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile =
|
||||
(PipelineType == GemmPipelineType::CompV4 || PipelineType == GemmPipelineType::Async)
|
||||
? 32
|
||||
: 64;
|
||||
constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
@@ -122,10 +111,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
|
||||
constexpr bool DoubleSmemBuffer =
|
||||
(PipelineType == GemmPipelineType::CompV4 || PipelineType == GemmPipelineType::Async)
|
||||
? true
|
||||
: false;
|
||||
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false;
|
||||
|
||||
// TODO: For now - but this should also be a test parameter
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
Reference in New Issue
Block a user