mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Introduce MX GEMM for FP8 data type (#2000)
This commit is contained in:
committed by
GitHub
parent
c027637a8f
commit
6660dc6b8e
@@ -793,7 +793,7 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
|
||||
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
|
||||
static constexpr index_t m_per_blk = 32; // from the instruction
|
||||
static constexpr index_t n_per_blk = 32; // from the instruction
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
|
||||
static constexpr bool is_k_reduction = true; // ???
|
||||
// clang-format on
|
||||
|
||||
@@ -817,7 +817,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>
|
||||
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
|
||||
static constexpr index_t m_per_blk = 16; // from the instruction
|
||||
static constexpr index_t n_per_blk = 16; // from the instruction
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
|
||||
static constexpr bool is_k_reduction = true; // ???
|
||||
// clang-format on
|
||||
|
||||
@@ -841,7 +841,7 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
|
||||
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
|
||||
static constexpr index_t m_per_blk = 32; // from the instruction
|
||||
static constexpr index_t n_per_blk = 32; // from the instruction
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
|
||||
static constexpr bool is_k_reduction = true; // ???
|
||||
// clang-format on
|
||||
|
||||
@@ -870,7 +870,7 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
|
||||
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
|
||||
static constexpr index_t m_per_blk = 16; // from the instruction
|
||||
static constexpr index_t n_per_blk = 16; // from the instruction
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
|
||||
static constexpr bool is_k_reduction = true; // ???
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user