Enable splitk for mxfp4; clang format;

This commit is contained in:
aska-0096
2025-06-02 12:23:01 +00:00
parent 5696e3c9f5
commit dd24786f78
12 changed files with 191 additions and 2486 deletions

View File

@@ -108,7 +108,6 @@ bool parse_cmd_args(int argc,
return true;
}
#if 1
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
@@ -146,8 +145,9 @@ void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, i
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f,
// 2-k)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
@@ -186,7 +186,6 @@ void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K
}
}
}
#endif
template <typename DeviceOpInstance,
typename ADataType,
@@ -346,7 +345,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
}
}
#if 1
preShuffleScaleBuffer<ck::is_same_v<ALayout, Row>>(a_m_k_scale.mData.data(),
a_shuffled_scale.mData.data(),
Scale_Padded_M,
@@ -358,48 +356,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
int NPerXdl = 16; // Fixed 16
preShuffleBuffer(b_k_n->mData.data(), b_input->mData.data(), N, K, NPerXdl);
}
#endif
// printf("a_scale:\n");
// for(ck::index_t i = 0; i < M; i++)
// {
// for(ck::index_t j = 0; j < K / ScaleBlockSize; j++)
// {
// // a_m_k_scale(i, j) =
// // ck::type_convert<XDataType>(static_cast<float>(powf(2.0f, (j / 4) % 4)));
// // a_m_k_scale(i, j) =ck::type_convert<XDataType>(static_cast<float>(1.0f));
// // a_shuffled_scale(i, j) =ck::type_convert<XDataType>(static_cast<float>(1.0f));
// printf("%02x ", *reinterpret_cast<uint8_t*>(&a_m_k_scale(i, j)));
// }
// printf("\n");
// }
// printf("b_scale:\n");
// for(ck::index_t i = 0; i < N; i++)
// {
// for(ck::index_t j = 0; j < K / ScaleBlockSize; j++)
// {
// // // b_k_n_scale(j, i) =
// // // ck::type_convert<XDataType>(static_cast<float>(powf(2.0f, (j / 4) % 4)));
// // b_k_n_scale(j, i) =ck::type_convert<XDataType>(static_cast<float>(1.0f));
// // b_shuffled_scale(j, i) =ck::type_convert<XDataType>(static_cast<float>(1.0f));
// printf("%02x ", *reinterpret_cast<uint8_t*>(&b_k_n_scale(j, i)));
// }
// printf("\n");
// }
// printf("a_shuffled_scale:\n");
// for(ck::index_t i = 0; i < M * K / ScaleBlockSize; i++)
// {
// printf("%02x ", *reinterpret_cast<uint8_t*>(&(a_shuffled_scale.mData.data()[i])));
// if(i % 64 == 63)
// printf("\n");
// }
// printf("b_shuffled_scale:\n");
// for(ck::index_t i = 0; i < N * K / ScaleBlockSize; i++)
// {
// printf("%02x ", *reinterpret_cast<uint8_t*>(&(b_shuffled_scale.mData.data()[i])));
// if(i % 64 == 63)
// printf("\n");
// }
if(config.verbosity > 0)
std::cout << "Device memory allocation..." << std::endl;
@@ -524,9 +480,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
// << std::endl;
// }
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!");
res_verified =
res_verified &&
ck::utils::check_err(
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 5e-1, 5e-1);
if(config.verbosity > 0 && res_verified)
std::cout << "Verification Successful!" << std::endl;

View File

@@ -484,13 +484,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});

View File

@@ -465,12 +465,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});

View File

@@ -254,16 +254,21 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
constexpr auto buffer_load_issue_point_b = 0;
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more ? num_mfma_perstage / buffer_load_perstage_more : 1;
num_mfma_perstage / buffer_load_perstage_more
? num_mfma_perstage / buffer_load_perstage_more
: 1;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less ? num_mfma_perstage / buffer_load_perstage_less : 1;
num_mfma_perstage / buffer_load_perstage_less
? num_mfma_perstage / buffer_load_perstage_less
: 1;
constexpr auto ds_write_issue_point = 0;
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
// B global read
static_for<0, buffer_load_b_stages, 1>{}([&](auto i) {
// Scale load, 1B
if constexpr (i.value==0){
if constexpr(i.value == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
// Scale load, 1A
@@ -330,7 +335,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// Scale load, 1A
if constexpr(imfma == 0){
if constexpr(imfma == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
@@ -426,7 +432,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
a_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
b_scale_thread_desc.GetElementSpaceSize());
@@ -483,12 +489,12 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
b_scale_thread_bufs(I0));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
b_scale_thread_copy_up.Run(b_scale_grid_desc,
b_scale_grid_buf_up,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_bufs_up(I0));
b_scale_grid_buf_up,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_bufs_up(I0));
b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
@@ -496,7 +502,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf_up(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0];
c_scale_thread_buf_up(m0) =
a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0];
});
// Local prefill A1
@@ -532,10 +539,10 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
b_scale_thread_copy_up.Run(b_scale_grid_desc,
b_scale_grid_buf_up,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_bufs_up(I0));
b_scale_grid_buf_up,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_bufs_up(I0));
b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
@@ -574,7 +581,6 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
});
});
__builtin_amdgcn_sched_barrier(0);
// main body
@@ -609,13 +615,13 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
if constexpr(NumKBlockPerScale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{}));
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
@@ -627,13 +633,13 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
b_scale_thread_copy_step);
b_scale_thread_copy_up.Run(b_scale_grid_desc,
b_scale_grid_buf_up,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_bufs_up(local_read_buf));
b_scale_grid_buf_up,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_bufs_up(local_read_buf));
b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step);
b_scale_thread_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
vector_type<AccDataType, 2> c_scale_thread_vec;
@@ -676,8 +682,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
@@ -711,7 +717,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
c_scale_thread_vec_up
.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
@@ -800,8 +807,10 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs[mfma_reg_buf][I0];
c_scale_thread_buf_up(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs_up[mfma_reg_buf][I0];
c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] *
b_scale_thread_bufs[mfma_reg_buf][I0];
c_scale_thread_buf_up(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] *
b_scale_thread_bufs_up[mfma_reg_buf][I0];
});
HotLoopScheduler();
@@ -824,13 +833,13 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
b_block_origin_idx,
b_thread_bufs(I1));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I1));
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I1));
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
vector_type<AccDataType, 2> c_scale_thread_vec;
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[m0];
@@ -970,7 +979,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
c_scale_thread_buf_up(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0];
c_scale_thread_buf_up(m0) =
a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0];
});
static_for<0, MRepeat, 1>{}([&](auto m0) {

View File

@@ -43,56 +43,56 @@ constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
if constexpr(GUFusion)
{
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1<
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MScaleBlock,
NScaleBlock,
KScaleBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MScaleBlock,
NScaleBlock,
KScaleBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MScaleBlock,
NScaleBlock,
KScaleBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MScaleBlock,
NScaleBlock,
KScaleBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
}
#if 0

View File

@@ -254,16 +254,21 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
constexpr auto buffer_load_issue_point_b = 0;
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more ? num_mfma_perstage / buffer_load_perstage_more : 1;
num_mfma_perstage / buffer_load_perstage_more
? num_mfma_perstage / buffer_load_perstage_more
: 1;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less ? num_mfma_perstage / buffer_load_perstage_less : 1;
num_mfma_perstage / buffer_load_perstage_less
? num_mfma_perstage / buffer_load_perstage_less
: 1;
constexpr auto ds_write_issue_point = 0;
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
// B global read
static_for<0, buffer_load_b_stages, 1>{}([&](auto i) {
// Scale load, 1B
if constexpr (i.value==0){
if constexpr(i.value == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
// Scale load, 1A
@@ -330,7 +335,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// Scale load, 1A
if constexpr(imfma == 0){
if constexpr(imfma == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
@@ -420,7 +426,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
a_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
b_scale_thread_desc.GetElementSpaceSize());
@@ -586,13 +592,13 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
if constexpr(NumKBlockPerScale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{}));
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
@@ -744,7 +750,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs[mfma_reg_buf][I0];
c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] *
b_scale_thread_bufs[mfma_reg_buf][I0];
});
HotLoopScheduler();
@@ -768,7 +775,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
b_thread_bufs(I1));
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
vector_type<AccDataType, 2> c_scale_thread_vec;
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[m0];

View File

@@ -449,84 +449,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
// Global prefetch 2
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
#if defined(__gfx950__) && 0
printf("Tid: %02d, NRepeat0, B 00-31 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x "
"%02x %02x %02x %02x %02x %02x\n"
"Tid: %02d, NRepeat0, B 32-63 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x "
"%02x %02x %02x %02x %02x %02x\n"
"Tid: %02d, NRepeat1, B 00-31 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x "
"%02x %02x %02x %02x %02x %02x\n"
"Tid: %02d, NRepeat1, B 32-63 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x "
"%02x %02x %02x %02x %02x %02x\n",
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<7>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 7>{}))),
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 7>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 7>{}))),
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 7>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 7>{}))),
get_thread_local_1d_id(),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 7>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 0>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 1>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 2>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 3>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 4>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 5>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 6>{}))),
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 7>{}))));
#endif
// Initialize C
c_thread_buf.Clear();
__builtin_amdgcn_sched_barrier(0);

View File

@@ -201,7 +201,7 @@ struct DeviceMoeGemmBlockScale
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto RunKernel = [&](const auto& kernel) {
const auto RunKernel = [&](const auto& kernel) {
if(stream_config.flush_cache)
{

View File

@@ -789,26 +789,12 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
}
// Calculate A scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_scale_k_split_offset =
k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
}
a_scale_k_split_offset =
k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack * MPerXdl;
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_scale_k_split_offset =
k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
}
b_scale_k_split_offset =
k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack * NPerXdl;
if(k_id < (karg.KBatch - 1))
{
@@ -1850,17 +1836,27 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
const auto Padded_Scale_M =
math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
(KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
64 * KXdlPack * MXdlPack / scale_pack_size_a),
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
(ScaleBlockSize / APackedSize)) *
MPerXdl * MXdlPack / scale_pack_size_a,
64 * KXdlPack * MXdlPack / scale_pack_size_a,
1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
(KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
64 * KXdlPack * NXdlPack / scale_pack_size_b),
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
(ScaleBlockSize / BPackedSize)) *
NPerXdl * NXdlPack / scale_pack_size_b,
64 * KXdlPack * NXdlPack / scale_pack_size_b,
1));
Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(a_scale_grid_desc_am_ak),
@@ -2362,17 +2358,27 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
const auto Padded_Scale_M =
math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
(KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
64 * KXdlPack * MXdlPack / scale_pack_size_a),
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
(ScaleBlockSize / APackedSize)) *
MPerXdl * MXdlPack / scale_pack_size_a,
64 * KXdlPack * MXdlPack / scale_pack_size_a,
1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
(KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
64 * KXdlPack * NXdlPack / scale_pack_size_b),
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
(ScaleBlockSize / BPackedSize)) *
NPerXdl * NXdlPack / scale_pack_size_b,
64 * KXdlPack * NXdlPack / scale_pack_size_b,
1));
Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
decltype(a_scale_grid_desc_am_ak),

View File

@@ -215,6 +215,14 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using mx_scale_t = e8m0_bexp_t;
static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
"A scale pack data type too large!");
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
"B scale pack data type too large!");
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
@@ -806,7 +814,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
{
if constexpr(!PermuteB)
{
b_k_split_offset = k_id * karg.KRead;
b_k_split_offset = k_id * karg.KRead * NPerXdl;
}
else
{
@@ -816,26 +824,12 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
}
// Calculate A scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_scale_k_split_offset =
k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
}
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack *
MPerXdl / scale_pack_size_a;
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_scale_k_split_offset =
k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
}
else if constexpr(is_same_v<tensor_layout::gemm::MFMA, BLayout>)
{
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
}
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack *
NPerXdl / scale_pack_size_b;
if(k_id < (karg.KBatch - 1))
{
@@ -1289,14 +1283,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
using mx_scale_t = e8m0_bexp_t;
static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
"A scale pack data type too large!");
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
"B scale pack data type too large!");
template <typename AGridDesc_AK0_M_K1,
typename AScaleGridDesc_AM_AK,
typename BGridDesc_BK0_N_K1,
@@ -2274,17 +2260,27 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
// We pad the M unconditionaly for Scale
const auto Padded_Scale_M =
math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
(KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
64 * KXdlPack * MXdlPack / scale_pack_size_a),
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
(ScaleBlockSize / APackedSize)) *
MPerXdl * MXdlPack / scale_pack_size_a,
64 * KXdlPack * MXdlPack / scale_pack_size_a,
1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
(KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
64 * KXdlPack * NXdlPack / scale_pack_size_b),
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
(ScaleBlockSize / BPackedSize)) *
NPerXdl * NXdlPack / scale_pack_size_b,
64 * KXdlPack * NXdlPack / scale_pack_size_b,
1));
Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
decltype(a_scale_grid_desc_am_ak),

View File

@@ -362,99 +362,6 @@ struct ThreadwiseTensorSliceTransfer_v2
}
}
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
__device__ void RunPrint(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
{
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
"wrong! DstSliceOrigin need to known at compile-time");
static_assert(
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
"wrong! inconsistent type");
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
// loop over tensor and copy
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type src_vector;
using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type::type;
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
printf("Tid: %03d, Ascale read gmem src_data_coord.GetOffset() = %d\n",
get_thread_local_1d_id(),
src_coord_.GetOffset());
// copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize,
is_src_valid);
// copy data from src_vector into dst_buf
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector);
if constexpr(InvalidElementAsNaN)
{
dst_buf(Number<dst_offset>{}) =
is_src_valid
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
: NumericLimits<DstData>::QuietNaN();
}
else
{
dst_buf(Number<dst_offset>{}) =
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
});
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
}
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
constexpr auto src_scalar_per_access = generate_sequence(