mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
init moe mx f4 scale shuffle
This commit is contained in:
@@ -82,6 +82,7 @@ struct MulABScaleExpertWeight
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
// B preshuffle
|
||||
void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 32;
|
||||
@@ -113,6 +114,54 @@ void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
|
||||
}
|
||||
}
|
||||
|
||||
// A, B Scale preshuffle
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
@@ -286,20 +335,27 @@ int main(int argc, char* argv[])
|
||||
Tensor<B1DataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
// B preshuffle
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
|
||||
|
||||
// A, B Scale preshuffle
|
||||
Tensor<A1DataType> a_scale_sorted(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<A1DataType> a_scale_preshuffled(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<B1DataType> b_scale_preshuffled(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
|
||||
e_t_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
|
||||
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
@@ -310,8 +366,6 @@ int main(int argc, char* argv[])
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_2<A1DataType>{0, 1});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 1});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-1, 1});
|
||||
break;
|
||||
case 2:
|
||||
@@ -319,8 +373,6 @@ int main(int argc, char* argv[])
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
@@ -328,8 +380,6 @@ int main(int argc, char* argv[])
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
break;
|
||||
case 4:
|
||||
@@ -337,8 +387,6 @@ int main(int argc, char* argv[])
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
break;
|
||||
case 5:
|
||||
@@ -346,8 +394,6 @@ int main(int argc, char* argv[])
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 6:
|
||||
@@ -355,8 +401,6 @@ int main(int argc, char* argv[])
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
default:
|
||||
@@ -364,8 +408,6 @@ int main(int argc, char* argv[])
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0.0, 1.0});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{}); // will to remove
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{}); // will to remove
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
|
||||
@@ -373,22 +415,42 @@ int main(int argc, char* argv[])
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize() / 2);
|
||||
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(A1DataType) * a_scale_sorted.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2);
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tokenid = sorted_token_ids.mData[i] & 0x00FFFFFF;
|
||||
int topkid = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
|
||||
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
|
||||
{
|
||||
if(tokenid = = tokens)
|
||||
{
|
||||
a_scale_sorted(i, k) = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_sorted(i, k) = a1_t_k_k(tokenid, topkid, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preShuffleBuffer<ck::is_same_v<A0Layout, Row>>(
|
||||
a_scale_sorted.mData.data(), a_scale_preshuffled.mData.data(), sorted_size, K);
|
||||
preShuffleBuffer<ck::is_same_v<B0Layout, Row>>(
|
||||
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K);
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
|
||||
a1_device_buf.ToDevice(a1_t_k_k.mData.data());
|
||||
b1_device_buf.ToDevice(b1_e_n_k.mData.data());
|
||||
d0_device_buf.ToDevice(d0_t_n.mData.data());
|
||||
d1_device_buf.ToDevice(d1_e_n.mData.data());
|
||||
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
|
||||
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
|
||||
@@ -187,12 +187,23 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
|
||||
|
||||
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
static constexpr auto ScalesPerXdlopsRun =
|
||||
(APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
|
||||
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
|
||||
|
||||
using mx_scale_t = e8m0_bexp_t;
|
||||
static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
|
||||
static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
|
||||
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
|
||||
"A scale pack data type too large!");
|
||||
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
|
||||
"B scale pack data type too large!");
|
||||
static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a;
|
||||
static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
|
||||
@@ -190,6 +190,10 @@ struct GridwiseMoeGemmMX
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto MXdlPack = 2;
|
||||
static constexpr auto NXdlPack = 2;
|
||||
static constexpr auto KXdlPack = 2;
|
||||
|
||||
static constexpr bool is_single_rate_mfma = false;
|
||||
static constexpr auto is_scale_mfma = true;
|
||||
using mfma_selector = MfmaSelector<ComputeTypeA,
|
||||
@@ -198,8 +202,8 @@ struct GridwiseMoeGemmMX
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>;
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
|
||||
static constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk / APackedSize);
|
||||
static constexpr index_t KLane =
|
||||
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
|
||||
|
||||
@@ -209,10 +213,6 @@ struct GridwiseMoeGemmMX
|
||||
static constexpr index_t NLane = NPerXdl;
|
||||
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
|
||||
static constexpr index_t MWave = MPerBlock / MPerXdl / MXdlPerWave;
|
||||
static constexpr auto ScalesPerXdlopsRun =
|
||||
(KPack * mfma_selector::selected_mfma.num_input_blks) / ScaleBlockSize;
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / mfma_selector::selected_mfma.num_input_blks;
|
||||
|
||||
// static constexpr index_t NumTokens = 1;
|
||||
static constexpr index_t SortedTileSize = MPerBlock;
|
||||
@@ -712,10 +712,10 @@ struct GridwiseMoeGemmMX
|
||||
TopK_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
K_ / APackedSize,
|
||||
StrideA_ / APackedSize,
|
||||
StrideScaleA_,
|
||||
StrideB_,
|
||||
StrideB_ / APackedSize,
|
||||
StrideScaleB_,
|
||||
StrideDs_,
|
||||
StrideC_,
|
||||
@@ -784,21 +784,23 @@ struct GridwiseMoeGemmMX
|
||||
// Calculate A scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize;
|
||||
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize * karg.StrideScaleA;
|
||||
a_scale_k_split_offset =
|
||||
k_id * karg.KRead / (ScaleBlockSize / PackedSize) * karg.StrideScaleA;
|
||||
}
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_scale_k_split_offset = k_id * (karg.KRead / ScaleBlockSize) * karg.StrideScaleB;
|
||||
b_scale_k_split_offset =
|
||||
k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize;
|
||||
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
|
||||
}
|
||||
|
||||
if(k_id < karg.KBatch - 1)
|
||||
@@ -1011,6 +1013,9 @@ struct GridwiseMoeGemmMX
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
|
||||
"KPerBlock should be multiple of ScaleBlockSize");
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
@@ -1211,6 +1216,14 @@ struct GridwiseMoeGemmMX
|
||||
// using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
|
||||
// NPerBlock>;
|
||||
|
||||
using mx_scale_t = e8m0_bexp_t;
|
||||
static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
|
||||
static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
|
||||
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
|
||||
"A scale pack data type too large!");
|
||||
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
|
||||
"B scale pack data type too large!");
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
@@ -1246,17 +1259,17 @@ struct GridwiseMoeGemmMX
|
||||
problem.NPadded,
|
||||
problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize) /
|
||||
ScalesPerXdlopsRunPerThread,
|
||||
ScalesPerXdlopsRunPerThread),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize),
|
||||
ScalesPerXdlopsRunPerThread,
|
||||
1));
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1));
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
|
||||
make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
|
||||
(KXdlPack * 64 / MPerXdl),
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a));
|
||||
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(problem.N / (NXdlPack * NPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
|
||||
(KXdlPack * 64 / NPerXdl),
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b));
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -1331,6 +1344,7 @@ struct GridwiseMoeGemmMX
|
||||
p_b_grid + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
// A, B scale buffer
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -1430,60 +1444,43 @@ struct GridwiseMoeGemmMX
|
||||
|
||||
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
|
||||
|
||||
auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) /
|
||||
mfma.selected_mfma.num_threads_per_blk;
|
||||
auto thread_offset_shuffled =
|
||||
get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
|
||||
|
||||
auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl;
|
||||
auto a_thread_offset_m = waveId_m;
|
||||
|
||||
// get each thread's offset int the scale tensor
|
||||
const index_t token_scale_pos = block_m_id * MPerBlock;
|
||||
if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
|
||||
StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
|
||||
const index_t fused_token =
|
||||
p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWave + a_thread_offset_m];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
if constexpr(!IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
scale_gather_offsets(m0) =
|
||||
token_offset * math::integer_divide_ceil(problem.K, ScaleBlockSize);
|
||||
});
|
||||
|
||||
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2_gather<
|
||||
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
|
||||
AScaleDataType,
|
||||
AScaleDataType,
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(BlockwiseGemmPipe::a_scale_thread_desc),
|
||||
Sequence<1, 1, 1>, // SliceLengths
|
||||
Sequence<0, 1, 2>, // DimAccessOrder
|
||||
2, // SrcVectorDim
|
||||
1, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true,
|
||||
MXdlPerWave,
|
||||
KRepeat>(
|
||||
a_scale_grid_desc_am_ak, make_multi_index(0, thread_offset_k, 0), scale_gather_offsets);
|
||||
Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
|
||||
Sequence<0, 1, 2>, // DimAccessOrder
|
||||
2, // SrcVectorDim
|
||||
KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>(a_scale_grid_desc_am_ak,
|
||||
make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
|
||||
0,
|
||||
thread_offset_shuffled / scale_pack_size_a));
|
||||
|
||||
// B scale load
|
||||
auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl;
|
||||
auto b_thread_offset_n = waveId_n;
|
||||
|
||||
auto b_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleDataType,
|
||||
BScaleDataType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(BlockwiseGemmPipe::b_scale_thread_desc_copy),
|
||||
Sequence<1, 1>, // SliceLengths
|
||||
Sequence<0, 1>, // DimAccessOrder
|
||||
1, // SrcVectorDim
|
||||
1, // SrcScalarPerVector
|
||||
1,
|
||||
true>(
|
||||
b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k));
|
||||
auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
|
||||
BScaleDataType,
|
||||
BScaleDataType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(BlockwiseGemmPipe::b_scale_thread_desc),
|
||||
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
|
||||
Sequence<0, 1, 2>, // DimAccessOrder
|
||||
2, // SrcVectorDim
|
||||
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>(b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
|
||||
0,
|
||||
thread_offset_shuffled / scale_pack_size_b));
|
||||
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user