mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4451 (commit 091bf0f)
[CK_TILE] Blockscale Gemm Fix Multi-Arch Compilation ## Motivation This PR updates CK_TILE blockscale GEMM-quant kernels and launch helpers to compile across multiple GPU architectures by introducing compile-time availability gating and a new attribute tag mechanism for kernel symbol/attribute specialization. ## Technical Details - Add an architecture-guarded `kIsAvailable` flag to the gfx950 pipeline and propagate availability handling into `QuantGemmKernel`. - Extend `make_kernel`/`kentry` to accept an `Attr` tag enabling per-kernel compile-time attributes (e.g., `no-packed-fp32-ops`) and unique symbols. - Update the blockscale GEMM quant example to pass kernel attributes and adjust gfx950 gating. ## Test Plan - CI - Local test: `cmake .. --preset dev -DGPU_TARGETS='gfx942;gfx950' -GNinja && ninja tile_example_gemm_quant` - Local test with ROCm/aiter#1954 ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
6a6cd05dbb
commit
d5acfd8d52
@@ -1354,7 +1354,7 @@ struct QuantGemmKernel
|
||||
{
|
||||
m = kargs.M;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(
|
||||
return GemmPipeline{}(
|
||||
a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr, m);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
@@ -1364,7 +1364,7 @@ struct QuantGemmKernel
|
||||
{
|
||||
n = kargs.N;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(
|
||||
return GemmPipeline{}(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr, n);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
@@ -1376,20 +1376,19 @@ struct QuantGemmKernel
|
||||
// m = kargs.M;
|
||||
n = kargs.N;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
aq_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
smem_ptr,
|
||||
m,
|
||||
n);
|
||||
return GemmPipeline{}(a_block_window,
|
||||
b_block_window,
|
||||
aq_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
smem_ptr,
|
||||
m,
|
||||
n);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant ||
|
||||
kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr);
|
||||
return GemmPipeline{}(a_block_window, b_block_window, num_loop, smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1454,7 +1453,7 @@ struct QuantGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
|
||||
CK_TILE_DEVICE void Run_(const QuantGemmKernelArgs& kargs) const
|
||||
{
|
||||
const auto blockId = amd_wave_read_first_lane(blockIdx.x);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
|
||||
@@ -1478,6 +1477,20 @@ struct QuantGemmKernel
|
||||
RunGemm(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
|
||||
template <typename T, typename = void>
|
||||
static constexpr bool kIsAvailableV = true;
|
||||
template <typename T>
|
||||
static constexpr bool kIsAvailableV<T, std::void_t<decltype(T::kIsAvailable)>> =
|
||||
T::kIsAvailable;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const QuantGemmKernelArgs& kargs) const
|
||||
{
|
||||
if constexpr(!kIsAvailableV<GemmPipeline>)
|
||||
ignore = kargs;
|
||||
else
|
||||
Run_(kargs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user