mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Merge commit '706c2b281caa201d2c9064e8940e0eb6c9e6710b' into develop
This commit is contained in:
2
.github/workflows/therock-test-component.yml
vendored
2
.github/workflows/therock-test-component.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
--group-add video
|
||||
--device /dev/kfd
|
||||
--device /dev/dri
|
||||
--group-add 992
|
||||
--group-add 110
|
||||
--env-file /etc/podinfo/gha-gpu-isolation-settings
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
@@ -139,6 +139,34 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
|
||||
// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html
|
||||
struct waitcnt_arg
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
// use s_wait_loadcnt_dscnt in this instruction; in this instruction, ds [5:0]; mem [13:8]
|
||||
CK_TILE_DEVICE static constexpr index_t MAX = 0b00'111111'00'111111;
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111;
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111;
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b111111;
|
||||
|
||||
template <index_t cnt>
|
||||
CK_TILE_DEVICE static constexpr index_t from_vmcnt()
|
||||
{
|
||||
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
|
||||
return MAX & (cnt << 8);
|
||||
}
|
||||
|
||||
template <index_t cnt>
|
||||
CK_TILE_DEVICE static constexpr index_t from_expcnt()
|
||||
{
|
||||
return 0; // no export in MI series
|
||||
}
|
||||
|
||||
template <index_t cnt>
|
||||
CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
|
||||
{
|
||||
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
|
||||
return MAX & cnt;
|
||||
}
|
||||
#else
|
||||
// bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210
|
||||
// [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV
|
||||
CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111;
|
||||
@@ -167,6 +195,7 @@ struct waitcnt_arg
|
||||
static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]");
|
||||
return MAX & (cnt << 8);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
|
||||
@@ -174,9 +203,18 @@ template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
|
||||
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
|
||||
CK_TILE_DEVICE void s_waitcnt()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
// GFX12 do't use __builtin_amdgcn_s_waitcnt
|
||||
constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
|
||||
waitcnt_arg::from_expcnt<expcnt>() |
|
||||
waitcnt_arg::from_lgkmcnt<lgkmcnt>();
|
||||
|
||||
asm volatile("s_wait_loadcnt_dscnt %0" : : "n"(wait_mask) : "memory");
|
||||
#else
|
||||
__builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt<vmcnt>() |
|
||||
waitcnt_arg::from_expcnt<expcnt>() |
|
||||
waitcnt_arg::from_lgkmcnt<lgkmcnt>());
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
|
||||
@@ -184,8 +222,23 @@ template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
|
||||
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
|
||||
CK_TILE_DEVICE void s_waitcnt_barrier()
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
// GFX12 optimization: Manual barrier implementation avoids performance penalty
|
||||
// from __builtin_amdgcn_s_barrier which inserts extra s_wait_loadcnt_dscnt 0x0
|
||||
constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
|
||||
waitcnt_arg::from_expcnt<expcnt>() |
|
||||
waitcnt_arg::from_lgkmcnt<lgkmcnt>();
|
||||
|
||||
asm volatile("s_wait_loadcnt_dscnt %0\n"
|
||||
"s_barrier_signal -1\n"
|
||||
"s_barrier_wait -1"
|
||||
:
|
||||
: "n"(wait_mask)
|
||||
: "memory");
|
||||
#else
|
||||
s_waitcnt<vmcnt, expcnt, lgkmcnt>();
|
||||
__builtin_amdgcn_s_barrier();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t lgkmcnt = 0>
|
||||
|
||||
@@ -797,7 +797,7 @@ struct MoeSortingKernel
|
||||
else
|
||||
smem_tokens(curr_token_id, eid)++;
|
||||
}
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
s_waitcnt<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, 0>();
|
||||
}
|
||||
__syncthreads(); // make sure different i_token iteration not overlap by different wave
|
||||
}
|
||||
@@ -922,7 +922,7 @@ struct MoeSortingKernel
|
||||
// NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt()
|
||||
// for above write however __syncthreads will cause barrier with waves other
|
||||
// than 0(which is not we want)
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
s_waitcnt<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, 0>();
|
||||
}
|
||||
if((lid + i_e_ - get_warp_size()) == (num_experts - 1))
|
||||
{
|
||||
|
||||
@@ -73,14 +73,10 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
|
||||
if constexpr(is_a_load_tr<Problem>)
|
||||
{
|
||||
@@ -94,168 +90,47 @@ struct UniversalGemmBasePolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
// Only use this ColumnMajor layout for Wave64 mode (gfx9)
|
||||
constexpr auto Wave64 = get_warp_size() == 64;
|
||||
if constexpr(Wave64 &&
|
||||
std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
// kfold and mpair dimension is not always required.
|
||||
// more dimension in merge_transform increase the difficulty of generating immarg
|
||||
// offset for compiler.
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
getATileAccessPattern()>;
|
||||
// AK1
|
||||
constexpr auto AK1 = number<VecLoadSize>{};
|
||||
constexpr auto AK0 = number<KPerBlock / AK1>{};
|
||||
// How the M dimension is split across threads
|
||||
constexpr auto M0 = TileEncodingPattern::X0; // # of threads in M dim
|
||||
constexpr auto M1 = number<MPerBlock / M0>{};
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
// Get the warp tile size
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
constexpr auto MPerXdl = number<WarpTile::at(I0)>{};
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr auto MLdsLayer =
|
||||
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
||||
|
||||
// How many elements we can write by single thread to LDS,
|
||||
// the transposed / shuffled tile dstr has size: <X1, Y2>
|
||||
constexpr auto KThreadWrite = TileEncodingPattern::Y2;
|
||||
constexpr auto K0PerThreadWrite = integer_divide_ceil(AK0, KThreadWrite);
|
||||
constexpr auto KThreadRead = get_warp_size() / MPerXdl;
|
||||
constexpr auto K0PerThreadRead = integer_divide_ceil(AK0, KThreadRead);
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
|
||||
number<MPerBlock / MLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto LdsBanksWidth = 128;
|
||||
constexpr auto kfold = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth)
|
||||
? 1
|
||||
: LdsBanksWidth / (AK1 * M0 * sizeof(ADataType));
|
||||
constexpr auto KThreadReadPerm =
|
||||
((kfold * K0PerThreadWrite / K0PerThreadRead) > 1 &&
|
||||
(kfold * K0PerThreadWrite / K0PerThreadRead) < KThreadRead)
|
||||
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
|
||||
: KThreadRead;
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
|
||||
number<KPerBlock / KPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
// 1<=mpair<=n0
|
||||
constexpr auto mpair =
|
||||
(AK1 * MPerXdl * sizeof(ADataType) > LdsBanksWidth)
|
||||
? 1
|
||||
: ((LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType))) > M0
|
||||
? M0
|
||||
: LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType)));
|
||||
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KThreadReadPerm * M1>{},
|
||||
number<kfold * M0 / mpair>{},
|
||||
number<mpair>{},
|
||||
AK1),
|
||||
AK1);
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_pass_through_transform(
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(make_tuple(number<KThreadReadPerm * M1>{},
|
||||
number<kfold * M0 / mpair>{})),
|
||||
make_pass_through_transform(number<mpair>{}),
|
||||
make_pass_through_transform(AK1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2, 3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2, 3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<M1>{})),
|
||||
make_unmerge_transform(make_tuple(number<kfold>{}, number<M0 / mpair>{})),
|
||||
make_pass_through_transform(number<mpair>{}),
|
||||
make_pass_through_transform(AK1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<0, 3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_unmerged,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KThreadReadPerm>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<kfold>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
AK1)),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<M0 / mpair>{}, number<mpair>{}, number<M1>{}))),
|
||||
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return a_lds_block_desc_ak0_m_ak1;
|
||||
}
|
||||
else // A is in RowMajor
|
||||
{
|
||||
constexpr auto MLdsLayer = (32 * 4 / KPerBlock / DataTypeSize) < 1
|
||||
? 1
|
||||
: (32 * 4 / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
|
||||
number<MPerBlock / MLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
|
||||
number<KPerBlock / KPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -268,12 +143,12 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
#if 1
|
||||
if constexpr(is_b_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better lds descriptor for performance
|
||||
@@ -285,169 +160,178 @@ struct UniversalGemmBasePolicy
|
||||
return b_lds_block_desc_0;
|
||||
}
|
||||
else
|
||||
// else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
// Only use this RowMajor layout for Wave64 mode (gfx9)
|
||||
constexpr auto Wave64 = get_warp_size() == 64;
|
||||
if constexpr(Wave64 && std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
getBTileAccessPattern()>;
|
||||
// BK1
|
||||
constexpr auto BK1 = number<VecLoadSize>{};
|
||||
constexpr auto BK0 = number<KPerBlock / BK1>{};
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
constexpr auto BK0 = number<KPerBlock / KPack>{};
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
||||
|
||||
// How threads access data on N dim
|
||||
constexpr auto N0 = TileEncodingPattern::X0; // # of threads in N dim
|
||||
constexpr auto N1 = number<NPerBlock / N0>{};
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
BK0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
// Get NPerXdl, the warp tile size
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
constexpr auto NPerXdl = number<WarpTile::at(I1)>{};
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
|
||||
BK0 * number<NLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
// How many elements we can write by single thread to LDS,
|
||||
// the transposed / shuffled tile dstr has size: <X1, Y2>
|
||||
constexpr auto KThreadWrite = TileEncodingPattern::Y2;
|
||||
constexpr auto K0PerThreadWrite = integer_divide_ceil(BK0, KThreadWrite);
|
||||
constexpr auto KThreadRead = get_warp_size() / NPerXdl;
|
||||
constexpr auto K0PerThreadRead = integer_divide_ceil(BK0, KThreadRead);
|
||||
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<NLdsLayer>{}, BK0)),
|
||||
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
// check if we exceed all 32banks width - (32x4B)
|
||||
constexpr auto LdsBanksWidth = 128;
|
||||
constexpr auto kfold = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth)
|
||||
? 1
|
||||
: LdsBanksWidth / (BK1 * N0 * sizeof(BDataType));
|
||||
constexpr auto KThreadReadPerm =
|
||||
((kfold * K0PerThreadWrite / K0PerThreadRead) > 1 &&
|
||||
(kfold * K0PerThreadWrite / K0PerThreadRead) < KThreadRead)
|
||||
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
|
||||
: KThreadRead;
|
||||
|
||||
// 1<=npair<=n0
|
||||
constexpr auto npair =
|
||||
(BK1 * NPerXdl * sizeof(BDataType) > LdsBanksWidth)
|
||||
? 1
|
||||
: ((LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType))) > N0
|
||||
? N0
|
||||
: LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType)));
|
||||
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KThreadReadPerm * N1>{},
|
||||
number<kfold * N0 / npair>{},
|
||||
number<npair>{},
|
||||
BK1),
|
||||
BK1);
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(make_pass_through_transform(
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(make_tuple(number<KThreadReadPerm * N1>{},
|
||||
number<kfold * N0 / npair>{})),
|
||||
make_pass_through_transform(number<npair>{}),
|
||||
make_pass_through_transform(BK1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2, 3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2, 3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
|
||||
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
|
||||
make_pass_through_transform(number<npair>{}),
|
||||
make_pass_through_transform(BK1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(
|
||||
sequence<1>{}, // 0: K0PerThreadWrite
|
||||
sequence<2>{}, // 1: KThreadReadPerm
|
||||
sequence<0, 3>{}, // 2: KThreadWrite / kfold / KThreadReadPerm, 3: N1
|
||||
sequence<4, 5>{}, // 4: kfold, 5: N0 / npair
|
||||
sequence<6>{}, // 6: npair
|
||||
sequence<7>{})); // 7: BK1
|
||||
|
||||
constexpr auto b_lds_block_desc_nk = transform_tensor_descriptor(
|
||||
b_lds_block_desc_unmerged,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KThreadReadPerm>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<kfold>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
BK1)),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
|
||||
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return b_lds_block_desc_nk;
|
||||
}
|
||||
else // B is Column Major
|
||||
{
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
constexpr auto BK0 = number<KPerBlock / KPack>{};
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer = (32 * 4 / KPerBlock / DataTypeSize) < 1
|
||||
? 1
|
||||
: (32 * 4 / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(BK0 * number<NLdsLayer>{},
|
||||
number<NPerBlock / NLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
|
||||
BK0 * number<NLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<NLdsLayer>{}, BK0)),
|
||||
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_bk0_nldslayer_n_bk1,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(BK0, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_bk0_nldslayer_n_bk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(BK0, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
#else
|
||||
else // B is Row Major
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
getBTileAccessPattern()>;
|
||||
|
||||
constexpr auto BK0 = number<TileEncodingPattern::X1>{};
|
||||
constexpr auto BK1 = number<TileEncodingPattern::Y0>{};
|
||||
// constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
|
||||
constexpr auto N0 = TileEncodingPattern::X0;
|
||||
constexpr auto N1 = NPerBlock / N0;
|
||||
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
constexpr auto NPerXdl = number<WarpTile::at(I1)>{};
|
||||
|
||||
// constexpr auto KThreadWrite =
|
||||
// BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
|
||||
constexpr auto KThreadWrite = TileEncodingPattern::Y2;
|
||||
constexpr auto K0PerThreadWrite = BK0 / KThreadWrite;
|
||||
constexpr auto KThreadRead = 64 / NPerXdl;
|
||||
constexpr auto K0PerThreadRead = BK0 / KThreadRead;
|
||||
|
||||
constexpr auto kfold =
|
||||
(BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType));
|
||||
constexpr auto KThreadReadPerm =
|
||||
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
|
||||
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
|
||||
: KThreadRead;
|
||||
|
||||
// 1<=npair<=n0
|
||||
constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128)
|
||||
? 1
|
||||
: ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0
|
||||
? N0
|
||||
: 128 / (BK1 * NPerXdl * sizeof(BDataType)));
|
||||
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KThreadReadPerm * N1>{},
|
||||
number<kfold * N0 / npair>{},
|
||||
number<npair>{},
|
||||
BK1));
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_tuple(number<KThreadReadPerm * N1>{}, number<kfold * N0 / npair>{})),
|
||||
make_pass_through_transform(number<npair>{}),
|
||||
make_pass_through_transform(BK1)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
|
||||
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
|
||||
make_pass_through_transform(number<npair>{}),
|
||||
make_pass_through_transform(BK1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<0, 3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
// constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
// b_lds_block_desc_unmerged,
|
||||
// make_tuple(make_merge_transform_v3_division_mod(
|
||||
// make_tuple(number<KThreadReadPerm>{},
|
||||
// number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
// number<kfold>{},
|
||||
// number<K0PerThreadWrite>{})),
|
||||
// make_merge_transform_v3_division_mod(
|
||||
// make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{})),
|
||||
// make_pass_through_transform(BK1)),
|
||||
// make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor(
|
||||
b_lds_block_desc_unmerged,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KThreadReadPerm>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<kfold>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
BK1)),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
|
||||
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
// return b_lds_block_desc_bk0_n_bk1;
|
||||
return b_lds_block_desc_kn;
|
||||
|
||||
// constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor(
|
||||
// make_tuple(BK0, number<NPerBlock>{}, number<KPack>{}),
|
||||
// make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
// number<KPack>{},
|
||||
// number<1>{});
|
||||
|
||||
// constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
// b_lds_block_desc_bk0_n_bk1,
|
||||
// make_tuple(make_pass_through_transform(number<NPerBlock>{}),
|
||||
// make_merge_transform_v3_division_mod(make_tuple(BK0,
|
||||
// number<KPack>{}))),
|
||||
// make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// return b_lds_block_desc;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user