Lwpck 3550: Implement and test fixed precision fp8 x bf8 (#2963)

* HasHotLoop is a constexpr

* Remove an unused function

* Remove some unused include statements

* Add implementation and tests for fp8 x bf8 weight preshuffle GEMM

* Add implementation and tests for fp8 x bf8 in CK Tile basic and universal GEMMs

* Remove two barrier calls that HotLoopScheduler already calls

* No need to suppress a variable that hasn't been declared

* Replace six arg_parser arguments with constexpr literals

* Simplify run_gemm_test_prec_type

* The strides don't need to be passed via arg_parser as we use their default values

* The layouts don't need to be passed as arguments twice

* Pass M N and K as regular arguments, not using the argument parser

* We can now remove the argument parser

* Add a common file for precision types to be used in testing

* Convert basic and universal GEMM tests to use gtest

* Make GemmConfig a test parameter, and form test cases as the cartesian product GemmConfigs x PrecTypes

* Add GemmConfigComputeV4 to the GEMM configs to run the universal tests on

* Added a changelog entry

* Add missing copyright statements

* ifndef-define-endif is not needed with pragma once

* Fix a comment

* Add F8 x BF8 tests for CompV4 in test_gemm_pipeline_kernel_types.hpp

* Disable the unreliable test MoeSortingCase4

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
SamiAario-AMD
2025-10-30 14:36:10 +02:00
committed by GitHub
parent 9ee9f4d2a3
commit 254bce9346
36 changed files with 411 additions and 491 deletions

View File

@@ -263,6 +263,9 @@ using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_fp8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
@@ -277,6 +280,10 @@ using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIter
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_32x32x32_fp8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;