Lwpck 3550: Implement and test fixed precision fp8 x bf8 (#2963)

* HasHotLoop is a constexpr

* Remove an unused function

* Remove some unused include statements

* Add implementation and tests for fp8 x bf8 weight preshuffle GEMM

* Add implementation and tests for fp8 x bf8 in CK Tile basic and universal GEMMs

* Remove two barrier calls that HotLoopScheduler already calls

* No need to suppress a variable that hasn't been declared

* Replace six arg_parser arguments with constexpr literals

* Simplify run_gemm_test_prec_type

* The strides don't need to be passed via arg_parser as we use their default values

* The layouts don't need to be passed as arguments twice

* Pass M N and K as regular arguments, not using the argument parser

* We can now remove the argument parser

* Add a common file for precision types to be used in testing

* Convert basic and universal GEMM tests to use gtest

* Make GemmConfig a test parameter, and form test cases as the cartesian product GemmConfigs x PrecTypes

* Add GemmConfigComputeV4 to the GEMM configs to run the universal tests on

* Added a changelog entry

* Add missing copyright statements

* ifndef-define-endif is not needed with pragma once

* Fix a comment

* Add F8 x BF8 tests for CompV4 in test_gemm_pipeline_kernel_types.hpp

* Disable the unreliable test MoeSortingCase4

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
SamiAario-AMD
2025-10-30 14:36:10 +02:00
committed by GitHub
parent 9ee9f4d2a3
commit 254bce9346
36 changed files with 411 additions and 491 deletions

View File

@@ -3,14 +3,10 @@
#pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
@@ -25,8 +21,6 @@ struct BaseGemmPipelineAgBgCrCompV3
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;

View File

@@ -484,7 +484,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
move_tile_window(b_tile_windows, b_dram_tile_window_step);
if(HasHotLoop)
if constexpr(HasHotLoop)
{
// minus 2 because we have ping-pong double buffer.
index_t iCounter = amd_wave_read_first_lane(num_loop - 2);
@@ -529,7 +529,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
// gemm
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
// pong
{
@@ -572,7 +571,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
// gemm
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
iCounter -= 2;
} while(iCounter > 1);
@@ -631,8 +629,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
static_for<0, 8, 1>{}([&](auto i) {
ignore = i;
static_for<0, 8, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
});