mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
MX FP GEMM - Test MX FP8 MFMA Instructions (#1902)
* Refactored `load_A_row_major` to follow scale mapping
* Refactored `load_A_col_major` to follow scale mapping
* Refactored `load_B_col_major` to follow scale mapping
* Verified non-scaled test
* Verified scaled tests
* Used ReferenceMXGemm for verification
* Updated license headers
[ROCm/composable_kernel commit: ffa13455a2]
This commit is contained in:
committed by
GitHub
parent
d3f31b32d2
commit
c3175995ba
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -780,7 +780,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: fix mfma...f8f6f4 instructions
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
|
||||
{
|
||||
@@ -847,9 +846,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
|
||||
// clang-format on
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
__device__ void run(const FloatA& a,
|
||||
const int32_t scale_a,
|
||||
const FloatB& b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(
|
||||
a, scale_a, b, scale_b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -871,9 +875,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
|
||||
// clang-format on
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
__device__ void run(const FloatA& a,
|
||||
const int32_t scale_a,
|
||||
const FloatB& b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(
|
||||
a, scale_a, b, scale_b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -533,9 +533,9 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0, // cbsz
|
||||
0, // blgp
|
||||
0, // { OPSEL_HI[0], OPSEL[0] }?
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // { OPSEL_HI[1], OPSEL[1] }?
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
@@ -569,9 +569,9 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0, // cbsz
|
||||
0, // blgp
|
||||
0, // { OPSEL_HI[0], OPSEL[0] }?
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // { OPSEL_HI[1], OPSEL[1] }?
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
|
||||
Reference in New Issue
Block a user