diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 4e5c4c56ef..fb4692f2eb 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -10,10 +10,6 @@ list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -Wno-global-constructors) # use global list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -enable-noalias-to-md-conversion=1") list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1") -if(GPU_TARGETS MATCHES "gfx95") - list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_EIGHTWARP_SUP) -endif() - if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") set(EXE_NAME tile_example_gemm_quant) add_executable(${EXE_NAME} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index 566f174d89..4ece442158 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -3,7 +3,7 @@ #include "run_gemm_quant_example.inc" -#if defined(CK_TILE_EIGHTWARP_SUP) +#if defined(CK_USE_GFX950) template using GemmConfig = GemmConfigEightWarps; template diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index ea5683601a..dd7e1abb02 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -246,6 +246,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str << std::endl; } float ave_time = 0; + using k_attr_t = ck_tile::kernel_attr; if(s.flush_cache_) { std::cout << "Flushing cache..." << std::endl; @@ -284,13 +285,15 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ave_time = ck_tile::launch_kernel_time_mask( s, run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = ck_tile::launch_kernel( + ave_time = ck_tile::launch_kernel( // s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } return ave_time; diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index 881a26b259..9b3fb88c16 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -15,37 +15,57 @@ namespace ck_tile { -template -#if CK_TILE_USE_LAUNCH_BOUNDS -__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) -#endif -#if defined(__HIP_DEVICE_COMPILE__) && defined(CK_TILE_EIGHTWARP_SUP) - __attribute__((target("no-packed-fp32-ops"))) -#endif - __global__ void kentry(Args... args) +template +inline constexpr bool kattr_no_packed_fp32_ops_v = false; +template +inline constexpr bool + kattr_no_packed_fp32_ops_v> = + T::kattr_no_packed_fp32_ops; + +template +struct kernel_attr { -#if defined(__HIP_DEVICE_COMPILE__) - Kernel{}(args...); + // The kernel function attribute "no-packed-fp32-ops": Disable the use of packed FP32 + // instructions so that they can be co-executed with matrix operations + static constexpr bool kattr_no_packed_fp32_ops = no_packed_fp32_ops; +}; + +#if CK_TILE_USE_LAUNCH_BOUNDS +#define KENTRY_LAUNCH_BOUNDS __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) #else - (..., (ignore = args, 0)); +#define KENTRY_LAUNCH_BOUNDS #endif +#if defined(__HIP_DEVICE_COMPILE__) +#define KENTRY_BODY Kernel{}(args...) +#define KENTRY_ATTR_NO_PACKED_FP32_OPS __attribute__((target("no-packed-fp32-ops"))) +#else +#define KENTRY_BODY (..., (ignore = args, 0)) +#define KENTRY_ATTR_NO_PACKED_FP32_OPS +#endif + +template +KENTRY_LAUNCH_BOUNDS __global__ void kentry(Args... args) +{ + KENTRY_BODY; +} +template +KENTRY_LAUNCH_BOUNDS __global__ // + std::enable_if_t> + kentry(Args... args) +{ + KENTRY_BODY; +} +template +KENTRY_LAUNCH_BOUNDS KENTRY_ATTR_NO_PACKED_FP32_OPS __global__ // + std::enable_if_t> + kentry(Args... args) +{ + KENTRY_BODY; } -template -#if CK_TILE_USE_LAUNCH_BOUNDS -__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) -#endif -#if defined(__HIP_DEVICE_COMPILE__) && defined(CK_TILE_EIGHTWARP_SUP) - __attribute__((target("no-packed-fp32-ops"))) -#endif - __global__ void kentry(Args... args) -{ -#if defined(__HIP_DEVICE_COMPILE__) - Kernel{}(args...); -#else - (..., (ignore = args, 0)); -#endif -} +#undef KENTRY_LAUNCH_BOUNDS +#undef KENTRY_BODY +#undef KENTRY_ATTR_NO_PACKED_FP32_OPS // // return a anonymous functor(lambda) to be called later @@ -54,26 +74,22 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) // // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl // -// Arch can be used to support linking multiple object files that have the same kernel compiled for +// Attr can be used to support linking multiple object files that have the same kernel compiled for // different architectures. In this case each object file has to use a different tag (gfx9_t, -// gfx12_t etc.), so the kernel will have different symbols for each architecture. -// +// gfx12_t etc.), so the kernel will have different symbols for each architecture. It can also be +// used to pass some compile-time attributes to the kernel. template CK_TILE_HOST auto make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { const auto kernel = []() { - if constexpr(std::is_void_v) - { + if constexpr(std::is_void_v) return kentry; - } else - { - return kentry; - } + return kentry; }(); return [=](const stream_config& s) { kernel<<>>(args...); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 63682040ad..05e8aa62a9 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -1354,7 +1354,7 @@ struct QuantGemmKernel { m = kargs.M; } - return GemmPipeline{}.template operator()( + return GemmPipeline{}( a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr, m); } else if constexpr(kQuantType == QuantType::BQuantGrouped) @@ -1364,7 +1364,7 @@ struct QuantGemmKernel { n = kargs.N; } - return GemmPipeline{}.template operator()( + return GemmPipeline{}( a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr, n); } else if constexpr(kQuantType == QuantType::ABQuantGrouped) @@ -1376,20 +1376,19 @@ struct QuantGemmKernel // m = kargs.M; n = kargs.N; } - return GemmPipeline{}.template operator()(a_block_window, - b_block_window, - aq_block_window, - bq_block_window, - num_loop, - smem_ptr, - m, - n); + return GemmPipeline{}(a_block_window, + b_block_window, + aq_block_window, + bq_block_window, + num_loop, + smem_ptr, + m, + n); } else if constexpr(kQuantType == QuantType::RowColQuant || kQuantType == QuantType::TensorQuant) { - return GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr); + return GemmPipeline{}(a_block_window, b_block_window, num_loop, smem_ptr); } }(); @@ -1454,7 +1453,7 @@ struct QuantGemmKernel } } - CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const + CK_TILE_DEVICE void Run_(const QuantGemmKernelArgs& kargs) const { const auto blockId = amd_wave_read_first_lane(blockIdx.x); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); @@ -1478,6 +1477,20 @@ struct QuantGemmKernel RunGemm( a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } + + template + static constexpr bool kIsAvailableV = true; + template + static constexpr bool kIsAvailableV> = + T::kIsAvailable; + + CK_TILE_DEVICE void operator()(const QuantGemmKernelArgs& kargs) const + { + if constexpr(!kIsAvailableV) + ignore = kargs; + else + Run_(kargs); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async.hpp index c036115de2..53ab399f98 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_async.hpp @@ -26,6 +26,11 @@ struct ABQuantGemmPipelineAgBgCrAsync : public BaseGemmPipelineAgBgCrCompV3; using PipelineImplBase = GemmABQuantPipelineAgBgCrImplBase; +#if defined(__gfx950__) + static constexpr bool kIsAvailable = true; +#else + static constexpr bool kIsAvailable = false; +#endif using ADataType = remove_cvref_t; using AQDataType = remove_cvref_t;