mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
MX GEMM - Add FP8 GEMM Tests for Different Layouts (#2152)
* Add gemm_mx_fp8_bf8 example with row-major B
* Add more overloads of MX MFMA instructions
* Add MK_KN (RRR) tests
* Add KM_NK (CCR) tests
* Add more problem sizes to Large tests
* Add test_gemm_mx to the list of regression tests
[ROCm/composable_kernel commit: 79b0bfeb41]
This commit is contained in:
committed by
GitHub
parent
426ed45111
commit
ac743ddf0b
@@ -797,12 +797,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// kfold and mpair dimension is not always required.
|
||||
// more dimension in merge_transform increase the difficulty of generating immarg offset
|
||||
// for compiler.
|
||||
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
constexpr auto M1 = MPerBlock / M0;
|
||||
constexpr auto WaveSize = 64;
|
||||
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
constexpr auto M1 = MPerBlock / M0;
|
||||
|
||||
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
|
||||
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
|
||||
constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / MPerXdl;
|
||||
constexpr auto KThreadRead = WaveSize / MPerXdl;
|
||||
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
|
||||
@@ -929,12 +930,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
}
|
||||
else // RowMajor B
|
||||
{
|
||||
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
|
||||
constexpr auto N1 = NPerBlock / N0;
|
||||
constexpr auto WaveSize = 64;
|
||||
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
|
||||
constexpr auto N1 = NPerBlock / N0;
|
||||
|
||||
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
|
||||
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
|
||||
constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / NPerXdl;
|
||||
constexpr auto KThreadRead = WaveSize / NPerXdl;
|
||||
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
|
||||
|
||||
@@ -1129,6 +1129,12 @@ struct MfmaSelector
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, f8_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16>()
|
||||
{
|
||||
@@ -1147,6 +1153,18 @@ struct MfmaSelector
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32>()
|
||||
{
|
||||
|
||||
@@ -532,7 +532,44 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0, // cbsz
|
||||
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
// XXX: Note on the scale_a and scale_b parameters:
|
||||
// If compiler detects that one or both scales are constant values, it will treat that
|
||||
// constant as F32 constant. I.e., if scale_a at some point was declared as
|
||||
// `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
|
||||
// assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
|
||||
|
||||
// XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
|
||||
// when OPSEL is set otherwise.
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
const f8x32_t& reg_b,
|
||||
const int32_t& scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
@@ -576,7 +613,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0, // cbsz
|
||||
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
@@ -605,7 +642,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
1, // cbsz
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
1, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
@@ -617,6 +654,64 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
const bf8x32_t& reg_b,
|
||||
const int32_t& scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
1, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
const f8x32_t& reg_b,
|
||||
const int32_t& scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user