init moe mx f4 scale shuffle

This commit is contained in:
mtgu0705
2025-05-16 14:46:09 -05:00
parent ec8d00d58d
commit 94fb9190be
3 changed files with 165 additions and 95 deletions

View File

@@ -82,6 +82,7 @@ struct MulABScaleExpertWeight
using CDEElementOp = MulABScaleExpertWeight;
// B preshuffle
void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
{
int KPack = 32;
@@ -113,6 +114,54 @@ void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
}
}
// A, B Scale preshuffle
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
@@ -286,20 +335,27 @@ int main(int argc, char* argv[])
Tensor<B1DataType> b1_e_n_k(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
// B preshuffle
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
// A, B Scale preshuffle
Tensor<A1DataType> a_scale_sorted(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<A1DataType> a_scale_preshuffled(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<B1DataType> b_scale_preshuffled(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero();
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method)
@@ -310,8 +366,6 @@ int main(int argc, char* argv[])
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_2<A1DataType>{0, 1});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 1});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-1, 1});
break;
case 2:
@@ -319,8 +373,6 @@ int main(int argc, char* argv[])
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 3:
@@ -328,8 +380,6 @@ int main(int argc, char* argv[])
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
case 4:
@@ -337,8 +387,6 @@ int main(int argc, char* argv[])
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
case 5:
@@ -346,8 +394,6 @@ int main(int argc, char* argv[])
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 6:
@@ -355,8 +401,6 @@ int main(int argc, char* argv[])
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
@@ -364,8 +408,6 @@ int main(int argc, char* argv[])
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0.0, 1.0});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
@@ -373,22 +415,42 @@ int main(int argc, char* argv[])
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize() / 2);
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k_k.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(A1DataType) * a_scale_sorted.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2);
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{
int tokenid = sorted_token_ids.mData[i] & 0x00FFFFFF;
int topkid = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
{
if(tokenid = = tokens)
{
a_scale_sorted(i, k) = 0;
}
else
{
a_scale_sorted(i, k) = a1_t_k_k(tokenid, topkid, k);
}
}
}
preShuffleBuffer<ck::is_same_v<A0Layout, Row>>(
a_scale_sorted.mData.data(), a_scale_preshuffled.mData.data(), sorted_size, K);
preShuffleBuffer<ck::is_same_v<B0Layout, Row>>(
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K);
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
a1_device_buf.ToDevice(a1_t_k_k.mData.data());
b1_device_buf.ToDevice(b1_e_n_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data());
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data());

View File

@@ -187,12 +187,23 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
static constexpr auto ScalesPerXdlopsRun =
(APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRunPerThread =
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
using mx_scale_t = e8m0_bexp_t;
static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
"A scale pack data type too large!");
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
"B scale pack data type too large!");
static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a;
static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;

View File

@@ -190,6 +190,10 @@ struct GridwiseMoeGemmMX
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto MXdlPack = 2;
static constexpr auto NXdlPack = 2;
static constexpr auto KXdlPack = 2;
static constexpr bool is_single_rate_mfma = false;
static constexpr auto is_scale_mfma = true;
using mfma_selector = MfmaSelector<ComputeTypeA,
@@ -198,8 +202,8 @@ struct GridwiseMoeGemmMX
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk / APackedSize);
static constexpr index_t KLane =
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
@@ -209,10 +213,6 @@ struct GridwiseMoeGemmMX
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static constexpr index_t MWave = MPerBlock / MPerXdl / MXdlPerWave;
static constexpr auto ScalesPerXdlopsRun =
(KPack * mfma_selector::selected_mfma.num_input_blks) / ScaleBlockSize;
static constexpr auto ScalesPerXdlopsRunPerThread =
ScalesPerXdlopsRun / mfma_selector::selected_mfma.num_input_blks;
// static constexpr index_t NumTokens = 1;
static constexpr index_t SortedTileSize = MPerBlock;
@@ -712,10 +712,10 @@ struct GridwiseMoeGemmMX
TopK_,
M_,
N_,
K_,
StrideA_,
K_ / APackedSize,
StrideA_ / APackedSize,
StrideScaleA_,
StrideB_,
StrideB_ / APackedSize,
StrideScaleB_,
StrideDs_,
StrideC_,
@@ -784,21 +784,23 @@ struct GridwiseMoeGemmMX
// Calculate A scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize;
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize * karg.StrideScaleA;
a_scale_k_split_offset =
k_id * karg.KRead / (ScaleBlockSize / PackedSize) * karg.StrideScaleA;
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_scale_k_split_offset = k_id * (karg.KRead / ScaleBlockSize) * karg.StrideScaleB;
b_scale_k_split_offset =
k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize;
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
}
if(k_id < karg.KBatch - 1)
@@ -1011,6 +1013,9 @@ struct GridwiseMoeGemmMX
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
"KPerBlock should be multiple of ScaleBlockSize");
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
@@ -1211,6 +1216,14 @@ struct GridwiseMoeGemmMX
// using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
// NPerBlock>;
using mx_scale_t = e8m0_bexp_t;
static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
"A scale pack data type too large!");
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
"B scale pack data type too large!");
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
@@ -1246,17 +1259,17 @@ struct GridwiseMoeGemmMX
problem.NPadded,
problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
math::integer_divide_ceil(problem.K, ScaleBlockSize) /
ScalesPerXdlopsRunPerThread,
ScalesPerXdlopsRunPerThread),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize),
ScalesPerXdlopsRunPerThread,
1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1));
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
(KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
make_tuple(problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
(KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -1331,6 +1344,7 @@ struct GridwiseMoeGemmMX
p_b_grid + expert_id * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
// A, B scale buffer
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -1430,60 +1444,43 @@ struct GridwiseMoeGemmMX
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) /
mfma.selected_mfma.num_threads_per_blk;
auto thread_offset_shuffled =
get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl;
auto a_thread_offset_m = waveId_m;
// get each thread's offset int the scale tensor
const index_t token_scale_pos = block_m_id * MPerBlock;
if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
const index_t fused_token =
p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWave + a_thread_offset_m];
index_t token_offset = fused_token & 0xffffff;
if constexpr(!IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
scale_gather_offsets(m0) =
token_offset * math::integer_divide_ceil(problem.K, ScaleBlockSize);
});
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2_gather<
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
AScaleDataType,
AScaleDataType,
decltype(a_scale_grid_desc_am_ak),
decltype(BlockwiseGemmPipe::a_scale_thread_desc),
Sequence<1, 1, 1>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
1, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true,
MXdlPerWave,
KRepeat>(
a_scale_grid_desc_am_ak, make_multi_index(0, thread_offset_k, 0), scale_gather_offsets);
Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(a_scale_grid_desc_am_ak,
make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
0,
thread_offset_shuffled / scale_pack_size_a));
// B scale load
auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl;
auto b_thread_offset_n = waveId_n;
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleDataType,
BScaleDataType,
decltype(b_scale_grid_desc_bn_ak),
decltype(BlockwiseGemmPipe::b_scale_thread_desc_copy),
Sequence<1, 1>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
1, // SrcScalarPerVector
1,
true>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k));
auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
BScaleDataType,
BScaleDataType,
decltype(b_scale_grid_desc_bn_ak),
decltype(BlockwiseGemmPipe::b_scale_thread_desc),
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
0,
thread_offset_shuffled / scale_pack_size_b));
if constexpr(IsInputGemm)
{