Fix the Composable Kernel CI and versions incompatibility (#4640)

## Motivation

This PR has 4 patches:
1. Fix the CI error of grouped gemm.
2. Fix the incompatibility of old linux version.
3. Fix the potential errors of flatmm.
4. Address the previous comments of abquant eight warps pipeline
solution.

---------

Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
Thomas Ning
2026-02-18 22:59:37 +08:00
committed by GitHub
parent 058be6c6e9
commit be25dd6775
12 changed files with 67 additions and 65 deletions

View File

@@ -88,7 +88,7 @@ using DeviceReduceInstance_2 = DeviceReduceMultiBlock<InOutDataType,
static bool do_verify;
static int init_method;
static float alpha;
static float beta;
static float beta_;
static bool time_kernel;
int main(int argc, char* argv[])
@@ -150,7 +150,7 @@ int main(int argc, char* argv[])
};
alpha = 1.0f;
beta = 0.0f;
beta_ = 0.0f;
Tensor<InOutDataType> in_1(inLengths_1);
@@ -174,22 +174,22 @@ int main(int argc, char* argv[])
case 0: break;
case 1:
in_1.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
if(beta != 0.0f)
if(beta_ != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
break;
case 2:
in_1.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
if(beta != 0.0f)
if(beta_ != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
break;
default:
in_1.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-5.0, 5.0}, num_thread);
if(beta != 0.0f)
if(beta_ != 0.0f)
out_ref.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-5.0, 5.0},
num_thread);
}
if(beta != 0.0f)
if(beta_ != 0.0f)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++)
out.mData[i] = out_ref.mData[i];
};
@@ -200,7 +200,7 @@ int main(int argc, char* argv[])
in_1_dev.ToDevice(in_1.mData.data());
if(beta != 0.0f)
if(beta_ != 0.0f)
out_dev.ToDevice(out.mData.data());
InElementwiseOperation in_elementwise_op;
@@ -246,7 +246,7 @@ int main(int argc, char* argv[])
arrOutStrides,
reduceDims,
static_cast<double>(alpha),
static_cast<double>(beta),
static_cast<double>(beta_),
in_1.mData.data(),
nullptr,
out_ref.mData.data(),
@@ -298,7 +298,7 @@ int main(int argc, char* argv[])
arrOutStrides,
reduceDims_2,
static_cast<double>(alpha),
static_cast<double>(beta),
static_cast<double>(beta_),
in_2_dev.GetDeviceBuffer(),
nullptr,
out_dev.GetDeviceBuffer(),

View File

@@ -8,6 +8,7 @@
#include <ostream>
#include <string>
#include <tuple>
#include <numeric>
#include "ck_tile/host.hpp"
#include "flatmm_basic.hpp"

View File

@@ -8,6 +8,7 @@
#include <ostream>
#include <string>
#include <tuple>
#include <numeric>
#include "flatmm_basic.hpp"

View File

@@ -166,7 +166,7 @@ int run_contiguous_grouped_flatmm_example_with_layouts(
}
ck_tile::index_t M =
std::reduce(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) {
std::accumulate(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) {
// round up to the multiple of BlockM
return acc + (group_m + BlockM - 1) / BlockM * BlockM;
});

View File

@@ -35,16 +35,19 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
{
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr bool IS_FP8BLOCKSCALE =
QuantMode == ck_tile::QuantType::ABQuantGrouped && BQuantGroupSize::kN == 128 &&
QuantMode == ck_tile::QuantType::ABQuantGrouped &&
(std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>) &&
(std::is_same_v<typename TypeConfig::BDataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::bf8_t>);
constexpr bool transpose_c = GemmConfig::TransposeC;
constexpr bool eight_warps =
IS_FP8BLOCKSCALE && BQuantGroupSize::kN == 128 &&
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
#ifdef CK_GFX950_SUPPORT
IS_FP8BLOCKSCALE && (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
GemmConfig::K_Warp_Tile == 128;
#else
false;
#endif
using ComputeDataType =
std::conditional_t<IS_FP8BLOCKSCALE, typename TypeConfig::ADataType, void>;