MX GEMM - New GEMM pipeline for MX data types (#2059)

* Allow selection of mfma_scale instructions

* Read B tensor from LDS to VGPR in chunks of 16 in MFMA order

* Add constexpr and synchronize return type for `get_exponent_value`

* Pass scales by reference and add comments to `mfma_scale_f32_32x32x64`

* Add support for microscaling instructions in `XdlopsGemm`

* Fix `mfma_scale_f32_16x16x128f8f6f4` wrapper

* Remove software implementation of MX GEMM

* Make interface of `intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>` consistent with the other scale instruction

* Update README

* Updated CHANGELOG

* Remove unused static methods
This commit is contained in:
Andriy Roshchenko
2025-04-15 17:17:07 -06:00
committed by GitHub
parent d55c9cb313
commit 7106976a72
19 changed files with 1007 additions and 608 deletions

View File

@@ -211,8 +211,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
* @tparam SrcVectorDim The dimension along which vectorized access is performed in the source
* tensor.
* @tparam SrcScalarPerVector The number of scalar elements per vector in the source tensor.
* @tparam SrcScalarStrideInVector The stride of scalar elements within a vector in the source
* tensor.
* @tparam SrcScalarStrideInVector Not used.
* @tparam SrcResetCoordinateAfterRun controls whether source coordinate is restored after each Run
* or rolled back one step in MoveSrcSliceWindow
* @tparam InvalidElementAsNaN Whether to fill invalid elements with NaN (only applicable for