mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
MX GEMM - New GEMM pipeline for MX data types (#2059)
* Allow selection of mfma_scale instructions * Read B tensor from LDS to VGPR in chunks of 16 in MFMA order * Add constexpr and synchronize return type for `get_exponent_value` * Pass scales by reference and add comments to `mfma_scale_f32_32x32x64` * Add support for microscaling instructions in `XdlopsGemm` * Fix `mfma_scale_f32_16x16x128f8f6f4` wrapper * Remove software implementation of MX GEMM * Make interface of `intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>` consistent with the other scale instruction * Update README * Updated CHANGELOG * Remove unused static methods
This commit is contained in:
committed by
GitHub
parent
d55c9cb313
commit
7106976a72
@@ -159,16 +159,22 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
static constexpr auto AK1Number = Number<AK1Value>{};
|
||||
static constexpr auto BK1Number = Number<BK1Value>{};
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma = false;
|
||||
static constexpr auto is_scale_mfma = true;
|
||||
|
||||
//> KPack is at least the k_per_blk of selected mfma
|
||||
//
|
||||
// Should be a multiple of k_per_blk.
|
||||
// TODO: Move this to blockwise pipeline base
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
@@ -1088,10 +1094,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
static_assert(KPerBlock % ScaleBlockSize == 0,
|
||||
"KPerBlock should be multiple of ScaleBlockSize");
|
||||
|
||||
static_assert(KPerBlock / ScaleBlockSize == BlockwiseGemmPipe::KRepeat,
|
||||
"Single call to xdlops_gemm::Run should process exactly ScaleBlockSize "
|
||||
"elements in k dimension");
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
@@ -1476,61 +1478,63 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
|
||||
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
|
||||
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
|
||||
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
|
||||
static constexpr auto KPerThread = KPerBlock / K0PerXdlops;
|
||||
|
||||
// NXdlPerWave == NRepeat
|
||||
// MXdlPerWave == MRepeat
|
||||
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
|
||||
// Initial thread mapping for MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MWaves=NWaves=2
|
||||
// Initial thread mapping for:
|
||||
// BlockSize = 256
|
||||
// MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2
|
||||
// For each [m0, n0] tile, there are 4 waves:
|
||||
// tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0]
|
||||
// tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1]
|
||||
// tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0]
|
||||
// tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1]
|
||||
|
||||
auto a_thread_offset_m =
|
||||
MPerXdl * ((get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) / MWaves) +
|
||||
mfma.selected_mfma.group_size *
|
||||
((get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / MPerXdl);
|
||||
auto a_thread_offset_k = KPerThread * (get_thread_local_1d_id() % MPerXdl) / MPerXdl;
|
||||
// BlockSize = 128
|
||||
// MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1
|
||||
// For each [m0, n0] tile, there are 2 waves:
|
||||
// tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0]
|
||||
// tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0]
|
||||
|
||||
auto b_thread_offset_n =
|
||||
get_thread_local_1d_id() % NPerXdl +
|
||||
(get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl;
|
||||
auto b_thread_offset_k = KPerThread * (get_thread_local_1d_id() % NPerXdl) / NPerXdl;
|
||||
// TODO: Document initial thread mapping for more combinations of parameters
|
||||
|
||||
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
|
||||
AScaleDataType,
|
||||
AScaleDataType,
|
||||
decltype(a_scale_grid_desc_am_ak), // SrcDesc
|
||||
decltype(BlockwiseGemmPipe::a_scale_thread_desc_group), // DstDesc
|
||||
Sequence<mfma.selected_mfma.group_size, 1>, // SliceLengths
|
||||
Sequence<0, 1>, // DimAccessOrder
|
||||
0, // SrcVectorDim
|
||||
1, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>(a_scale_grid_desc_am_ak,
|
||||
make_multi_index(block_m_id * MPerBlock + a_thread_offset_m,
|
||||
a_thread_offset_k / ScaleBlockSize));
|
||||
const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
|
||||
BScaleDataType,
|
||||
BScaleDataType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(BlockwiseGemmPipe::b_scale_thread_desc),
|
||||
Sequence<1, BlockwiseGemmPipe::KRepeat>, // SliceLengths
|
||||
Sequence<0, 1>, // DimAccessOrder
|
||||
1, // SrcVectorDim
|
||||
BlockwiseGemmPipe::KRepeat, // SrcScalarPerVector
|
||||
1,
|
||||
false>(b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n,
|
||||
b_thread_offset_k / ScaleBlockSize));
|
||||
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
|
||||
|
||||
auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) /
|
||||
mfma.selected_mfma.num_threads_per_blk;
|
||||
|
||||
auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl;
|
||||
|
||||
auto a_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<AScaleDataType,
|
||||
AScaleDataType,
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(BlockwiseGemmPipe::a_scale_thread_desc_copy),
|
||||
Sequence<1, 1>, // SliceLengths
|
||||
Sequence<0, 1>, // DimAccessOrder
|
||||
1, // SrcVectorDim
|
||||
1, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>(
|
||||
a_scale_grid_desc_am_ak,
|
||||
make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, thread_offset_k));
|
||||
|
||||
auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl;
|
||||
|
||||
auto b_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleDataType,
|
||||
BScaleDataType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(BlockwiseGemmPipe::b_scale_thread_desc_copy),
|
||||
Sequence<1, 1>, // SliceLengths
|
||||
Sequence<0, 1>, // DimAccessOrder
|
||||
1, // SrcVectorDim
|
||||
1, // SrcScalarPerVector
|
||||
1,
|
||||
true>(
|
||||
b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
|
||||
Reference in New Issue
Block a user