mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Enable splitk for mxfp4; clang format;
This commit is contained in:
@@ -108,7 +108,6 @@ bool parse_cmd_args(int argc,
|
||||
return true;
|
||||
}
|
||||
|
||||
#if 1
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
@@ -146,8 +145,9 @@ void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, i
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f,
|
||||
// 2-k)));
|
||||
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
@@ -186,7 +186,6 @@ void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename DeviceOpInstance,
|
||||
typename ADataType,
|
||||
@@ -346,7 +345,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
}
|
||||
}
|
||||
|
||||
#if 1
|
||||
preShuffleScaleBuffer<ck::is_same_v<ALayout, Row>>(a_m_k_scale.mData.data(),
|
||||
a_shuffled_scale.mData.data(),
|
||||
Scale_Padded_M,
|
||||
@@ -358,48 +356,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
int NPerXdl = 16; // Fixed 16
|
||||
preShuffleBuffer(b_k_n->mData.data(), b_input->mData.data(), N, K, NPerXdl);
|
||||
}
|
||||
#endif
|
||||
// printf("a_scale:\n");
|
||||
// for(ck::index_t i = 0; i < M; i++)
|
||||
// {
|
||||
// for(ck::index_t j = 0; j < K / ScaleBlockSize; j++)
|
||||
// {
|
||||
// // a_m_k_scale(i, j) =
|
||||
// // ck::type_convert<XDataType>(static_cast<float>(powf(2.0f, (j / 4) % 4)));
|
||||
// // a_m_k_scale(i, j) =ck::type_convert<XDataType>(static_cast<float>(1.0f));
|
||||
// // a_shuffled_scale(i, j) =ck::type_convert<XDataType>(static_cast<float>(1.0f));
|
||||
// printf("%02x ", *reinterpret_cast<uint8_t*>(&a_m_k_scale(i, j)));
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
// printf("b_scale:\n");
|
||||
// for(ck::index_t i = 0; i < N; i++)
|
||||
// {
|
||||
// for(ck::index_t j = 0; j < K / ScaleBlockSize; j++)
|
||||
// {
|
||||
// // // b_k_n_scale(j, i) =
|
||||
// // // ck::type_convert<XDataType>(static_cast<float>(powf(2.0f, (j / 4) % 4)));
|
||||
// // b_k_n_scale(j, i) =ck::type_convert<XDataType>(static_cast<float>(1.0f));
|
||||
// // b_shuffled_scale(j, i) =ck::type_convert<XDataType>(static_cast<float>(1.0f));
|
||||
// printf("%02x ", *reinterpret_cast<uint8_t*>(&b_k_n_scale(j, i)));
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
|
||||
// printf("a_shuffled_scale:\n");
|
||||
// for(ck::index_t i = 0; i < M * K / ScaleBlockSize; i++)
|
||||
// {
|
||||
// printf("%02x ", *reinterpret_cast<uint8_t*>(&(a_shuffled_scale.mData.data()[i])));
|
||||
// if(i % 64 == 63)
|
||||
// printf("\n");
|
||||
// }
|
||||
// printf("b_shuffled_scale:\n");
|
||||
// for(ck::index_t i = 0; i < N * K / ScaleBlockSize; i++)
|
||||
// {
|
||||
// printf("%02x ", *reinterpret_cast<uint8_t*>(&(b_shuffled_scale.mData.data()[i])));
|
||||
// if(i % 64 == 63)
|
||||
// printf("\n");
|
||||
// }
|
||||
|
||||
if(config.verbosity > 0)
|
||||
std::cout << "Device memory allocation..." << std::endl;
|
||||
@@ -524,9 +480,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
// << std::endl;
|
||||
// }
|
||||
|
||||
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!");
|
||||
res_verified =
|
||||
res_verified &&
|
||||
ck::utils::check_err(
|
||||
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 5e-1, 5e-1);
|
||||
|
||||
if(config.verbosity > 0 && res_verified)
|
||||
std::cout << "Verification Successful!" << std::endl;
|
||||
|
||||
@@ -484,13 +484,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -465,12 +465,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -254,16 +254,21 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
|
||||
constexpr auto buffer_load_issue_point_b = 0;
|
||||
constexpr auto buffer_load_issue_point_interval_more =
|
||||
num_mfma_perstage / buffer_load_perstage_more ? num_mfma_perstage / buffer_load_perstage_more : 1;
|
||||
num_mfma_perstage / buffer_load_perstage_more
|
||||
? num_mfma_perstage / buffer_load_perstage_more
|
||||
: 1;
|
||||
constexpr auto buffer_load_issue_point_interval_less =
|
||||
num_mfma_perstage / buffer_load_perstage_less ? num_mfma_perstage / buffer_load_perstage_less : 1;
|
||||
num_mfma_perstage / buffer_load_perstage_less
|
||||
? num_mfma_perstage / buffer_load_perstage_less
|
||||
: 1;
|
||||
constexpr auto ds_write_issue_point = 0;
|
||||
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
|
||||
|
||||
// B global read
|
||||
static_for<0, buffer_load_b_stages, 1>{}([&](auto i) {
|
||||
// Scale load, 1B
|
||||
if constexpr (i.value==0){
|
||||
if constexpr(i.value == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
// Scale load, 1A
|
||||
@@ -330,7 +335,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
// Scale load, 1A
|
||||
if constexpr(imfma == 0){
|
||||
if constexpr(imfma == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
|
||||
@@ -426,7 +432,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
@@ -483,12 +489,12 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
b_scale_thread_bufs(I0));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_bufs_up(I0));
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_bufs_up(I0));
|
||||
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
@@ -496,7 +502,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
|
||||
});
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf_up(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0];
|
||||
c_scale_thread_buf_up(m0) =
|
||||
a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0];
|
||||
});
|
||||
|
||||
// Local prefill A1
|
||||
@@ -532,10 +539,10 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_bufs_up(I0));
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_bufs_up(I0));
|
||||
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
@@ -574,7 +581,6 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
@@ -609,13 +615,13 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
}
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -627,13 +633,13 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
b_scale_thread_copy_step);
|
||||
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_bufs_up(local_read_buf));
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_bufs_up(local_read_buf));
|
||||
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step);
|
||||
b_scale_thread_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
@@ -676,8 +682,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
@@ -711,7 +717,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec_up.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_scale_thread_vec_up
|
||||
.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
@@ -800,8 +807,10 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs[mfma_reg_buf][I0];
|
||||
c_scale_thread_buf_up(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs_up[mfma_reg_buf][I0];
|
||||
c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] *
|
||||
b_scale_thread_bufs[mfma_reg_buf][I0];
|
||||
c_scale_thread_buf_up(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] *
|
||||
b_scale_thread_bufs_up[mfma_reg_buf][I0];
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
@@ -824,13 +833,13 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1));
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1));
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[m0];
|
||||
@@ -970,7 +979,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
|
||||
c_scale_thread_buf_up(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0];
|
||||
c_scale_thread_buf_up(m0) =
|
||||
a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0];
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
|
||||
@@ -43,56 +43,56 @@ constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
#if 0
|
||||
|
||||
@@ -254,16 +254,21 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
|
||||
constexpr auto buffer_load_issue_point_b = 0;
|
||||
constexpr auto buffer_load_issue_point_interval_more =
|
||||
num_mfma_perstage / buffer_load_perstage_more ? num_mfma_perstage / buffer_load_perstage_more : 1;
|
||||
num_mfma_perstage / buffer_load_perstage_more
|
||||
? num_mfma_perstage / buffer_load_perstage_more
|
||||
: 1;
|
||||
constexpr auto buffer_load_issue_point_interval_less =
|
||||
num_mfma_perstage / buffer_load_perstage_less ? num_mfma_perstage / buffer_load_perstage_less : 1;
|
||||
num_mfma_perstage / buffer_load_perstage_less
|
||||
? num_mfma_perstage / buffer_load_perstage_less
|
||||
: 1;
|
||||
constexpr auto ds_write_issue_point = 0;
|
||||
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
|
||||
|
||||
// B global read
|
||||
static_for<0, buffer_load_b_stages, 1>{}([&](auto i) {
|
||||
// Scale load, 1B
|
||||
if constexpr (i.value==0){
|
||||
if constexpr(i.value == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
// Scale load, 1A
|
||||
@@ -330,7 +335,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
// Scale load, 1A
|
||||
if constexpr(imfma == 0){
|
||||
if constexpr(imfma == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
|
||||
@@ -420,7 +426,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
@@ -586,13 +592,13 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
}
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -744,7 +750,8 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs[mfma_reg_buf][I0];
|
||||
c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] *
|
||||
b_scale_thread_bufs[mfma_reg_buf][I0];
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
@@ -768,7 +775,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
b_thread_bufs(I1));
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[m0];
|
||||
|
||||
@@ -449,84 +449,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlockGemmPipelineScheduler:
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1));
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
#if defined(__gfx950__) && 0
|
||||
printf("Tid: %02d, NRepeat0, B 00-31 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x "
|
||||
"%02x %02x %02x %02x %02x %02x\n"
|
||||
"Tid: %02d, NRepeat0, B 32-63 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x "
|
||||
"%02x %02x %02x %02x %02x %02x\n"
|
||||
"Tid: %02d, NRepeat1, B 00-31 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x "
|
||||
"%02x %02x %02x %02x %02x %02x\n"
|
||||
"Tid: %02d, NRepeat1, B 32-63 %02x %02x %02x %02x %02x %02x %02x %02x %02x %02x "
|
||||
"%02x %02x %02x %02x %02x %02x\n",
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<7>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<8 + 7>{}))),
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 7>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<16 + 8 + 7>{}))),
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 7>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 8 + 7>{}))),
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 7>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 0>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 1>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 2>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 3>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 4>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 5>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 6>{}))),
|
||||
*reinterpret_cast<uint8_t*>(&(b_thread_bufs(I0)(Number<32 + 16 + 8 + 7>{}))));
|
||||
#endif
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -201,7 +201,7 @@ struct DeviceMoeGemmBlockScale
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
|
||||
|
||||
@@ -789,26 +789,12 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
}
|
||||
|
||||
// Calculate A scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_scale_k_split_offset =
|
||||
k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
|
||||
}
|
||||
a_scale_k_split_offset =
|
||||
k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack * MPerXdl;
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_scale_k_split_offset =
|
||||
k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
|
||||
}
|
||||
b_scale_k_split_offset =
|
||||
k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack * NPerXdl;
|
||||
|
||||
if(k_id < (karg.KBatch - 1))
|
||||
{
|
||||
@@ -1850,17 +1836,27 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
|
||||
const auto Padded_Scale_M =
|
||||
math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
|
||||
(KXdlPack * 64 / MPerXdl),
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a));
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a),
|
||||
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
|
||||
(ScaleBlockSize / APackedSize)) *
|
||||
MPerXdl * MXdlPack / scale_pack_size_a,
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a,
|
||||
1));
|
||||
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(problem.N / (NXdlPack * NPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
|
||||
(KXdlPack * 64 / NPerXdl),
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b));
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b),
|
||||
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
|
||||
(ScaleBlockSize / BPackedSize)) *
|
||||
NPerXdl * NXdlPack / scale_pack_size_b,
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b,
|
||||
1));
|
||||
|
||||
Run<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
@@ -2362,17 +2358,27 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
|
||||
const auto Padded_Scale_M =
|
||||
math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
|
||||
(KXdlPack * 64 / MPerXdl),
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a));
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a),
|
||||
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
|
||||
(ScaleBlockSize / APackedSize)) *
|
||||
MPerXdl * MXdlPack / scale_pack_size_a,
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a,
|
||||
1));
|
||||
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(problem.N / (NXdlPack * NPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
|
||||
(KXdlPack * 64 / NPerXdl),
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b));
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b),
|
||||
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
|
||||
(ScaleBlockSize / BPackedSize)) *
|
||||
NPerXdl * NXdlPack / scale_pack_size_b,
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b,
|
||||
1));
|
||||
|
||||
Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
|
||||
@@ -215,6 +215,14 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using mx_scale_t = e8m0_bexp_t;
|
||||
static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
|
||||
static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
|
||||
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
|
||||
"A scale pack data type too large!");
|
||||
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
|
||||
"B scale pack data type too large!");
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
|
||||
@@ -806,7 +814,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
b_k_split_offset = k_id * karg.KRead;
|
||||
b_k_split_offset = k_id * karg.KRead * NPerXdl;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -816,26 +824,12 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
}
|
||||
|
||||
// Calculate A scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_scale_k_split_offset =
|
||||
k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
|
||||
}
|
||||
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack *
|
||||
MPerXdl / scale_pack_size_a;
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_scale_k_split_offset =
|
||||
k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::MFMA, BLayout>)
|
||||
{
|
||||
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
|
||||
}
|
||||
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack *
|
||||
NPerXdl / scale_pack_size_b;
|
||||
|
||||
if(k_id < (karg.KBatch - 1))
|
||||
{
|
||||
@@ -1289,14 +1283,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
|
||||
|
||||
using mx_scale_t = e8m0_bexp_t;
|
||||
static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
|
||||
static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
|
||||
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
|
||||
"A scale pack data type too large!");
|
||||
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
|
||||
"B scale pack data type too large!");
|
||||
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename AScaleGridDesc_AM_AK,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
@@ -2274,17 +2260,27 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
// We pad the M unconditionaly for Scale
|
||||
const auto Padded_Scale_M =
|
||||
math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
|
||||
(KXdlPack * 64 / MPerXdl),
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a));
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a),
|
||||
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
|
||||
(ScaleBlockSize / APackedSize)) *
|
||||
MPerXdl * MXdlPack / scale_pack_size_a,
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a,
|
||||
1));
|
||||
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(problem.N / (NXdlPack * NPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
|
||||
(KXdlPack * 64 / NPerXdl),
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b));
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b),
|
||||
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
|
||||
(ScaleBlockSize / BPackedSize)) *
|
||||
NPerXdl * NXdlPack / scale_pack_size_b,
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b,
|
||||
1));
|
||||
|
||||
Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -362,99 +362,6 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
|
||||
__device__ void RunPrint(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
|
||||
"wrong! DstSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
// DstDesc and dst_slice_origin_idx are known at compile-time
|
||||
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
||||
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>>;
|
||||
|
||||
// loop over tensor and copy
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type src_vector;
|
||||
|
||||
using src_vector_t =
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type::type;
|
||||
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
|
||||
printf("Tid: %03d, Ascale read gmem src_data_coord.GetOffset() = %d\n",
|
||||
get_thread_local_1d_id(),
|
||||
src_coord_.GetOffset());
|
||||
// copy data from src_buf into src_vector
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize,
|
||||
is_src_valid);
|
||||
|
||||
// copy data from src_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset =
|
||||
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
|
||||
i * src_scalar_step_in_vector);
|
||||
|
||||
if constexpr(InvalidElementAsNaN)
|
||||
{
|
||||
dst_buf(Number<dst_offset>{}) =
|
||||
is_src_valid
|
||||
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
|
||||
: NumericLimits<DstData>::QuietNaN();
|
||||
}
|
||||
else
|
||||
{
|
||||
dst_buf(Number<dst_offset>{}) =
|
||||
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
|
||||
}
|
||||
});
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
|
||||
Reference in New Issue
Block a user