MX GEMM - Parameterized Test Template (#2088)

* Tests for MX FP8 GEMM

* Improve documentation
This commit is contained in:
Andriy Roshchenko
2025-04-16 19:56:00 -06:00
committed by GitHub
parent da54464cce
commit 213b203a3c
12 changed files with 948 additions and 7 deletions

View File

@@ -22,6 +22,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
// clang-format off
/**
* \brief WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types
*
@@ -31,8 +32,8 @@ namespace device {
* Assumptions:
* - A and B data types are compliant with the OCP Microscaling Formats (MX) Specification
* - Each scale applies to ScaleBlockSize elements in K direction
* - A scale matrix is row-major
* - B scale matrix is column-major
* - A scale matrix is a row-major
* - B scale matrix is a column-major
* - Scale data types must have get_exponent_value() specialization, whereas lowest 8 bits of the
* exponent will be interpreted as conventional biased Float32 exponent (E8M0)
*
@@ -72,10 +73,10 @@ namespace device {
* for(int mw = m0; mw < m0 + MWaves * MPerXDL; mw += MPerXDL){
* for(int nw = n0; nw < n0 + NWaves * NPerXDL; nw += NPerXDL){
* for(int k0 = kb; k0 < kb + KPerBlock; k0 += mfma.num_input_blks*KPack){
* // MFMA accumulation for multirate instructions
* for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPack){
* for(int k_mfma = k_pack; k_mfma < k_pack + KPack; k_mfma += mfma.k_per_blk){
* // MFMA instruction
* // MFMA accumulation
* for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPerXdlops){
* // MFMA instruction
* for(int k_mfma = k_pack; k_mfma < k_pack + KPerXdlops; k_mfma += mfma.k_per_blk){
* for(int m = mw; m < mw + MPerXDL; m++){
* for(int n = nw; n < nw + NPerXDL; n++){
* for(int k = k_mfma; k < k_mfma + mfma.k_per_blk; k++){
@@ -96,6 +97,7 @@ namespace device {
* \endcode
*
*/
// clang-format on
template <typename ALayout,
typename BLayout,
typename CLayout,
@@ -104,7 +106,7 @@ template <typename ALayout,
typename BDataType,
typename BScaleDataType,
typename CDataType,
typename GemmAccDataType,
typename GemmAccDataType, // TODO: always float
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,