implement shuffled scale mxfp4gemm, blocker: opsel not effect

This commit is contained in:
aska-0096
2025-05-11 05:54:13 +00:00
parent 6c761bf9b8
commit 41ea1066ac
7 changed files with 137 additions and 51 deletions

View File

@@ -133,13 +133,15 @@ void preShuffleScaleBuffer(const ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk / KXdlPack; // i XdlKThread
int k2 = tempk % KXdlPack; // i KXdlPack
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
k1 * MNXdlPack * KXdlPack * XdlMNThread +
n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack +
n2;
dst[outputIndex] = src[n * K + k];
}
@@ -332,6 +334,27 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
std::cout << "NOTE: No input data initialization." << std::endl;
}
}
printf("a_scale:\n");
for (size_t i = 0; i < M; i++)
{
for (size_t j = 0; j < K / ScaleBlockSize; j++)
{
a_m_k_scale(i, j) = ck::type_convert<XDataType>(static_cast<float>(powf(2.0f, (j/4)%4)));
printf("%02x ", *reinterpret_cast<uint8_t*>(&a_m_k_scale(i, j)));
}
printf("\n");
}
printf("b_scale:\n");
for (size_t i = 0; i < N; i++)
{
for (size_t j = 0; j < K / ScaleBlockSize; j++)
{
b_k_n_scale(j, i) = ck::type_convert<XDataType>(static_cast<float>(powf(2.0f, (j/4)%4)));
printf("%02x ", *reinterpret_cast<uint8_t*>(&b_k_n_scale(j, i)));
}
printf("\n");
}
#if 1
preShuffleScaleBuffer(
a_m_k_scale.mData.data(), a_shuffled_scale.mData.data(), M, K / ScaleBlockSize);
@@ -339,6 +362,22 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
#endif
printf("a_shuffled_scale:\n");
for (size_t i = 0; i < M*K / ScaleBlockSize; i++)
{
printf("%02x ", *reinterpret_cast<uint8_t*>(&(a_shuffled_scale.mData.data()[i])));
if(i % 64 == 63)
printf("\n");
}
printf("b_shuffled_scale:\n");
for (size_t i = 0; i < N*K / ScaleBlockSize; i++)
{
printf("%02x ", *reinterpret_cast<uint8_t*>(&(b_shuffled_scale.mData.data()[i])));
if(i % 64 == 63)
printf("\n");
}
if(config.verbosity > 0)
std::cout << "Device memory allocation..." << std::endl;
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize());
@@ -353,14 +392,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data());
// for (size_t i = 0; i < N; i++)
// {
// for (size_t j = 0; j < K / ScaleBlockSize; j++)
// {
// printf("%02x ", *reinterpret_cast<uint8_t*>(&b_shuffled_scale(j, i)));
// }
// printf("\n");
// }
if(config.verbosity > 0)
std::cout << "Done." << std::endl;

View File

@@ -367,19 +367,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
});
if(get_thread_local_1d_id() == 0)
#if 0
if(get_thread_local_1d_id())
{
printf("Scale A: %02x %02x %02x %02x\n",
printf("1stGMEM Tid: %03d, Scale A: %02x %02x %02x %02x\n",
get_thread_local_1d_id(),
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I0)[Number<0>{}]),
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I0)[Number<1>{}]),
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I0)[Number<2>{}]),
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I0)[Number<3>{}]));
}
#endif
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(-MWaves * MRepeat / MXdlPack, 0, 0));
a_scale_grid_desc, make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
// Prefetch b_scales
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
@@ -396,20 +397,21 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
if(get_thread_local_1d_id() == 0)
#if 0
if(get_thread_local_1d_id())
{
printf("Scale B: %02x %02x %02x %02x\n",
printf("1stGMEM Tid: %03d, Scale B: %02x %02x %02x %02x\n",
get_thread_local_1d_id(),
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I0)[Number<0>{}]),
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I0)[Number<1>{}]),
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I0)[Number<2>{}]),
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I0)[Number<3>{}]));
}
#endif
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NWaves * NRepeat / NXdlPack, 0, 0));
b_scale_grid_desc, make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
@@ -483,7 +485,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(-MWaves * MRepeat / MXdlPack, 0, 0));
a_scale_grid_desc, make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
// Prefetch b_scales
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
@@ -504,7 +506,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NWaves * NRepeat / NXdlPack, 0, 0));
b_scale_grid_desc, make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
// TODO: consider scheduling the scale load
// -------------------------------------------------------------------------------------------
@@ -672,13 +674,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
// tail
if constexpr(TailNum == TailNumber::Even)
{
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Prefetch a_scales
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
@@ -710,7 +705,27 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
#if 0
if(get_thread_local_1d_id())
{
printf("2stGMEM Tid: %03d, Scale A: %02x %02x %02x %02x\n",
get_thread_local_1d_id(),
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I1)[Number<0>{}]),
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I1)[Number<1>{}]),
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I1)[Number<2>{}]),
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I1)[Number<3>{}]));
}
if(get_thread_local_1d_id())
{
printf("2stGMEM Tid: %03d, Scale B: %02x %02x %02x %02x\n",
get_thread_local_1d_id(),
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I1)[Number<0>{}]),
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I1)[Number<1>{}]),
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I1)[Number<2>{}]),
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I1)[Number<3>{}]));
}
#endif
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
@@ -796,7 +811,17 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
});
});
});
#if 0
if(get_thread_local_1d_id())
{
printf("1stMFMA, Tid: %03d, floatC: %.0f %.0f %.0f %.0f\n",
get_thread_local_1d_id(),
c_thread_buf[Number<0>{}],
c_thread_buf[Number<1>{}],
c_thread_buf[Number<2>{}],
c_thread_buf[Number<3>{}]);
}
#endif
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
@@ -911,6 +936,17 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
});
});
});
#if 0
if(get_thread_local_1d_id())
{
printf("2stMFMA, Tid: %03d, floatC: %.0f %.0f %.0f %.0f\n",
get_thread_local_1d_id(),
c_thread_buf[Number<0>{}],
c_thread_buf[Number<1>{}],
c_thread_buf[Number<2>{}],
c_thread_buf[Number<3>{}]);
}
#endif
}
else if constexpr(TailNum == TailNumber::Odd)
{

View File

@@ -424,6 +424,15 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
Run(kernel);
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
return ave_time;

View File

@@ -1279,7 +1279,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
}
}
}
#if 0
// check gridwise gemm pipeline
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
@@ -1290,7 +1290,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
return false;
}
}
#endif
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
@@ -1518,9 +1518,9 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// mfma.selected_mfma.num_threads_per_blk;
// A wave access continuous memory
auto thread_offset_shuffled = get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize;
auto thread_offset_shuffled = get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
auto a_thread_offset_m = waveId_m * MPerXdl * MXdlPack;
auto a_thread_offset_m = waveId_m;
auto a_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<AScaleDataType,
@@ -1537,7 +1537,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
make_multi_index(
block_m_id * MPerBlock + a_thread_offset_m, 0, thread_offset_shuffled));
auto b_thread_offset_n = waveId_n * NPerXdl * NXdlPack;
auto b_thread_offset_n = waveId_n;
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleDataType,

View File

@@ -867,6 +867,15 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
}
};
__host__ __device__ inline constexpr int32_t get_exponent_value_ex(e8m0x4_bexp_t x)
{
return (
(static_cast<uint32_t>(x.template AsType<e8m0_bexp_t>()[Number<0>{}].data) ) |
(static_cast<uint32_t>(x.template AsType<e8m0_bexp_t>()[Number<1>{}].data) << 8) |
(static_cast<uint32_t>(x.template AsType<e8m0_bexp_t>()[Number<2>{}].data) << 16 ) |
(static_cast<uint32_t>(x.template AsType<e8m0_bexp_t>()[Number<3>{}].data) << 24));
}
template <>
struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
{
@@ -899,20 +908,18 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
const ScaleB& scale_b,
FloatC& reg_c) const
{
if(get_thread_local_1d_id() == 0)
{
printf("Before BitCast: Scale A: %08x, Scale B: %08x\n",
*reinterpret_cast<const uint32_t*>(&scale_a),
*reinterpret_cast<const uint32_t*>(&scale_b));
}
// if(get_thread_local_1d_id() == 0)
// {
// printf("Before BitCast: Scale A: %08x, Scale B: %08x\n",
// *reinterpret_cast<const uint32_t*>(&scale_a),
// *reinterpret_cast<const uint32_t*>(&scale_b));
// }
// static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this
// point."); static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at
// this point.");
// intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
// a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
a, bit_cast<int32_t>(scale_a), b, bit_cast<int32_t>(scale_b), reg_c);
a, get_exponent_value_ex(scale_a), b, get_exponent_value_ex(scale_b), reg_c);
}
};

View File

@@ -756,11 +756,13 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
const int32_t scale_b,
FloatC& reg_c)
{
// if(get_thread_local_1d_id()){
// printf("Scale A: %08x, Scale B: %08x\n",
// *reinterpret_cast<const uint8_t*>(&scale_a), *reinterpret_cast<const
// uint8_t*>(&scale_b));
// }
if(get_thread_local_1d_id()){
printf("Tid: %03d, Scale A: %08x, Scale B: %08x, OpSelA: %d, OpSelB: %d\n",
get_thread_local_1d_id(),
*reinterpret_cast<const uint32_t*>(&scale_a), *reinterpret_cast<const
uint32_t*>(&scale_b),
OpselA, OpselB);
}
#if defined(__gfx950__)
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);

View File

@@ -2134,6 +2134,7 @@ using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
using e8m0x4_bexp_t = typename vector_type<e8m0_bexp_t, 4>::type;
// pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;