Fix the Composable Kernel CI and versions incompatibility (#4640)

## Motivation

This PR has 4 patches:
1. Fix the CI error of grouped gemm.
2. Fix the incompatibility of old linux version.
3. Fix the potential errors of flatmm.
4. Address the previous comments of abquant eight warps pipeline
solution.

---------

Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
Thomas Ning
2026-02-18 22:59:37 +08:00
committed by GitHub
parent 24eaadc4d2
commit e5fb690945
12 changed files with 67 additions and 65 deletions

View File

@@ -73,7 +73,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_STANDARD
#endif
#define CK_TILE_FLOAT_TO_FP8_STANDARD 0

View File

@@ -137,7 +137,7 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num));
auto expo = std::floor(std::log2(std::abs(max_possible_num)));
double compute_error = 0;
if constexpr(is_any_of<ComputeDataType, pk_int4_t, I8, I32, int>::value)
{
@@ -158,7 +158,7 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
}
else
{
output_error = std::pow(2, expo - numeric_traits<OutDataType>::mant) * 0.5;
output_error = std::pow(2, expo - numeric_traits<OutDataType>::mant) * 1.0;
}
double midway_error = std::max(compute_error, output_error);

View File

@@ -65,11 +65,7 @@ inline bool is_gfx12_supported()
return get_device_name() == "gfx1200" || get_device_name() == "gfx1201";
}
inline bool is_gfx95_supported()
{
// Check if load transpose is supported.
return get_device_name() == "gfx950";
}
inline bool is_gfx95_supported() { return get_device_name() == "gfx950"; }
inline size_t get_num_cus()
{

View File

@@ -116,13 +116,13 @@ struct CShuffleEpilogue
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
#ifdef __gfx9__
static constexpr bool AsyncPipeline = (MWave * NWave == 8);
#ifdef __gfx95__
static constexpr bool EightWave = (MWave * NWave == 8);
#else
static constexpr bool AsyncPipeline = false;
static constexpr bool EightWave = false;
#endif
static constexpr index_t BlockedXDLN_PerWarp =
AsyncPipeline ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
@@ -447,7 +447,7 @@ struct CShuffleEpilogue
if constexpr(is_950 || is_any_of<ADataType, pk_int4_t, pk_fp4_t>::value ||
is_any_of<BDataType, pk_int4_t, pk_fp4_t>::value)
{
if constexpr(AsyncPipeline)
if constexpr(EightWave)
{
return tile_distribution_encoding<
sequence<>,

View File

@@ -780,29 +780,31 @@ struct FlatmmKernel
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m)
{
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
constexpr int GM = decltype(kargs.scale_m_ptr)::GranularityMN;
constexpr int GK = decltype(kargs.scale_m_ptr)::GranularityK;
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
: 1; // per-token scale
static_assert(GM != -1,
"MakeScaleMWindow should only be instantiated when scale is enabled");
// per-tensor (GM==0) -> Mdim = 1, stride 0
const index_t m_dim = (GM == 0) ? 1 : (kargs.M / GM);
const index_t m_stride = (GM == 0) ? 0 : 1;
const index_t k_dim = (GK == 0) ? 1 : (splitk_batch_offset.splitted_k / GK);
const index_t k_stride = 0; // your original code keeps K stride 0
// Step 1: Create tensor view
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_m_ptr.ptr,
make_tuple(kargs.M / ScaleGranularityM,
ScaleGranularityKA == 0
? 1
: (splitk_batch_offset.splitted_k / ScaleGranularityKA)),
make_tuple(scale_stride_m, 0),
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
make_tuple(m_dim, k_dim),
make_tuple(m_stride, k_stride),
number < (GM == 1) ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
number<1>{});
// Step 2: Create tile window
// Window extents: if GM==0, we still just broadcast from [0,*]
return make_tile_window(scale_m_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number < ScaleGranularityKA == 0
? TilePartitioner::NPerBlock
: TilePartitioner::KPerBlock > {}),
number < (GK == 0) ? TilePartitioner::NPerBlock
: TilePartitioner::KPerBlock > {}),
{block_idx_m, 0});
}
@@ -811,27 +813,29 @@ struct FlatmmKernel
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_n)
{
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
constexpr int GN = decltype(kargs.scale_n_ptr)::GranularityMN;
constexpr int GK = decltype(kargs.scale_n_ptr)::GranularityK;
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
: 1; // per-channel scale
static_assert(GN != -1,
"MakeScaleNWindow should only be instantiated when scale is enabled");
// per-tensor (GN==0) -> Ndim = 1, stride 0
const index_t n_dim = (GN == 0) ? 1 : (kargs.N / GN);
const index_t n_stride = (GN == 0) ? 0 : 1;
const index_t k_dim = (GK == 0) ? 1 : (splitk_batch_offset.splitted_k / GK);
const index_t k_stride = 0;
// Step 1: Create tensor view
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_n_ptr.ptr,
make_tuple(
ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
kargs.N / ScaleGranularityN),
make_tuple(0, scale_stride_n),
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
make_tuple(k_dim, n_dim),
make_tuple(k_stride, n_stride),
number < (GN == 1) ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
number<1>{});
// Step 2: Create tile window
return make_tile_window(scale_n_view,
make_tuple(number < ScaleGranularityKB == 0
? TilePartitioner::MPerBlock
: TilePartitioner::KPerBlock > {},
make_tuple(number < (GK == 0) ? TilePartitioner::MPerBlock
: TilePartitioner::KPerBlock > {},
number<TilePartitioner::NPerBlock>{}),
{0, block_idx_n});
}
@@ -854,8 +858,6 @@ struct FlatmmKernel
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m);
const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
@@ -866,6 +868,8 @@ struct FlatmmKernel
// Run Epilogue Pipeline with k_batch dispatching
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
{
const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m);
const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n);
if(kargs.k_batch == 1)
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(

View File

@@ -76,7 +76,8 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / (KWarp * WarpGemm::kK);
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Problem::Traits::APreshuffleQuant;
static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(KPerBlock / KWarp, BQuantGroupSize::kK);
@@ -158,7 +159,8 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
static constexpr bool PreshuffleQuant = Traits::PreshuffleQuant;
static constexpr bool APreshuffleQuant = Traits::APreshuffleQuant;
static constexpr bool BPreshuffleQuant = Traits::BPreshuffleQuant;
static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
@@ -364,7 +366,7 @@ struct ABQuantBlockUniversalGemmAsBsCrAsync : public BlockGemmQuantBase
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
aq_block_tensor);
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
auto pull_from_lane =