mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Blockscale Gemm Fix Multi-Arch Compilation (#4451)
## Motivation This PR updates CK_TILE blockscale GEMM-quant kernels and launch helpers to compile across multiple GPU architectures by introducing compile-time availability gating and a new attribute tag mechanism for kernel symbol/attribute specialization. ## Technical Details - Add an architecture-guarded `kIsAvailable` flag to the gfx950 pipeline and propagate availability handling into `QuantGemmKernel`. - Extend `make_kernel`/`kentry` to accept an `Attr` tag enabling per-kernel compile-time attributes (e.g., `no-packed-fp32-ops`) and unique symbols. - Update the blockscale GEMM quant example to pass kernel attributes and adjust gfx950 gating. ## Test Plan - CI - Local test: `cmake .. --preset dev -DGPU_TARGETS='gfx942;gfx950' -GNinja && ninja tile_example_gemm_quant` - Local test with ROCm/aiter#1954 ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
@@ -10,10 +10,6 @@ list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -Wno-global-constructors) # use global
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -enable-noalias-to-md-conversion=1")
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_EIGHTWARP_SUP)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
set(EXE_NAME tile_example_gemm_quant)
|
||||
add_executable(${EXE_NAME}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
#if defined(CK_TILE_EIGHTWARP_SUP)
|
||||
#if defined(CK_USE_GFX950)
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigEightWarps<T>;
|
||||
template <typename T>
|
||||
|
||||
@@ -246,6 +246,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
<< std::endl;
|
||||
}
|
||||
float ave_time = 0;
|
||||
using k_attr_t = ck_tile::kernel_attr<eight_warps>;
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
@@ -284,13 +285,15 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu, k_attr_t>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
ave_time = ck_tile::launch_kernel( //
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu, k_attr_t>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
|
||||
@@ -15,37 +15,57 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
#if CK_TILE_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
|
||||
#endif
|
||||
#if defined(__HIP_DEVICE_COMPILE__) && defined(CK_TILE_EIGHTWARP_SUP)
|
||||
__attribute__((target("no-packed-fp32-ops")))
|
||||
#endif
|
||||
__global__ void kentry(Args... args)
|
||||
template <typename T, typename = void>
|
||||
inline constexpr bool kattr_no_packed_fp32_ops_v = false;
|
||||
template <typename T>
|
||||
inline constexpr bool
|
||||
kattr_no_packed_fp32_ops_v<T, std::void_t<decltype(T::kattr_no_packed_fp32_ops)>> =
|
||||
T::kattr_no_packed_fp32_ops;
|
||||
|
||||
template <bool no_packed_fp32_ops>
|
||||
struct kernel_attr
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
Kernel{}(args...);
|
||||
// The kernel function attribute "no-packed-fp32-ops": Disable the use of packed FP32
|
||||
// instructions so that they can be co-executed with matrix operations
|
||||
static constexpr bool kattr_no_packed_fp32_ops = no_packed_fp32_ops;
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_LAUNCH_BOUNDS
|
||||
#define KENTRY_LAUNCH_BOUNDS __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
|
||||
#else
|
||||
(..., (ignore = args, 0));
|
||||
#define KENTRY_LAUNCH_BOUNDS
|
||||
#endif
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
#define KENTRY_BODY Kernel{}(args...)
|
||||
#define KENTRY_ATTR_NO_PACKED_FP32_OPS __attribute__((target("no-packed-fp32-ops")))
|
||||
#else
|
||||
#define KENTRY_BODY (..., (ignore = args, 0))
|
||||
#define KENTRY_ATTR_NO_PACKED_FP32_OPS
|
||||
#endif
|
||||
|
||||
template <int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
KENTRY_LAUNCH_BOUNDS __global__ void kentry(Args... args)
|
||||
{
|
||||
KENTRY_BODY;
|
||||
}
|
||||
template <typename Attr, int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
KENTRY_LAUNCH_BOUNDS __global__ //
|
||||
std::enable_if_t<!kattr_no_packed_fp32_ops_v<Attr>>
|
||||
kentry(Args... args)
|
||||
{
|
||||
KENTRY_BODY;
|
||||
}
|
||||
template <typename Attr, int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
KENTRY_LAUNCH_BOUNDS KENTRY_ATTR_NO_PACKED_FP32_OPS __global__ //
|
||||
std::enable_if_t<kattr_no_packed_fp32_ops_v<Attr>>
|
||||
kentry(Args... args)
|
||||
{
|
||||
KENTRY_BODY;
|
||||
}
|
||||
|
||||
template <typename Arch, int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
#if CK_TILE_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
|
||||
#endif
|
||||
#if defined(__HIP_DEVICE_COMPILE__) && defined(CK_TILE_EIGHTWARP_SUP)
|
||||
__attribute__((target("no-packed-fp32-ops")))
|
||||
#endif
|
||||
__global__ void kentry(Args... args)
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
Kernel{}(args...);
|
||||
#else
|
||||
(..., (ignore = args, 0));
|
||||
#endif
|
||||
}
|
||||
#undef KENTRY_LAUNCH_BOUNDS
|
||||
#undef KENTRY_BODY
|
||||
#undef KENTRY_ATTR_NO_PACKED_FP32_OPS
|
||||
|
||||
//
|
||||
// return a anonymous functor(lambda) to be called later
|
||||
@@ -54,26 +74,22 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
|
||||
//
|
||||
// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
|
||||
//
|
||||
// Arch can be used to support linking multiple object files that have the same kernel compiled for
|
||||
// Attr can be used to support linking multiple object files that have the same kernel compiled for
|
||||
// different architectures. In this case each object file has to use a different tag (gfx9_t,
|
||||
// gfx12_t etc.), so the kernel will have different symbols for each architecture.
|
||||
//
|
||||
// gfx12_t etc.), so the kernel will have different symbols for each architecture. It can also be
|
||||
// used to pass some compile-time attributes to the kernel.
|
||||
template <int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
|
||||
typename Arch = void,
|
||||
typename Attr = void,
|
||||
typename KernelImpl,
|
||||
typename... Args>
|
||||
CK_TILE_HOST auto
|
||||
make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
{
|
||||
const auto kernel = []() {
|
||||
if constexpr(std::is_void_v<Arch>)
|
||||
{
|
||||
if constexpr(std::is_void_v<Attr>)
|
||||
return kentry<MinBlockPerCu, KernelImpl, Args...>;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kentry<Arch, MinBlockPerCu, KernelImpl, Args...>;
|
||||
}
|
||||
return kentry<Attr, MinBlockPerCu, KernelImpl, Args...>;
|
||||
}();
|
||||
return [=](const stream_config& s) {
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
|
||||
@@ -1354,7 +1354,7 @@ struct QuantGemmKernel
|
||||
{
|
||||
m = kargs.M;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(
|
||||
return GemmPipeline{}(
|
||||
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr, m);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
@@ -1364,7 +1364,7 @@ struct QuantGemmKernel
|
||||
{
|
||||
n = kargs.N;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(
|
||||
return GemmPipeline{}(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr, n);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
@@ -1376,20 +1376,19 @@ struct QuantGemmKernel
|
||||
// m = kargs.M;
|
||||
n = kargs.N;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
aq_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
smem_ptr,
|
||||
m,
|
||||
n);
|
||||
return GemmPipeline{}(a_block_window,
|
||||
b_block_window,
|
||||
aq_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
smem_ptr,
|
||||
m,
|
||||
n);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant ||
|
||||
kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr);
|
||||
return GemmPipeline{}(a_block_window, b_block_window, num_loop, smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1454,7 +1453,7 @@ struct QuantGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
|
||||
CK_TILE_DEVICE void Run_(const QuantGemmKernelArgs& kargs) const
|
||||
{
|
||||
const auto blockId = amd_wave_read_first_lane(blockIdx.x);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
|
||||
@@ -1478,6 +1477,20 @@ struct QuantGemmKernel
|
||||
RunGemm(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
|
||||
template <typename T, typename = void>
|
||||
static constexpr bool kIsAvailableV = true;
|
||||
template <typename T>
|
||||
static constexpr bool kIsAvailableV<T, std::void_t<decltype(T::kIsAvailable)>> =
|
||||
T::kIsAvailable;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const QuantGemmKernelArgs& kargs) const
|
||||
{
|
||||
if constexpr(!kIsAvailableV<GemmPipeline>)
|
||||
ignore = kargs;
|
||||
else
|
||||
Run_(kargs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -26,6 +26,11 @@ struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmABQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool kIsAvailable = true;
|
||||
#else
|
||||
static constexpr bool kIsAvailable = false;
|
||||
#endif
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
|
||||
Reference in New Issue
Block a user