[CK_TILE] Blockscale Gemm Fix Multi-Arch Compilation (#4451)

## Motivation
This PR updates CK_TILE blockscale GEMM-quant kernels and launch helpers
to compile across multiple GPU architectures by introducing compile-time
availability gating and a new attribute tag mechanism for kernel
symbol/attribute specialization.

## Technical Details
- Add an architecture-guarded `kIsAvailable` flag to the gfx950 pipeline
and propagate availability handling into `QuantGemmKernel`.
- Extend `make_kernel`/`kentry` to accept an `Attr` tag enabling
per-kernel compile-time attributes (e.g., `no-packed-fp32-ops`) and
unique symbols.
- Update the blockscale GEMM quant example to pass kernel attributes and
adjust gfx950 gating.

## Test Plan
- CI
- Local test: `cmake .. --preset dev -DGPU_TARGETS='gfx942;gfx950'
-GNinja && ninja tile_example_gemm_quant`
- Local test with ROCm/aiter#1954
## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Yi DING
2026-02-10 20:41:09 +08:00
committed by GitHub
parent d61393a714
commit 1ac61a54c9
6 changed files with 90 additions and 57 deletions

View File

@@ -10,10 +10,6 @@ list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -Wno-global-constructors) # use global
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -enable-noalias-to-md-conversion=1")
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
if(GPU_TARGETS MATCHES "gfx95")
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_EIGHTWARP_SUP)
endif()
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
set(EXE_NAME tile_example_gemm_quant)
add_executable(${EXE_NAME}

View File

@@ -3,7 +3,7 @@
#include "run_gemm_quant_example.inc"
#if defined(CK_TILE_EIGHTWARP_SUP)
#if defined(CK_USE_GFX950)
template <typename T>
using GemmConfig = GemmConfigEightWarps<T>;
template <typename T>

View File

@@ -246,6 +246,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
<< std::endl;
}
float ave_time = 0;
using k_attr_t = ck_tile::kernel_attr<eight_warps>;
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
@@ -284,13 +285,15 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<GemmConfig::kBlockPerCu, k_attr_t>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(
ave_time = ck_tile::launch_kernel( //
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<GemmConfig::kBlockPerCu, k_attr_t>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;

View File

@@ -15,37 +15,57 @@
namespace ck_tile {
template <int MinBlockPerCu, typename Kernel, typename... Args>
#if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
#endif
#if defined(__HIP_DEVICE_COMPILE__) && defined(CK_TILE_EIGHTWARP_SUP)
__attribute__((target("no-packed-fp32-ops")))
#endif
__global__ void kentry(Args... args)
template <typename T, typename = void>
inline constexpr bool kattr_no_packed_fp32_ops_v = false;
template <typename T>
inline constexpr bool
kattr_no_packed_fp32_ops_v<T, std::void_t<decltype(T::kattr_no_packed_fp32_ops)>> =
T::kattr_no_packed_fp32_ops;
template <bool no_packed_fp32_ops>
struct kernel_attr
{
#if defined(__HIP_DEVICE_COMPILE__)
Kernel{}(args...);
// The kernel function attribute "no-packed-fp32-ops": Disable the use of packed FP32
// instructions so that they can be co-executed with matrix operations
static constexpr bool kattr_no_packed_fp32_ops = no_packed_fp32_ops;
};
#if CK_TILE_USE_LAUNCH_BOUNDS
#define KENTRY_LAUNCH_BOUNDS __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
#else
(..., (ignore = args, 0));
#define KENTRY_LAUNCH_BOUNDS
#endif
#if defined(__HIP_DEVICE_COMPILE__)
#define KENTRY_BODY Kernel{}(args...)
#define KENTRY_ATTR_NO_PACKED_FP32_OPS __attribute__((target("no-packed-fp32-ops")))
#else
#define KENTRY_BODY (..., (ignore = args, 0))
#define KENTRY_ATTR_NO_PACKED_FP32_OPS
#endif
template <int MinBlockPerCu, typename Kernel, typename... Args>
KENTRY_LAUNCH_BOUNDS __global__ void kentry(Args... args)
{
KENTRY_BODY;
}
template <typename Attr, int MinBlockPerCu, typename Kernel, typename... Args>
KENTRY_LAUNCH_BOUNDS __global__ //
std::enable_if_t<!kattr_no_packed_fp32_ops_v<Attr>>
kentry(Args... args)
{
KENTRY_BODY;
}
template <typename Attr, int MinBlockPerCu, typename Kernel, typename... Args>
KENTRY_LAUNCH_BOUNDS KENTRY_ATTR_NO_PACKED_FP32_OPS __global__ //
std::enable_if_t<kattr_no_packed_fp32_ops_v<Attr>>
kentry(Args... args)
{
KENTRY_BODY;
}
template <typename Arch, int MinBlockPerCu, typename Kernel, typename... Args>
#if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
#endif
#if defined(__HIP_DEVICE_COMPILE__) && defined(CK_TILE_EIGHTWARP_SUP)
__attribute__((target("no-packed-fp32-ops")))
#endif
__global__ void kentry(Args... args)
{
#if defined(__HIP_DEVICE_COMPILE__)
Kernel{}(args...);
#else
(..., (ignore = args, 0));
#endif
}
#undef KENTRY_LAUNCH_BOUNDS
#undef KENTRY_BODY
#undef KENTRY_ATTR_NO_PACKED_FP32_OPS
//
// return a anonymous functor(lambda) to be called later
@@ -54,26 +74,22 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
//
// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
//
// Arch can be used to support linking multiple object files that have the same kernel compiled for
// Attr can be used to support linking multiple object files that have the same kernel compiled for
// different architectures. In this case each object file has to use a different tag (gfx9_t,
// gfx12_t etc.), so the kernel will have different symbols for each architecture.
//
// gfx12_t etc.), so the kernel will have different symbols for each architecture. It can also be
// used to pass some compile-time attributes to the kernel.
template <int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
typename Arch = void,
typename Attr = void,
typename KernelImpl,
typename... Args>
CK_TILE_HOST auto
make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{
const auto kernel = []() {
if constexpr(std::is_void_v<Arch>)
{
if constexpr(std::is_void_v<Attr>)
return kentry<MinBlockPerCu, KernelImpl, Args...>;
}
else
{
return kentry<Arch, MinBlockPerCu, KernelImpl, Args...>;
}
return kentry<Attr, MinBlockPerCu, KernelImpl, Args...>;
}();
return [=](const stream_config& s) {
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);

View File

@@ -1354,7 +1354,7 @@ struct QuantGemmKernel
{
m = kargs.M;
}
return GemmPipeline{}.template operator()(
return GemmPipeline{}(
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr, m);
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
@@ -1364,7 +1364,7 @@ struct QuantGemmKernel
{
n = kargs.N;
}
return GemmPipeline{}.template operator()(
return GemmPipeline{}(
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr, n);
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
@@ -1376,20 +1376,19 @@ struct QuantGemmKernel
// m = kargs.M;
n = kargs.N;
}
return GemmPipeline{}.template operator()(a_block_window,
b_block_window,
aq_block_window,
bq_block_window,
num_loop,
smem_ptr,
m,
n);
return GemmPipeline{}(a_block_window,
b_block_window,
aq_block_window,
bq_block_window,
num_loop,
smem_ptr,
m,
n);
}
else if constexpr(kQuantType == QuantType::RowColQuant ||
kQuantType == QuantType::TensorQuant)
{
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr);
return GemmPipeline{}(a_block_window, b_block_window, num_loop, smem_ptr);
}
}();
@@ -1454,7 +1453,7 @@ struct QuantGemmKernel
}
}
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
CK_TILE_DEVICE void Run_(const QuantGemmKernelArgs& kargs) const
{
const auto blockId = amd_wave_read_first_lane(blockIdx.x);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
@@ -1478,6 +1477,20 @@ struct QuantGemmKernel
RunGemm(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
template <typename T, typename = void>
static constexpr bool kIsAvailableV = true;
template <typename T>
static constexpr bool kIsAvailableV<T, std::void_t<decltype(T::kIsAvailable)>> =
T::kIsAvailable;
CK_TILE_DEVICE void operator()(const QuantGemmKernelArgs& kargs) const
{
if constexpr(!kIsAvailableV<GemmPipeline>)
ignore = kargs;
else
Run_(kargs);
}
};
} // namespace ck_tile

View File

@@ -26,6 +26,11 @@ struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Prob
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmABQuantPipelineAgBgCrImplBase<Problem, Policy>;
#if defined(__gfx950__)
static constexpr bool kIsAvailable = true;
#else
static constexpr bool kIsAvailable = false;
#endif
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;