update code

This commit is contained in:
mtgu0705
2025-05-17 09:28:26 -05:00
parent 94fb9190be
commit eeeba8901f
6 changed files with 749 additions and 537 deletions

View File

@@ -24,19 +24,20 @@
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F4 = ck::f4x2_pk_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using F4 = ck::f4x2_pk_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F4;
using A1DataType = XDataType;
using A1DataType = XPackedDataType;
using B0DataType = F4;
using B1DataType = XDataType;
using B1DataType = XPackedDataType;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
@@ -170,7 +171,9 @@ using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
#if 0
static constexpr ck::index_t MPerBlock = 128;
@@ -213,14 +216,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 128, 128,
32, 32,
MPerBlock, 256, KPerBlock,
16, 16,
8, 2,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
16, 16,
8, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
#endif
@@ -328,22 +331,22 @@ int main(int argc, char* argv[])
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
Tensor<A1DataType> a1_t_k_k(
Tensor<XDataType> a1_t_k_k(
HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize},
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<B1DataType> b1_e_n_k(
Tensor<XDataType> b1_e_n_k(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
// B preshuffle
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
// A, B Scale preshuffle
Tensor<A1DataType> a_scale_sorted(HostTensorDescriptor(
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<A1DataType> a_scale_preshuffled(HostTensorDescriptor(
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<B1DataType> b_scale_preshuffled(
Tensor<XDataType> b_scale_preshuffled(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
@@ -364,50 +367,50 @@ int main(int argc, char* argv[])
case 1:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_2<A1DataType>{0, 1});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 1});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_2<XDataType>{0, 1});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_2<XDataType>{0, 1});
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-1, 1});
break;
case 2:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 3:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
case 4:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
case 5:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 6:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0.0, 1.0});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
@@ -415,35 +418,37 @@ int main(int argc, char* argv[])
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize() / 2);
DeviceMem a1_device_buf(sizeof(A1DataType) * a_scale_sorted.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2);
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{
int tokenid = sorted_token_ids.mData[i] & 0x00FFFFFF;
int topkid = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
{
if(tokenid = = tokens)
if(token_id == tokens)
{
a_scale_sorted(i, k) = 0;
}
else
{
a_scale_sorted(i, k) = a1_t_k_k(tokenid, topkid, k);
a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k);
}
}
}
preShuffleBuffer<ck::is_same_v<A0Layout, Row>>(
a_scale_sorted.mData.data(), a_scale_preshuffled.mData.data(), sorted_size, K);
preShuffleBuffer<ck::is_same_v<B0Layout, Row>>(
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K);
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Row>>(
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize);
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
@@ -614,9 +619,9 @@ int main(int argc, char* argv[])
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
A1DataType,
XDataType,
B0DataType,
B1DataType,
XDataType,
D2DataType,
CShuffleDataType,
AccDataType,

View File

@@ -76,6 +76,12 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
// Hardcode to 2, for better 8-bit access pattern
static constexpr index_t MXdlPack = 2;
static constexpr index_t NXdlPack = 2;
static constexpr index_t KXdlPack = 2;
using HotLoopInstList = ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<
BlockSize,
MPerBlock,
@@ -127,7 +133,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]);
return make_tuple(0, waveId_m, 0, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
@@ -138,7 +144,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]);
return make_tuple(0, waveId_n, 0, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
@@ -170,7 +176,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
return make_tuple(c_thread_m, c_thread_n);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
/**
* @brief Constructor for BlockwiseGemmXdlops_mx_pipeline_base.
@@ -190,8 +196,8 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
* repeat dimensions.
*/
__host__ __device__
BlockwiseGemmXdlops_mx_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin = CalculateAThreadOriginDataIndex(),
Tuple5 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
@@ -327,49 +333,63 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
__host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_m3_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_n3_k;
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack / APackedSize>{}),
make_tuple(Number<KPack / APackedSize>{},
Number<KRepeat * MRepeat * KPack / APackedSize>{},
Number<MRepeat * KPack / APackedSize>{},
I1));
// A[M0, M1, M2, M3, KPack]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<MRepeat / MXdlPack>{},
I1,
Number<MXdlPack>{},
Number<KRepeat>{},
Number<KPack>{}),
make_tuple(Number<KPack * MXdlPack>{},
Number<KRepeat * MRepeat * KPack>{},
Number<MRepeat * KPack>{},
Number<KPack>{},
I1));
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack / BPackedSize>{}),
make_tuple(Number<KPack / BPackedSize>{},
Number<KRepeat * NRepeat * KPack / BPackedSize>{},
Number<NRepeat * KPack / BPackedSize>{},
I1));
// B[N0, N1, N2, N3, KPack]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<NRepeat / NXdlPack>{},
I1,
Number<KRepeat>{},
Number<NXdlPack>{},
Number<KPack>{}),
make_tuple(Number<KPack * NXdlPack>{},
Number<KRepeat * NRepeat * KPack>{},
Number<NRepeat * KPack>{},
Number<KPack>{},
I1));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat / MXdlPack>{},
Number<NRepeat / NXdlPack>{},
Number<MXdlPack>{},
Number<NXdlPack>{},
xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_block_desc_m0_m1_m2_m3_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3>,
3,
Sequence<1, 1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3, 4>,
4,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_block_desc_n0_n1_n2_n3_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3>,
3,
Sequence<1, 1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3, 4>,
4,
B_K1,
B_K1>;

View File

@@ -495,7 +495,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
index_t i = 0;
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf, auto a_buf) {
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
// Prefetch a_scales to buf 1
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
@@ -683,8 +683,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
__builtin_amdgcn_sched_barrier(0);
}; // LoopFunc
LoopFunc(I0, I1, I0);
LoopFunc(I1, I0, I1);
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));

View File

@@ -277,16 +277,35 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Prefetch a_scales to buf 0
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(I0, I0, I0),
a_scale_thread_bufs(I0));
// Prefetch a_scales
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
auto a_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_copy.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_copy,
make_tuple(I0, I0),
a_scale_thread_buf_copy);
a_scale_thread_buf(I0)(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, ScalesPerKBlockSize, 0));
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
// Prefetch b_scales to buf 0
static_for<0, NRepeat, 1>{}([&](auto n0) {
@@ -329,15 +348,34 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineSched
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Prefetch a_scales to buf 1
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(I0, I0, I0),
a_scale_thread_bufs(I1));
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
auto a_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_copy.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_copy,
make_tuple(I0, I0),
a_scale_thread_buf_copy);
a_scale_thread_buf(I1)(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, ScalesPerKBlockSize, 0));
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
// Prefetch b_scales to buf 1
static_for<0, NRepeat, 1>{}([&](auto n0) {

View File

@@ -194,6 +194,22 @@ struct GridwiseMoeGemmMX
static constexpr auto NXdlPack = 2;
static constexpr auto KXdlPack = 2;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr bool is_single_rate_mfma = false;
static constexpr auto is_scale_mfma = true;
using mfma_selector = MfmaSelector<ComputeTypeA,
@@ -232,22 +248,6 @@ struct GridwiseMoeGemmMX
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
__host__ static auto CalculateGridSize(index_t M, index_t N)
{
const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
@@ -317,7 +317,11 @@ struct GridwiseMoeGemmMX
return math::integer_divide_ceil(N, NPerBlock);
}
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
template <index_t MNXdlPerWave,
index_t MNWaves,
index_t MNXdlPack,
index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
@@ -326,10 +330,12 @@ struct GridwiseMoeGemmMX
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_unmerge_transform(make_tuple(Number<MNXdlPerWave / MNXdlPack>{},
Number<MNWaves>{},
Number<MNXdlPack>{},
Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{}));
}
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
@@ -513,16 +519,18 @@ struct GridwiseMoeGemmMX
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
{
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWave, MPerXdl>(ABlockDesc_AK0_M_AK1{});
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWave, MXdlPack, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
{
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NXdlPack, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
template <typename ELayout>
@@ -789,7 +797,7 @@ struct GridwiseMoeGemmMX
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_scale_k_split_offset =
k_id * karg.KRead / (ScaleBlockSize / PackedSize) * karg.StrideScaleA;
k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
}
// Calculate B scale offset
@@ -939,8 +947,11 @@ struct GridwiseMoeGemmMX
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
return make_naive_tensor_descriptor_packed(
make_tuple(Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}));
return make_naive_tensor_descriptor_packed(make_tuple(Number<NXdlPerWave / NXdlPack>{},
I1,
Number<NXdlPack>{},
Number<KRepeat>{},
Number<BK1Value>{}));
}
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
@@ -969,9 +980,9 @@ struct GridwiseMoeGemmMX
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())),
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())),
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
@@ -1395,21 +1406,26 @@ struct GridwiseMoeGemmMX
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
3,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
auto b_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave / NXdlPack>{},
I1,
Number<NXdlPack>{},
Number<KRepeat>{},
Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
4,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1511,15 +1527,17 @@ struct GridwiseMoeGemmMX
BScaleDataType,
BScaleDataType,
decltype(b_scale_grid_desc_bn_ak),
decltype(BlockwiseGemmPipe::b_scale_thread_desc_copy),
Sequence<1, 1>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
1, // SrcScalarPerVector
1,
decltype(BlockwiseGemmPipe::b_scale_thread_desc),
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k));
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
0,
thread_offset_shuffled / scale_pack_size_b));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
@@ -1958,17 +1976,17 @@ struct GridwiseMoeGemmMX
problem.NPadded,
problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
math::integer_divide_ceil(problem.K, ScaleBlockSize) /
ScalesPerXdlopsRunPerThread,
ScalesPerXdlopsRunPerThread),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize),
ScalesPerXdlopsRunPerThread,
1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockSize), 1));
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
(KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
make_tuple(problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
(KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -2145,62 +2163,48 @@ struct GridwiseMoeGemmMX
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
auto thread_offset_shuffled =
get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) /
mfma.selected_mfma.num_threads_per_blk;
auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl;
auto a_thread_offset_m = waveId_m;
// get each thread's offset int the scale tensor
const index_t token_scale_pos = block_m_id * MPerBlock;
if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, MXdlPerWave> scale_gather_offsets;
static_for<0, MXdlPerWave, 1>{}([&](auto m0) {
const index_t fused_token =
p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWave + a_thread_offset_m];
index_t token_offset = fused_token & 0xffffff;
if constexpr(!IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
scale_gather_offsets(m0) =
token_offset * math::integer_divide_ceil(problem.K, ScaleBlockSize);
});
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2_gather<
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
AScaleDataType,
AScaleDataType,
decltype(a_scale_grid_desc_am_ak),
decltype(BlockwiseGemmPipe::a_scale_thread_desc),
Sequence<1, 1, 1>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
1, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true,
MXdlPerWave,
KRepeat>(
a_scale_grid_desc_am_ak, make_multi_index(0, thread_offset_k, 0), scale_gather_offsets);
Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(a_scale_grid_desc_am_ak,
make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
0,
thread_offset_shuffled / scale_pack_size_a));
// B scale load
auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl;
auto b_thread_offset_n = waveId_n;
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleDataType,
BScaleDataType,
decltype(b_scale_grid_desc_bn_ak),
decltype(BlockwiseGemmPipe::b_scale_thread_desc_copy),
Sequence<1, 1>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
1, // SrcScalarPerVector
1,
true>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k));
auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
BScaleDataType,
BScaleDataType,
decltype(b_scale_grid_desc_bn_ak),
decltype(BlockwiseGemmPipe::b_scale_thread_desc),
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
0,
thread_offset_shuffled / scale_pack_size_b));
if constexpr(IsInputGemm)
{
@@ -2231,15 +2235,18 @@ struct GridwiseMoeGemmMX
BScaleDataType,
BScaleDataType,
decltype(b_scale_grid_desc_bn_ak),
decltype(BlockwiseGemmPipe::b_scale_thread_desc_copy),
Sequence<1, 1>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
1, // SrcScalarPerVector
1,
decltype(BlockwiseGemmPipe::b_scale_thread_desc),
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k));
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
0,
thread_offset_shuffled / scale_pack_size_b));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,