mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
[CK-Tile] functional support for transposed inputs in compute-bound double-lds-buffer pipeline with async loads from global memory to LDS (#2984)
* reuse local prefetch logic from compute v4 pipeline add single-tile test explicit lambda capture reuse lds block descriptors from base policy for the transposed case match the test case kernel configuration with compute v4 * add comments
This commit is contained in:
@@ -152,6 +152,9 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
@@ -249,11 +252,6 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
// TODO currently only support A matrix row major, B matrix col major; if A matrix is
|
||||
// col major or B is row major, need to combine with transpose load api
|
||||
static_assert(!(is_a_col_major || is_b_row_major),
|
||||
"only support A matrix is row major, B matrix is col major!");
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
@@ -293,18 +291,29 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
|
||||
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
|
||||
|
||||
// set up LDS tile shapes
|
||||
constexpr auto a_lds_shape = []() {
|
||||
if constexpr(is_a_load_tr_v)
|
||||
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
|
||||
constexpr auto b_lds_shape = []() {
|
||||
if constexpr(is_b_load_tr_v)
|
||||
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
|
||||
// LDS tile windows for storing, one per LDS buffer
|
||||
auto a_copy_lds_window0 = make_tile_window(
|
||||
a_lds_block0, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0});
|
||||
|
||||
auto a_copy_lds_window1 = make_tile_window(
|
||||
a_lds_block1, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0});
|
||||
|
||||
auto b_copy_lds_window0 = make_tile_window(
|
||||
b_lds_block0, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0});
|
||||
|
||||
auto b_copy_lds_window1 = make_tile_window(
|
||||
b_lds_block1, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0});
|
||||
|
||||
// initialize DRAM window steps, used to advance the DRAM windows
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
@@ -336,44 +345,49 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
|
||||
constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeABlockDistributionEncode())){};
|
||||
constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeBBlockDistributionEncode())){};
|
||||
// tile distribution for the register tiles
|
||||
constexpr auto ALdsTileDistr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto BLdsTileDistr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
// register tiles; double buffering -> a register tile corresponds to a LDS tile window
|
||||
ALdsTile a_block_tile0;
|
||||
ALdsTile a_block_tile1;
|
||||
ALdsTile a_block_tile0, a_block_tile1;
|
||||
BLdsTile b_block_tile0, b_block_tile1;
|
||||
|
||||
BLdsTile b_block_tile0;
|
||||
BLdsTile b_block_tile1;
|
||||
constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
|
||||
if constexpr(is_a_load_tr_v)
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename decltype(ALdsTileDistr)::DstrEncode,
|
||||
typename Problem::ADataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return ALdsTileDistr;
|
||||
}();
|
||||
constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() {
|
||||
if constexpr(is_b_load_tr_v)
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename decltype(BLdsTileDistr)::DstrEncode,
|
||||
typename Problem::BDataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return BLdsTileDistr;
|
||||
}();
|
||||
|
||||
// LDS tile windows for reading;
|
||||
// they share the data pointer with the LDS windows for storing
|
||||
// but also associate with a distribution to produce a register tile when reading
|
||||
auto a_lds_ld_window0 =
|
||||
make_tile_window(a_lds_block0,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
ALdsTileDistr);
|
||||
make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
|
||||
auto a_lds_ld_window1 =
|
||||
make_tile_window(a_lds_block1,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
ALdsTileDistr);
|
||||
make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
|
||||
auto b_lds_ld_window0 =
|
||||
make_tile_window(b_lds_block0,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
BLdsTileDistr);
|
||||
make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
|
||||
auto b_lds_ld_window1 =
|
||||
make_tile_window(b_lds_block1,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
BLdsTileDistr);
|
||||
make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
|
||||
|
||||
static_assert(!(is_tile_window_linear_v<decltype(a_lds_ld_window0)>) &&
|
||||
!(is_tile_window_linear_v<decltype(a_lds_ld_window1)>) &&
|
||||
@@ -384,8 +398,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
// write to LDS window(0) must complete before the local prefetch
|
||||
block_sync_lds_direct_load();
|
||||
// read A(0), B(0) from LDS window(0) to pipeline registers(0)
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
// LDS window(0) contents are overwritten below by global prefetch, need to sync
|
||||
block_sync_lds();
|
||||
// read A(2), B(2) from DRAM to LDS window(0)
|
||||
@@ -406,8 +420,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
// ping
|
||||
{
|
||||
// read A(i-1), B(i-1) from LDS window(1) to pipeline registers(1)
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
// LDS window(1) contents are overwritten by global prefetch, need to sync
|
||||
block_sync_lds();
|
||||
// read A(i), B(i) from DRAM to LDS window(1)
|
||||
@@ -427,8 +441,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
// write to LDS window(0) must complete before the local prefetch
|
||||
block_sync_lds_direct_load();
|
||||
// read A(i), B(i) from LDS window(0) to pipeline registers(0)
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
// LDS window(0) contents are overwritten by global prefetch, need to sync
|
||||
block_sync_lds();
|
||||
// read A(i+1), B(i+1) from DRAM to LDS window(0)
|
||||
@@ -452,15 +466,15 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
{
|
||||
{
|
||||
// read A(num_loop-1), B(num_loop-1) from LDS window(1) to pipeline registers(1)
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
// C(num_loop-2) = A(num_loop-2) @ B(num_loop-2)
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
}
|
||||
{
|
||||
// read A(num_loop), B(num_loop) from LDS window(0) to pipeline registers(0)
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1)
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
}
|
||||
@@ -474,8 +488,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
{
|
||||
{
|
||||
// read A(num_loop), B(num_loop) from LDS window(1) to pipeline registers(1)
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1)
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
}
|
||||
|
||||
@@ -23,21 +23,36 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
{
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
if constexpr(is_a_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better LDS descriptor for performance
|
||||
// This branch is reusing the logic from
|
||||
// UniversalGemmBasePolicy::MakeALdsBlockDescriptor
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<KPerBlock>{}, number<MPerBlock>{}),
|
||||
make_tuple(number<MPerBlock>{}, number<1>{}),
|
||||
number<MPerBlock>{},
|
||||
number<1>{});
|
||||
return a_lds_block_desc_0;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -45,21 +60,36 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
{
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
if constexpr(is_b_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better LDS descriptor for performance
|
||||
// This branch is reusing the logic from
|
||||
// UniversalGemmBasePolicy::MakeBLdsBlockDescriptor
|
||||
constexpr auto b_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<KPerBlock>{}, number<NPerBlock>{}),
|
||||
make_tuple(number<NPerBlock>{}, number<1>{}),
|
||||
number<NPerBlock>{},
|
||||
number<1>{});
|
||||
return b_lds_block_desc_0;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user