mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
MX GEMM - Parameterized Test Template (#2088)
* Tests for MX FP8 GEMM * Improve documentation
This commit is contained in:
committed by
GitHub
parent
da54464cce
commit
213b203a3c
@@ -22,6 +22,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// clang-format off
|
||||
/**
|
||||
* \brief WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types
|
||||
*
|
||||
@@ -31,8 +32,8 @@ namespace device {
|
||||
* Assumptions:
|
||||
* - A and B data types are compliant with the OCP Microscaling Formats (MX) Specification
|
||||
* - Each scale applies to ScaleBlockSize elements in K direction
|
||||
* - A scale matrix is row-major
|
||||
* - B scale matrix is column-major
|
||||
* - A scale matrix is a row-major
|
||||
* - B scale matrix is a column-major
|
||||
* - Scale data types must have get_exponent_value() specialization, whereas lowest 8 bits of the
|
||||
* exponent will be interpreted as conventional biased Float32 exponent (E8M0)
|
||||
*
|
||||
@@ -72,10 +73,10 @@ namespace device {
|
||||
* for(int mw = m0; mw < m0 + MWaves * MPerXDL; mw += MPerXDL){
|
||||
* for(int nw = n0; nw < n0 + NWaves * NPerXDL; nw += NPerXDL){
|
||||
* for(int k0 = kb; k0 < kb + KPerBlock; k0 += mfma.num_input_blks*KPack){
|
||||
* // MFMA accumulation for multirate instructions
|
||||
* for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPack){
|
||||
* for(int k_mfma = k_pack; k_mfma < k_pack + KPack; k_mfma += mfma.k_per_blk){
|
||||
* // MFMA instruction
|
||||
* // MFMA accumulation
|
||||
* for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPerXdlops){
|
||||
* // MFMA instruction
|
||||
* for(int k_mfma = k_pack; k_mfma < k_pack + KPerXdlops; k_mfma += mfma.k_per_blk){
|
||||
* for(int m = mw; m < mw + MPerXDL; m++){
|
||||
* for(int n = nw; n < nw + NPerXDL; n++){
|
||||
* for(int k = k_mfma; k < k_mfma + mfma.k_per_blk; k++){
|
||||
@@ -96,6 +97,7 @@ namespace device {
|
||||
* \endcode
|
||||
*
|
||||
*/
|
||||
// clang-format on
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
@@ -104,7 +106,7 @@ template <typename ALayout,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename GemmAccDataType, // TODO: always float
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
|
||||
Reference in New Issue
Block a user