Introduce MX GEMM for FP8 data type (#2000)

This commit is contained in:
Andriy Roshchenko
2025-03-24 15:41:07 -06:00
committed by GitHub
parent c027637a8f
commit 6660dc6b8e
11 changed files with 4129 additions and 135 deletions

View File

@@ -793,7 +793,7 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
@@ -817,7 +817,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
@@ -841,7 +841,7 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
@@ -870,7 +870,7 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on