Merge commit '706c2b281caa201d2c9064e8940e0eb6c9e6710b' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-14 16:13:22 +00:00
parent db6db740c4
commit e787672715
4 changed files with 258 additions and 321 deletions

View File

@@ -29,7 +29,7 @@ jobs:
--group-add video
--device /dev/kfd
--device /dev/dri
--group-add 992
--group-add 110
--env-file /etc/podinfo/gha-gpu-isolation-settings
strategy:
fail-fast: false

View File

@@ -139,6 +139,34 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html
struct waitcnt_arg
{
#if defined(__gfx12__)
// use s_wait_loadcnt_dscnt in this instruction; in this instruction, ds [5:0]; mem [13:8]
CK_TILE_DEVICE static constexpr index_t MAX = 0b00'111111'00'111111;
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111;
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111;
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b111111;
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_vmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
return MAX & (cnt << 8);
}
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_expcnt()
{
return 0; // no export in MI series
}
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
return MAX & cnt;
}
#else
// bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210
// [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV
CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111;
@@ -167,6 +195,7 @@ struct waitcnt_arg
static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]");
return MAX & (cnt << 8);
}
#endif
};
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
@@ -174,9 +203,18 @@ template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
CK_TILE_DEVICE void s_waitcnt()
{
#if defined(__gfx12__)
// GFX12 do't use __builtin_amdgcn_s_waitcnt
constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
waitcnt_arg::from_expcnt<expcnt>() |
waitcnt_arg::from_lgkmcnt<lgkmcnt>();
asm volatile("s_wait_loadcnt_dscnt %0" : : "n"(wait_mask) : "memory");
#else
__builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt<vmcnt>() |
waitcnt_arg::from_expcnt<expcnt>() |
waitcnt_arg::from_lgkmcnt<lgkmcnt>());
#endif
}
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
@@ -184,8 +222,23 @@ template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
CK_TILE_DEVICE void s_waitcnt_barrier()
{
#if defined(__gfx12__)
// GFX12 optimization: Manual barrier implementation avoids performance penalty
// from __builtin_amdgcn_s_barrier which inserts extra s_wait_loadcnt_dscnt 0x0
constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
waitcnt_arg::from_expcnt<expcnt>() |
waitcnt_arg::from_lgkmcnt<lgkmcnt>();
asm volatile("s_wait_loadcnt_dscnt %0\n"
"s_barrier_signal -1\n"
"s_barrier_wait -1"
:
: "n"(wait_mask)
: "memory");
#else
s_waitcnt<vmcnt, expcnt, lgkmcnt>();
__builtin_amdgcn_s_barrier();
#endif
}
template <index_t lgkmcnt = 0>

View File

@@ -797,7 +797,7 @@ struct MoeSortingKernel
else
smem_tokens(curr_token_id, eid)++;
}
__builtin_amdgcn_s_waitcnt(0xc07f);
s_waitcnt<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, 0>();
}
__syncthreads(); // make sure different i_token iteration not overlap by different wave
}
@@ -922,7 +922,7 @@ struct MoeSortingKernel
// NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt()
// for above write however __syncthreads will cause barrier with waves other
// than 0(which is not we want)
__builtin_amdgcn_s_waitcnt(0xc07f);
s_waitcnt<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, 0>();
}
if((lid + i_e_ - get_warp_size()) == (num_experts - 1))
{

View File

@@ -73,14 +73,10 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto DataTypeSize = sizeof(ADataType);
if constexpr(is_a_load_tr<Problem>)
{
@@ -94,168 +90,47 @@ struct UniversalGemmBasePolicy
}
else
{
// Only use this ColumnMajor layout for Wave64 mode (gfx9)
constexpr auto Wave64 = get_warp_size() == 64;
if constexpr(Wave64 &&
std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg
// offset for compiler.
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
MPerBlock,
VecLoadSize,
getATileAccessPattern()>;
// AK1
constexpr auto AK1 = number<VecLoadSize>{};
constexpr auto AK0 = number<KPerBlock / AK1>{};
// How the M dimension is split across threads
constexpr auto M0 = TileEncodingPattern::X0; // # of threads in M dim
constexpr auto M1 = number<MPerBlock / M0>{};
constexpr index_t KPack = GetSmemPackA<Problem>();
// Get the warp tile size
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr auto MPerXdl = number<WarpTile::at(I0)>{};
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr auto MLdsLayer =
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
// How many elements we can write by single thread to LDS,
// the transposed / shuffled tile dstr has size: <X1, Y2>
constexpr auto KThreadWrite = TileEncodingPattern::Y2;
constexpr auto K0PerThreadWrite = integer_divide_ceil(AK0, KThreadWrite);
constexpr auto KThreadRead = get_warp_size() / MPerXdl;
constexpr auto K0PerThreadRead = integer_divide_ceil(AK0, KThreadRead);
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
number<MPerBlock / MLdsLayer>{},
number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto LdsBanksWidth = 128;
constexpr auto kfold = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth)
? 1
: LdsBanksWidth / (AK1 * M0 * sizeof(ADataType));
constexpr auto KThreadReadPerm =
((kfold * K0PerThreadWrite / K0PerThreadRead) > 1 &&
(kfold * K0PerThreadWrite / K0PerThreadRead) < KThreadRead)
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
number<KPerBlock / KPack * MLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
// 1<=mpair<=n0
constexpr auto mpair =
(AK1 * MPerXdl * sizeof(ADataType) > LdsBanksWidth)
? 1
: ((LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType))) > M0
? M0
: LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType)));
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * M1>{},
number<kfold * M0 / mpair>{},
number<mpair>{},
AK1),
AK1);
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(make_pass_through_transform(
number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(make_tuple(number<KThreadReadPerm * M1>{},
number<kfold * M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(AK1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2, 3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2, 3>{},
sequence<4>{},
sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(
number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<M1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(AK1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
AK1)),
make_merge_transform_v3_division_mod(make_tuple(
number<M0 / mpair>{}, number<mpair>{}, number<M1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return a_lds_block_desc_ak0_m_ak1;
}
else // A is in RowMajor
{
constexpr auto MLdsLayer = (32 * 4 / KPerBlock / DataTypeSize) < 1
? 1
: (32 * 4 / KPerBlock / DataTypeSize);
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
number<MPerBlock / MLdsLayer>{},
number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(
make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
number<KPerBlock / KPack * MLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(
number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
return a_lds_block_desc;
}
}
@@ -268,12 +143,12 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
#if 1
if constexpr(is_b_load_tr<Problem>)
{
// TODO: better lds descriptor for performance
@@ -285,169 +160,178 @@ struct UniversalGemmBasePolicy
return b_lds_block_desc_0;
}
else
// else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
// Only use this RowMajor layout for Wave64 mode (gfx9)
constexpr auto Wave64 = get_warp_size() == 64;
if constexpr(Wave64 && std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
getBTileAccessPattern()>;
// BK1
constexpr auto BK1 = number<VecLoadSize>{};
constexpr auto BK0 = number<KPerBlock / BK1>{};
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto BK0 = number<KPerBlock / KPack>{};
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr auto NLdsLayer =
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
// How threads access data on N dim
constexpr auto N0 = TileEncodingPattern::X0; // # of threads in N dim
constexpr auto N1 = number<NPerBlock / N0>{};
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(
BK0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
number<KPack>{},
number<1>{});
// Get NPerXdl, the warp tile size
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr auto NPerXdl = number<WarpTile::at(I1)>{};
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
BK0 * number<NLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
// How many elements we can write by single thread to LDS,
// the transposed / shuffled tile dstr has size: <X1, Y2>
constexpr auto KThreadWrite = TileEncodingPattern::Y2;
constexpr auto K0PerThreadWrite = integer_divide_ceil(BK0, KThreadWrite);
constexpr auto KThreadRead = get_warp_size() / NPerXdl;
constexpr auto K0PerThreadRead = integer_divide_ceil(BK0, KThreadRead);
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(number<NLdsLayer>{}, BK0)),
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
// check if we exceed all 32banks width - (32x4B)
constexpr auto LdsBanksWidth = 128;
constexpr auto kfold = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth)
? 1
: LdsBanksWidth / (BK1 * N0 * sizeof(BDataType));
constexpr auto KThreadReadPerm =
((kfold * K0PerThreadWrite / K0PerThreadRead) > 1 &&
(kfold * K0PerThreadWrite / K0PerThreadRead) < KThreadRead)
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=n0
constexpr auto npair =
(BK1 * NPerXdl * sizeof(BDataType) > LdsBanksWidth)
? 1
: ((LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType))) > N0
? N0
: LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * N1>{},
number<kfold * N0 / npair>{},
number<npair>{},
BK1),
BK1);
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(make_pass_through_transform(
number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(make_tuple(number<KThreadReadPerm * N1>{},
number<kfold * N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(BK1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2, 3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2, 3>{},
sequence<4>{},
sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(
number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(BK1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(
sequence<1>{}, // 0: K0PerThreadWrite
sequence<2>{}, // 1: KThreadReadPerm
sequence<0, 3>{}, // 2: KThreadWrite / kfold / KThreadReadPerm, 3: N1
sequence<4, 5>{}, // 4: kfold, 5: N0 / npair
sequence<6>{}, // 6: npair
sequence<7>{})); // 7: BK1
constexpr auto b_lds_block_desc_nk = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
BK1)),
make_merge_transform_v3_division_mod(make_tuple(
number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return b_lds_block_desc_nk;
}
else // B is Column Major
{
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto BK0 = number<KPerBlock / KPack>{};
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr auto NLdsLayer = (32 * 4 / KPerBlock / DataTypeSize) < 1
? 1
: (32 * 4 / KPerBlock / DataTypeSize);
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(BK0 * number<NLdsLayer>{},
number<NPerBlock / NLdsLayer>{},
number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
BK0 * number<NLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(number<NLdsLayer>{}, BK0)),
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_bk0_nldslayer_n_bk1,
make_tuple(
make_merge_transform_v3_division_mod(
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform_v3_division_mod(make_tuple(BK0, number<KPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_bk0_nldslayer_n_bk1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform_v3_division_mod(make_tuple(BK0, number<KPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
#else
else // B is Row Major
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
getBTileAccessPattern()>;
constexpr auto BK0 = number<TileEncodingPattern::X1>{};
constexpr auto BK1 = number<TileEncodingPattern::Y0>{};
// constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr auto N0 = TileEncodingPattern::X0;
constexpr auto N1 = NPerBlock / N0;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr auto NPerXdl = number<WarpTile::at(I1)>{};
// constexpr auto KThreadWrite =
// BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr auto KThreadWrite = TileEncodingPattern::Y2;
constexpr auto K0PerThreadWrite = BK0 / KThreadWrite;
constexpr auto KThreadRead = 64 / NPerXdl;
constexpr auto K0PerThreadRead = BK0 / KThreadRead;
constexpr auto kfold =
(BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=n0
constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128)
? 1
: ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0
? N0
: 128 / (BK1 * NPerXdl * sizeof(BDataType)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * N1>{},
number<kfold * N0 / npair>{},
number<npair>{},
BK1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * N1>{}, number<kfold * N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(BK1)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(BK1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
// constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
// b_lds_block_desc_unmerged,
// make_tuple(make_merge_transform_v3_division_mod(
// make_tuple(number<KThreadReadPerm>{},
// number<KThreadWrite / kfold / KThreadReadPerm>{},
// number<kfold>{},
// number<K0PerThreadWrite>{})),
// make_merge_transform_v3_division_mod(
// make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{})),
// make_pass_through_transform(BK1)),
// make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}),
// make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
BK1)),
make_merge_transform_v3_division_mod(
make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
// return b_lds_block_desc_bk0_n_bk1;
return b_lds_block_desc_kn;
// constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor(
// make_tuple(BK0, number<NPerBlock>{}, number<KPack>{}),
// make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
// number<KPack>{},
// number<1>{});
// constexpr auto b_lds_block_desc = transform_tensor_descriptor(
// b_lds_block_desc_bk0_n_bk1,
// make_tuple(make_pass_through_transform(number<NPerBlock>{}),
// make_merge_transform_v3_division_mod(make_tuple(BK0,
// number<KPack>{}))),
// make_tuple(sequence<1>{}, sequence<0, 2>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
// return b_lds_block_desc;
}
#endif
}
/**