improve codes

This commit is contained in:
joye
2025-05-21 15:22:33 +08:00
parent 7377bc7200
commit 2e296ee963
6 changed files with 61 additions and 198 deletions

View File

@@ -78,9 +78,6 @@ struct buffer_store;
template <index_t bytes>
struct buffer_store_if;
template <index_t bytes, bool pre_nop = false>
struct async_buffer_load;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
@@ -1651,90 +1648,6 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
}
}
template <bool pre_nop>
struct async_buffer_load<16, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(CK_TILE_LDS_ADDR T* smem,
int32x4_t rsrc /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
#if defined(__gfx950__)
asm volatile("s_nop 4\n"
"buffer_load_dwordx4 %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(v_offset), "s"(rsrc), "n"(i_offset)
: "memory");
// __builtin_amdgcn_struct_buffer_load_lds(rsrc, smem, 16, i_offset, v_offset, 0, 0, 0);
#else
ignore = smem;
ignore = rsrc;
ignore = v_offset;
ignore = i_offset;
#endif
}
};
template <bool pre_nop>
struct async_buffer_load<12, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(CK_TILE_LDS_ADDR T* smem,
int32x4_t rsrc /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
#if defined(__gfx950__)
asm volatile("s_nop 4\n"
"buffer_load_dwordx3 %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(v_offset), "s"(rsrc), "n"(i_offset)
: "memory");
// __builtin_amdgcn_struct_buffer_load_lds(rsrc, smem, 12, i_offset, v_offset, 0, 0, 0);
#else
ignore = smem;
ignore = rsrc;
ignore = v_offset;
ignore = i_offset;
#endif
}
};
template <bool pre_nop>
struct async_buffer_load<4, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(CK_TILE_LDS_ADDR T* smem,
int32x4_t rsrc /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
#if defined(__gfx950__)
asm volatile("s_nop 4\n"
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(v_offset), "s"(rsrc), "n"(i_offset)
: "memory");
// __builtin_amdgcn_struct_buffer_load_lds(rsrc, smem, 4, i_offset, v_offset, 0, 0, 0);
#else
ignore = smem;
ignore = rsrc;
ignore = v_offset;
ignore = i_offset;
#endif
}
};
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
@@ -1746,18 +1659,6 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
index_t src_immediate_addr_offset = 0,
bool_constant<pre_nop> = {})
{
#if defined(__gfx950__)
constexpr index_t bytes = sizeof(T) * N;
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"wrong! only support in dword, dwordx3, dwordx4");
async_buffer_load<bytes, pre_nop>{}(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
0,
bool_constant<pre_nop>{});
#else
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
async_buffer_load_dword_v(smem,
@@ -1767,7 +1668,6 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
src_immediate_addr_offset,
0,
bool_constant<pre_nop>{});
#endif
}
template <typename T,
@@ -1782,10 +1682,13 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
// #if defined(__gfx950__)
constexpr index_t bytes = sizeof(T) * N;
#if defined(__gfx950__)
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"wrong! only support in dword, dwordx3, dwordx4");
#else
static_assert(bytes == 4, "wrong! not implemented vector size");
#endif
ignore = src_wave_addr_offset;
ignore = src_immediate_addr_offset;
if constexpr(oob_conditional_check)

View File

@@ -546,7 +546,7 @@ struct tile_window_with_static_distribution
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// #if defined(__gfx950__)
#if defined(__gfx950__)
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
@@ -586,76 +586,71 @@ struct tile_window_with_static_distribution
}
});
});
// #else
#else
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// // issues * warps * lanes
// static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't
figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero
(how to
// check?)
constexpr index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{}));
// // TODO: LDS offset is not good for intrinsic based implementation(compiler can't
// figure out
// // dependency) hence avoid use offset based solution. size_per_buf should be zero
// (how to
// // check?)
// constexpr index_t size_per_buf =
// lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
// make_tuple(number<0>{}, number<0>{}, number<0>{}));
constexpr index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
size_per_buf;
// constexpr index_t size_per_wave =
// lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
// make_tuple(number<0>{}, number<1>{}, number<0>{})) -
// size_per_buf;
constexpr index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
size_per_buf;
// constexpr index_t size_per_issue =
// lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
// make_tuple(number<1>{}, number<0>{}, number<0>{})) -
// size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
// const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using Traits = load_store_traits;
// using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
// using vector_t = typename Traits::vector_t;
// using SFC_Ys = typename Traits::SFC_Ys;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// // TODO: we force CK_TILE_LDS_ADDR
// CK_TILE_LDS_ADDR LdsDataType* smem =
// lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
// // loop over thread tensor space [y0, y1, ...]
// static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// /// TODO: use structure binding (to be captured later) if compiled in C++20
// auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
// auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
// constexpr auto iAccess = number<iCoord * NumAccessPerCoord +
// iCoordAccess>{};
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
// // read from bottom tensor
// get_bottom_tensor_view().template
// async_get_vectorized_elements<vector_t>(
// smem, bottom_tensor_thread_coord, 0,
// bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
// // move thread coordinate
// if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
// {
// constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
// constexpr auto idx_diff_ps_ys = container_concat(
// generate_tuple([&](auto) { return number<0>{}; },
// number<NDimP>{}), idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
// move_window_adaptor_and_bottom_tensor_thread_coordinate(
// window_adaptor_thread_coord, bottom_tensor_thread_coord,
// idx_diff_ps_ys);
// smem += size_per_issue; // Note we manually increase the per-issue
// offset
// }
// });
// });
// #endif
smem += size_per_issue; // Note we manually increase the per-issue
offset
}
});
});
#endif
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>

View File

@@ -33,11 +33,11 @@ struct GemmPipelineAgBgCrImplBase
}
template <typename DstBlockWindow, typename SrcTileWindow, typename DramTileWindowStep>
CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow& dst_block_tile,
CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow& dst_block_window,
SrcTileWindow& dram_tile_window,
const DramTileWindowStep& dram_tile_window_step) const
{
async_load_tile(dst_block_tile, dram_tile_window);
async_load_tile(dst_block_window, dram_tile_window);
move_tile_window(dram_tile_window, dram_tile_window_step);
}

View File

@@ -1,16 +0,0 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_util.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineAsync : public TestCkTileGemmPipeline<T>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelineAsync
TYPED_TEST_SUITE(TestCkTileGemmPipelineAsync, KernelTypesAsync);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -20,7 +20,6 @@ using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV3>;
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
using Async = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Async>;
// clang-format off
using KernelTypesMem = ::testing::Types<
@@ -60,8 +59,4 @@ using KernelTypesCompV4 = ::testing::Types<
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>
>;
using KernelTypesAsync = ::testing::Types<
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Async>
>;
// clang-format on

View File

@@ -36,8 +36,7 @@ enum struct GemmPipelineType
{
Mem,
CompV3,
CompV4,
Async
CompV4
};
template <GemmPipelineType PT, typename Problem>
@@ -64,13 +63,6 @@ struct GemmPipelineTypeSelector<GemmPipelineType::CompV4, Problem>
using pipeline = ck_tile::GemmPipelineAgBgCrCompV4<Problem>;
};
template <typename Problem>
struct GemmPipelineTypeSelector<GemmPipelineType::Async, Problem>
{
using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync<Problem>;
using pipeline = ck_tile::GemmPipelineAgBgCrCompAsync<Problem>;
};
template <typename Pipeline, ck_tile::TailNumber TN>
void try_run(ck_tile::TailNumber tn)
{
@@ -105,10 +97,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
// TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile =
(PipelineType == GemmPipelineType::CompV4 || PipelineType == GemmPipelineType::Async)
? 32
: 64;
constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
@@ -122,10 +111,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;
constexpr bool DoubleSmemBuffer =
(PipelineType == GemmPipelineType::CompV4 || PipelineType == GemmPipelineType::Async)
? true
: false;
constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false;
// TODO: For now - but this should also be a test parameter
constexpr bool TransposeC = false;