mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
optimize A_LDS descriptor to avoid bankconflict
This commit is contained in:
@@ -513,44 +513,35 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
|
||||
PipelinePolicy::template MakeF16xF4_ALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block_ping =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
|
||||
auto a_lds_block_pong =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
|
||||
auto A_XDL_TileDist = make_static_tile_distribution(typename WG::AWarpDstrEncoding{});
|
||||
auto A_Lds_TileDist =
|
||||
PipelinePolicy::template MakeFp16xF4_DS_WRITE_ATileDistribution<Problem>();
|
||||
auto A_Lds_Stride = WG::kK;
|
||||
|
||||
// auto A_XDL_TileDist = PipelinePolicy::template
|
||||
// MakeF16xF4_ALDS_TileDistribution<Problem>(); auto A_Lds_TileDist =
|
||||
// PipelinePolicy::template MakeADramTileDistribution<Problem>(); auto A_Lds_Stride = 8;
|
||||
|
||||
auto a_copy_lds_window_ping =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
A_Lds_TileDist);
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
auto a_copy_lds_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
A_Lds_TileDist);
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping_tmp =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
A_XDL_TileDist);
|
||||
PipelinePolicy::template MakeF16xF4_ALDS_TileDistribution<Problem>());
|
||||
auto a_warp_window_pong_tmp =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
A_XDL_TileDist);
|
||||
PipelinePolicy::template MakeF16xF4_ALDS_TileDistribution<Problem>());
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
|
||||
@@ -562,26 +553,22 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
auto A_Lds_Stride = 8;
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
// auto weight_k_idx = kIter / number<XDL_PerWeightK>{};
|
||||
// auto weight_k_rank = kIter % number<XDL_PerWeightK>{};
|
||||
// move_tile_window(
|
||||
// a_warp_windows_ping(mIter)(kIter),
|
||||
// {mIter * MPerBlockPerIter,
|
||||
// weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
|
||||
// move_tile_window(
|
||||
// a_warp_windows_pong(mIter)(kIter),
|
||||
// {mIter * MPerBlockPerIter,
|
||||
// weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
auto weight_k_idx = kIter / number<XDL_PerWeightK>{};
|
||||
auto weight_k_rank = kIter % number<XDL_PerWeightK>{};
|
||||
move_tile_window(
|
||||
a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter,
|
||||
weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
|
||||
move_tile_window(
|
||||
a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter,
|
||||
weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -673,10 +660,8 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
|
||||
|
||||
// A_Lds_TileDist may differ with ADramTileDistribution
|
||||
auto a_block_tile_transformed = make_static_distributed_tensor<ComputeType>(A_Lds_TileDist);
|
||||
|
||||
a_block_tile_transformed.get_thread_buffer() =
|
||||
tile_elementwise_in(a_element_func, a_block_tile).get_thread_buffer();
|
||||
auto a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -789,8 +774,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
a_block_tile_transformed.get_thread_buffer() =
|
||||
tile_elementwise_in(a_element_func, a_block_tile).get_thread_buffer();
|
||||
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
@@ -893,8 +877,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
a_block_tile_transformed.get_thread_buffer() =
|
||||
tile_elementwise_in(a_element_func, a_block_tile).get_thread_buffer();
|
||||
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
|
||||
|
||||
// Prefetch A(2i+3)
|
||||
@@ -1001,8 +984,7 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
a_block_tile_transformed.get_thread_buffer() =
|
||||
tile_elementwise_in(a_element_func, a_block_tile).get_thread_buffer();
|
||||
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
|
||||
|
||||
// GEMM loopK-1
|
||||
|
||||
@@ -17,6 +17,57 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
static constexpr index_t N_Pack = 2; // it's fixed for fp4
|
||||
static constexpr index_t K_Pack = 2; // it's fixed for fp4
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr index_t XDL_PerWeightK = 4;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<KPerBlock / KPack / XDL_PerWeightK>{},
|
||||
number<MPerBlock>{},
|
||||
number<XDL_PerWeightK>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack * XDL_PerWeightK>{},
|
||||
number<KPerBlock>{},
|
||||
number<KPack>{},
|
||||
number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
|
||||
number<KPerBlock / KPack / XDL_PerWeightK>{})),
|
||||
make_pass_through_transform(number<XDL_PerWeightK>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
transform_tensor_descriptor(a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<KPerBlock / KPack / XDL_PerWeightK>{},
|
||||
number<XDL_PerWeightK>{},
|
||||
number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp16xF4_ADramTileDistribution()
|
||||
{
|
||||
@@ -49,46 +100,6 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp16xF4_DS_WRITE_ATileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M2, M1 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
// unmerge K0 to K16_i x K4_1 x K4_2
|
||||
// then exchange the order of K4_1 and K4_2
|
||||
constexpr index_t XDL_PerKBLoad = 4;
|
||||
constexpr index_t K128_Cnt = K0 / XDL_PerKBLoad / XDL_PerKBLoad;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K128_Cnt, XDL_PerKBLoad, XDL_PerKBLoad, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2, 2, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0, 2, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ALDS_TileDistribution()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user