mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
implement shuffled scale mxfp4gemm, blocker: opsel not effect
This commit is contained in:
@@ -133,13 +133,15 @@ void preShuffleScaleBuffer(const ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk / KXdlPack; // i XdlKThread
|
||||
int k2 = tempk % KXdlPack; // i KXdlPack
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread +
|
||||
n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack +
|
||||
n2;
|
||||
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
}
|
||||
@@ -332,6 +334,27 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
std::cout << "NOTE: No input data initialization." << std::endl;
|
||||
}
|
||||
}
|
||||
printf("a_scale:\n");
|
||||
for (size_t i = 0; i < M; i++)
|
||||
{
|
||||
for (size_t j = 0; j < K / ScaleBlockSize; j++)
|
||||
{
|
||||
a_m_k_scale(i, j) = ck::type_convert<XDataType>(static_cast<float>(powf(2.0f, (j/4)%4)));
|
||||
printf("%02x ", *reinterpret_cast<uint8_t*>(&a_m_k_scale(i, j)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("b_scale:\n");
|
||||
for (size_t i = 0; i < N; i++)
|
||||
{
|
||||
for (size_t j = 0; j < K / ScaleBlockSize; j++)
|
||||
{
|
||||
b_k_n_scale(j, i) = ck::type_convert<XDataType>(static_cast<float>(powf(2.0f, (j/4)%4)));
|
||||
printf("%02x ", *reinterpret_cast<uint8_t*>(&b_k_n_scale(j, i)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
#if 1
|
||||
preShuffleScaleBuffer(
|
||||
a_m_k_scale.mData.data(), a_shuffled_scale.mData.data(), M, K / ScaleBlockSize);
|
||||
@@ -339,6 +362,22 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
|
||||
#endif
|
||||
|
||||
printf("a_shuffled_scale:\n");
|
||||
for (size_t i = 0; i < M*K / ScaleBlockSize; i++)
|
||||
{
|
||||
printf("%02x ", *reinterpret_cast<uint8_t*>(&(a_shuffled_scale.mData.data()[i])));
|
||||
|
||||
if(i % 64 == 63)
|
||||
printf("\n");
|
||||
}
|
||||
printf("b_shuffled_scale:\n");
|
||||
for (size_t i = 0; i < N*K / ScaleBlockSize; i++)
|
||||
{
|
||||
printf("%02x ", *reinterpret_cast<uint8_t*>(&(b_shuffled_scale.mData.data()[i])));
|
||||
if(i % 64 == 63)
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
if(config.verbosity > 0)
|
||||
std::cout << "Device memory allocation..." << std::endl;
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize());
|
||||
@@ -353,14 +392,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data());
|
||||
// for (size_t i = 0; i < N; i++)
|
||||
// {
|
||||
// for (size_t j = 0; j < K / ScaleBlockSize; j++)
|
||||
// {
|
||||
// printf("%02x ", *reinterpret_cast<uint8_t*>(&b_shuffled_scale(j, i)));
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
|
||||
if(config.verbosity > 0)
|
||||
std::cout << "Done." << std::endl;
|
||||
|
||||
@@ -367,19 +367,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
#if 0
|
||||
if(get_thread_local_1d_id())
|
||||
{
|
||||
printf("Scale A: %02x %02x %02x %02x\n",
|
||||
printf("1stGMEM Tid: %03d, Scale A: %02x %02x %02x %02x\n",
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I0)[Number<0>{}]),
|
||||
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I0)[Number<1>{}]),
|
||||
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I0)[Number<2>{}]),
|
||||
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I0)[Number<3>{}]));
|
||||
}
|
||||
|
||||
#endif
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(-MWaves * MRepeat / MXdlPack, 0, 0));
|
||||
a_scale_grid_desc, make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
@@ -396,20 +397,21 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
#if 0
|
||||
if(get_thread_local_1d_id())
|
||||
{
|
||||
printf("Scale B: %02x %02x %02x %02x\n",
|
||||
printf("1stGMEM Tid: %03d, Scale B: %02x %02x %02x %02x\n",
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I0)[Number<0>{}]),
|
||||
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I0)[Number<1>{}]),
|
||||
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I0)[Number<2>{}]),
|
||||
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I0)[Number<3>{}]));
|
||||
}
|
||||
|
||||
#endif
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NWaves * NRepeat / NXdlPack, 0, 0));
|
||||
b_scale_grid_desc, make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
@@ -483,7 +485,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(-MWaves * MRepeat / MXdlPack, 0, 0));
|
||||
a_scale_grid_desc, make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
@@ -504,7 +506,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NWaves * NRepeat / NXdlPack, 0, 0));
|
||||
b_scale_grid_desc, make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// TODO: consider scheduling the scale load
|
||||
// -------------------------------------------------------------------------------------------
|
||||
@@ -672,13 +674,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
@@ -710,7 +705,27 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
#if 0
|
||||
if(get_thread_local_1d_id())
|
||||
{
|
||||
printf("2stGMEM Tid: %03d, Scale A: %02x %02x %02x %02x\n",
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I1)[Number<0>{}]),
|
||||
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I1)[Number<1>{}]),
|
||||
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I1)[Number<2>{}]),
|
||||
*reinterpret_cast<const uint8_t*>(&a_scale_thread_bufs(I1)[Number<3>{}]));
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id())
|
||||
{
|
||||
printf("2stGMEM Tid: %03d, Scale B: %02x %02x %02x %02x\n",
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I1)[Number<0>{}]),
|
||||
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I1)[Number<1>{}]),
|
||||
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I1)[Number<2>{}]),
|
||||
*reinterpret_cast<const uint8_t*>(&b_scale_thread_bufs(I1)[Number<3>{}]));
|
||||
}
|
||||
#endif
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
@@ -796,7 +811,17 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id())
|
||||
{
|
||||
printf("1stMFMA, Tid: %03d, floatC: %.0f %.0f %.0f %.0f\n",
|
||||
get_thread_local_1d_id(),
|
||||
c_thread_buf[Number<0>{}],
|
||||
c_thread_buf[Number<1>{}],
|
||||
c_thread_buf[Number<2>{}],
|
||||
c_thread_buf[Number<3>{}]);
|
||||
}
|
||||
#endif
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
@@ -911,6 +936,17 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
});
|
||||
});
|
||||
#if 0
|
||||
if(get_thread_local_1d_id())
|
||||
{
|
||||
printf("2stMFMA, Tid: %03d, floatC: %.0f %.0f %.0f %.0f\n",
|
||||
get_thread_local_1d_id(),
|
||||
c_thread_buf[Number<0>{}],
|
||||
c_thread_buf[Number<1>{}],
|
||||
c_thread_buf[Number<2>{}],
|
||||
c_thread_buf[Number<3>{}]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
|
||||
@@ -424,6 +424,15 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
|
||||
@@ -1279,7 +1279,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
|
||||
|
||||
@@ -1290,7 +1290,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
@@ -1518,9 +1518,9 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// mfma.selected_mfma.num_threads_per_blk;
|
||||
|
||||
// A wave access continuous memory
|
||||
auto thread_offset_shuffled = get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize;
|
||||
auto thread_offset_shuffled = get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
|
||||
|
||||
auto a_thread_offset_m = waveId_m * MPerXdl * MXdlPack;
|
||||
auto a_thread_offset_m = waveId_m;
|
||||
|
||||
auto a_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<AScaleDataType,
|
||||
@@ -1537,7 +1537,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
make_multi_index(
|
||||
block_m_id * MPerBlock + a_thread_offset_m, 0, thread_offset_shuffled));
|
||||
|
||||
auto b_thread_offset_n = waveId_n * NPerXdl * NXdlPack;
|
||||
auto b_thread_offset_n = waveId_n;
|
||||
|
||||
auto b_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleDataType,
|
||||
|
||||
@@ -867,6 +867,15 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ inline constexpr int32_t get_exponent_value_ex(e8m0x4_bexp_t x)
|
||||
{
|
||||
return (
|
||||
(static_cast<uint32_t>(x.template AsType<e8m0_bexp_t>()[Number<0>{}].data) ) |
|
||||
(static_cast<uint32_t>(x.template AsType<e8m0_bexp_t>()[Number<1>{}].data) << 8) |
|
||||
(static_cast<uint32_t>(x.template AsType<e8m0_bexp_t>()[Number<2>{}].data) << 16 ) |
|
||||
(static_cast<uint32_t>(x.template AsType<e8m0_bexp_t>()[Number<3>{}].data) << 24));
|
||||
}
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
|
||||
{
|
||||
@@ -899,20 +908,18 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
|
||||
const ScaleB& scale_b,
|
||||
FloatC& reg_c) const
|
||||
{
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
{
|
||||
printf("Before BitCast: Scale A: %08x, Scale B: %08x\n",
|
||||
*reinterpret_cast<const uint32_t*>(&scale_a),
|
||||
*reinterpret_cast<const uint32_t*>(&scale_b));
|
||||
}
|
||||
// if(get_thread_local_1d_id() == 0)
|
||||
// {
|
||||
// printf("Before BitCast: Scale A: %08x, Scale B: %08x\n",
|
||||
// *reinterpret_cast<const uint32_t*>(&scale_a),
|
||||
// *reinterpret_cast<const uint32_t*>(&scale_b));
|
||||
// }
|
||||
// static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this
|
||||
// point."); static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at
|
||||
// this point.");
|
||||
|
||||
// intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
|
||||
// a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
|
||||
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
|
||||
a, bit_cast<int32_t>(scale_a), b, bit_cast<int32_t>(scale_b), reg_c);
|
||||
a, get_exponent_value_ex(scale_a), b, get_exponent_value_ex(scale_b), reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -756,11 +756,13 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
// if(get_thread_local_1d_id()){
|
||||
// printf("Scale A: %08x, Scale B: %08x\n",
|
||||
// *reinterpret_cast<const uint8_t*>(&scale_a), *reinterpret_cast<const
|
||||
// uint8_t*>(&scale_b));
|
||||
// }
|
||||
if(get_thread_local_1d_id()){
|
||||
printf("Tid: %03d, Scale A: %08x, Scale B: %08x, OpSelA: %d, OpSelB: %d\n",
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<const uint32_t*>(&scale_a), *reinterpret_cast<const
|
||||
uint32_t*>(&scale_b),
|
||||
OpselA, OpselB);
|
||||
}
|
||||
#if defined(__gfx950__)
|
||||
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
|
||||
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
|
||||
|
||||
@@ -2134,6 +2134,7 @@ using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
|
||||
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
|
||||
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
|
||||
|
||||
using e8m0x4_bexp_t = typename vector_type<e8m0_bexp_t, 4>::type;
|
||||
// pack int4
|
||||
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
|
||||
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
|
||||
|
||||
Reference in New Issue
Block a user