From d1193e8637a4ac82217d0413e67ed52700c7f8fc Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 4 Dec 2025 18:29:14 -0800 Subject: [PATCH 01/24] fix hipblaslt build for different archs (#3358) --- Dockerfile.pytorch | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.pytorch b/Dockerfile.pytorch index 4533166c06..9628bf46fa 100644 --- a/Dockerfile.pytorch +++ b/Dockerfile.pytorch @@ -29,4 +29,4 @@ RUN groupadd -g 109 render && \ git sparse-checkout set projects/hipblaslt shared/origami && \ cd projects/hipblaslt && \ git show --oneline -s && \ - CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --logic-yaml-filter gfx950/*/* --architecture="gfx942;gfx950" -j 128 --skip_rocroller + CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx90a;gfx942;gfx950" -j 128 --skip_rocroller From 05292b3604e143e98ec2cb67edb2e3d2ad1d6ecb Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 5 Dec 2025 10:31:12 +0800 Subject: [PATCH 02/24] [CK_TILE][FMHA] Integrate FAv2 & FAv3 (WIP) in the single fmha_fwd() API (#3153) * Let fmha_fwd_v3() compatible with fmha_fwd() * Decouple get_fwd_blobs() and FmhaFwdKernel * Decouple compatibility checks from get_fwd_blobs() * Extract product feature checks out from get_fwd_blobs() * Remove duplicated code in factories and redundant checks * Remove FmhaFwdKernel<>::GetName() * Let FmhaFwdApiPool support pipelines with different mask_impl * Add tile setting for fmha fwd v3 pipeline * Add fwd v3 instances to tile_example_fmha_fwd manually * Remove unused function import * Undo irrelevant changes * Remove fwd v3 instances from tile_example_fmha_fwd * Finish fmha fwd v3 kernel instance codegen * Fix formatting * Remove unused F_idx attribute * Add is_generic_attention_mask<> traits * Add constraints to the fmha fwd v3 pipeline * Unify traits & problem used for fmha fwd v3 * Unify kernel launch code for fmha fwd v2 & v3 * Unify kernel template selection logic * Use same kernel codegen template for both v2 & v3 * Rename api() property as render() method * Allow specifying filter for fmha fwd api pool * Allow specifying function name when rendering api pool items * Separate fmha fwd v3 kernel dispatching logic from v2 * Remove lambda assignment * Add simple v2/v3 dispatch logic * Stop generating empty if-clauses Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them. * Use "".join() to concatenate fmha fwd api string content * Add more feature checks for fmha fwd v3 pipeline * Check features before dispatch to fmha_fwd_v3() * Add more feature checks for fmha_fwd_v3() * Add missing filter call * Use Tuple to reserve the dtype orders * Fix wrong pipeline matching logic * Add fmha fwd v3 group mode instances * Add functor_transform<> * Add type constraints to make_tile_window() * Remove fmha fwd v3 example * Fix wrong product(aiter mha_fwd()) config * Fix wrong fmha fwd v2/v3 selection logic * Fix formatting * Add comment to warning v3 kernel users * Fix wrong codegen logics * Remove unnecessary param * Fix format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- example/ck_tile/01_fmha/CMakeLists.txt | 34 - .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 20 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 827 ++++++++++++------ .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 616 ------------- example/ck_tile/01_fmha/fmha_fwd.hpp | 94 ++ example/ck_tile/01_fmha/fmha_fwd_v3.cpp | 60 -- example/ck_tile/01_fmha/fmha_fwd_v3.hpp | 73 -- example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 179 ---- .../instances/fmha_fwd_v3_d128_bf16_mask.cpp | 14 - .../instances/fmha_fwd_v3_d128_bf16_nmask.cpp | 14 - .../instances/fmha_fwd_v3_d128_fp16_mask.cpp | 14 - .../instances/fmha_fwd_v3_d128_fp16_nmask.cpp | 14 - .../core/algorithm/coordinate_transform.hpp | 82 ++ include/ck_tile/core/tensor/tile_window.hpp | 9 +- .../ck_tile/ops/fmha/block/block_masking.hpp | 13 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 48 - .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 134 +-- .../pipeline/block_fmha_fwd_v3_pipeline.hpp | 32 +- .../pipeline/block_fmha_pipeline_enum.hpp | 1 + .../pipeline/block_fmha_pipeline_problem.hpp | 43 - .../ops/fmha/pipeline/tile_fmha_traits.hpp | 16 - include/ck_tile/remod.py | 2 +- 22 files changed, 890 insertions(+), 1449 deletions(-) delete mode 100644 example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp delete mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3.cpp delete mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3.hpp delete mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp delete mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 0c8102a70b..6e7d69281d 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -208,40 +208,6 @@ add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp) target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -# add fmha_fwd_v3 example -set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3") -message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") - -add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp) -target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS - "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" -) -target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE - fmha_fwd_v3.cpp - ${FMHA_FWD_V3_INSTANCES} -) - -set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS) -list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS - -fgpu-flush-denormals-to-zero - -Wno-undefined-func-template - --save-temps -) -set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) - -check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) -if(HAS_DISABLE_PACKED_FP32) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS - -mllvm --amdgpu-disable-packed-fp32=1 - ) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS - -DCK_TILE_DISABLE_PACKED_FP32=1 - ) -endif() - -target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) -target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 333579ec8d..a3cfe2622a 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -30,16 +30,24 @@ _MASK_MAP = { } -def get_mask_map(mask: str): - if mask == "generic": +def get_mask_map(mask_impl: str): + if mask_impl == "generic": return _MASK_MAP - elif mask == "simplified": + elif mask_impl == "simplified": return _MASK_SIMPLIFIED_MAP else: assert False return None +def get_mask_impl(mask: str) -> str: + return "simplified" if mask.startswith("s_") else "generic" + + +def get_mask_cpp_type(mask: str) -> str: + return get_mask_map(get_mask_impl(mask))[mask] + + _MASK_CHECK_MAP = { "no": "t.mask_type == mask_enum::no_mask", "causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", @@ -62,6 +70,10 @@ def get_mask_check_map(mask: str): return None +def get_mask_cpp_check_expr(mask: str) -> str: + return get_mask_check_map(get_mask_impl(mask))[mask] + + QSCALE_MAP = { "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", @@ -122,6 +134,7 @@ PIPELINE_MAP = { "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs": "ck_tile::BlockFmhaPipelineQSKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", + "qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline", } PIPELINE_ENUM_MAP = { @@ -131,6 +144,7 @@ PIPELINE_ENUM_MAP = { "qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS", "qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", + "qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 17d4f6e1d7..c00bdcea3b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -8,14 +8,13 @@ import os from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path -from typing import List, Optional, Tuple +from typing import Callable, ClassVar, Iterable, List, Optional, Tuple from codegen.arch import ArchTrait, get_factories_for_targets from codegen.cmake_config import GEN_DIR from codegen.cpp_symbol_map import ( LAYOUT_MAP, BIAS_CHECK_MAP, - get_mask_check_map, BOOL_MAP, PIPELINE_MAP, PIPELINE_ENUM_MAP, @@ -23,6 +22,8 @@ from codegen.cpp_symbol_map import ( FWD_DTYPE_MAP, BIAS_MAP, get_mask_map, + get_mask_cpp_type, + get_mask_cpp_check_expr, QSCALE_CHECK_MAP, QSCALE_MAP, ) @@ -48,79 +49,79 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT #include "fmha_fwd.hpp" """ -FMHA_FWD_KERNEL_BODY = """ +FMHA_FWD_KERNEL_BODY_TEMPLATE = """ #include #if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) -using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_dtype = {F_dtype}; -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; +using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; +using fmha_shape = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_logits}, - {F_bias}, - false, - {F_lse}, - {F_dropout}, - {F_qscale}, - {F_occupancy}, - {F_skip}>; +using fmha_traits = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_qscale}, + {F_occupancy}, + {F_skip}>; -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; +using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; -using fmha_mask_{F_idx} = {F_mask}; +using fmha_mask = {F_mask}; -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, +using fmha_pipeline_problem = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape, {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, + fmha_variant, + fmha_mask, {F_trload}, - fmha_trait_{F_idx}>; + fmha_traits>; -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; +using fmha_pipeline = {F_pipeline}< + fmha_pipeline_problem>; -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + {F_spad}, {F_dvpad}>>; -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel; +using fmha_kernel = {F_kernel}; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + +using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; template<> -float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ - using k_ = fmha_kernel_{F_idx}; + using k_ = fmha_kernel; if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + std::cout << ", {F_kname}" << std::flush; + auto [kargs, grids] = {F_kargs_creator}(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); @@ -130,40 +131,47 @@ float fmha_fwd_(const ck_tile::stream_config& s, fm """ FMHA_FWD_API_FILENAME = "fmha_fwd_api.cpp" -FMHA_FWD_API = """ +FMHA_FWD_API_HEADER = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py #include #include -namespace {{ -bool get_num_cus(unsigned& num_cus) {{ +#include "fmha_fwd.hpp" + +namespace { +bool get_num_cus(unsigned& num_cus) { int device; auto status = hipGetDevice(&device); - if(status != hipSuccess) {{ + if(status != hipSuccess) { fprintf(stderr, "failed to get device"); return false; - }} + } - hipDeviceProp_t props{{}}; + hipDeviceProp_t props{}; status = hipGetDeviceProperties(&props, device); - if(status != hipSuccess) {{ + if(status != hipSuccess) { fprintf(stderr, "failed to get device properties"); return false; - }} + } num_cus = props.multiProcessorCount; return true; -}} +} -unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) { const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 return batch * nheads * num_m_blocks * num_n_blocks; -}} -}} // namespace - -float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s) {{ +} +} // namespace +""" +FMHA_FWD_API_FUNC_TEMPLATE = """ +namespace {{ +float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{ float r = -1; [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate @@ -182,6 +190,28 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& {F_dispatch} return r; }} +}} // namespace +""" +FMHA_FWD_API_FOOTER_TEMPLATE = """ +float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{ + const std::string device_name = ck_tile::get_device_name(); + + const bool is_swa = (traits.mask_type != mask_enum::no_mask) and + ((0 < args.window_size_left) or (0 < args.window_size_right)); + const bool can_dispatch_v3 = + (device_name.compare(0, 6, "gfx950") == 0) and + (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and + traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and + (traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and + (not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and + (not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and + (args.hdim_v == 128); + if ({F_is_v3_enabled} and can_dispatch_v3) {{ + return fmha_fwd_v3(traits, args, config); + }} else {{ + return fmha_fwd_v2(traits, args, config); + }} +}} """ FMHA_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{ @@ -261,7 +291,7 @@ class FmhaFwdApiTrait: def scheck(self) -> str: if self.mode == "group": return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true - if self.pipeline_tag in ["qr_async", "qr_async_trload"]: + if self.pipeline_tag in ["qr_async", "qr_async_trload", "qr_async_trload_v3"]: if self.spad == "t": return "true" # always support else: @@ -294,7 +324,7 @@ class FmhaFwdApiTrait: return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) else: return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" - elif self.pipeline_tag == "qr_async_trload": + elif self.pipeline_tag in ["qr_async_trload", "qr_async_trload_v3"]: if self.skpad == "t": return "true" else: @@ -310,7 +340,7 @@ class FmhaFwdApiTrait: return f"a.hdim_q % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == "t": return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -327,7 +357,7 @@ class FmhaFwdApiTrait: return f"a.hdim_v % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == "t": return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -429,9 +459,8 @@ class FmhaFwdPipeline: class FmhaFwdApiPool: - def __init__(self, mask_impl): + def __init__(self): self.pool = OrderedDict() - self.mask_impl = mask_impl def register_traits(self, trait: FmhaFwdApiTrait) -> None: hdim = trait.hdim, trait.bn1 @@ -443,19 +472,60 @@ class FmhaFwdApiPool: check_duplicates_and_paddings(ts, trait) ts.append(copy.copy(trait)) - @property - def api(self) -> str: + def get_num_traits( + self, filter_fn: Optional[Callable[[FmhaFwdApiTrait], bool]] = None + ) -> int: + if filter_fn is None: + + def accept_all(trait: FmhaFwdApiTrait) -> bool: + return True + + filter_fn = accept_all + + return sum( + sum(1 for trait in pool_by_hdim if filter_fn(trait)) + for pool_by_arch in self.pool.values() + for pool_by_dtype in pool_by_arch.values() + for pool_by_hdim in pool_by_dtype.values() + ) + + def render( + self, func_name, filter_fn: Optional[Callable[[FmhaFwdApiTrait], bool]] = None + ) -> str: + if filter_fn is None: + + def accept_all(trait: FmhaFwdApiTrait) -> bool: + return True + + filter_fn = accept_all + + def has_traits(node) -> bool: + """Recursively traverse nested OrderedDicts and lists to determine if any FmhaFwdApiTrait satisfies filter_fn().""" + if isinstance(node, list): + return any(filter_fn(elem) for elem in node) + elif isinstance(node, OrderedDict): + return any(has_traits(val) for val in node.values()) + return False + per_arch = str() - for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()): + for i_arch, (arch, pool_by_arch) in enumerate( + item for item in self.pool.items() if has_traits(item[1]) + ): per_dtypes = str() - for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()): + for i_dtype, (dtype, pool_by_dtype) in enumerate( + item for item in pool_by_arch.items() if has_traits(item[1]) + ): per_hdim_case = str() for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate( - pool_by_dtype.items() + item for item in pool_by_dtype.items() if has_traits(item[1]) ): - max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0) + max_bm0 = max( + (t.bm0 for t in pool_by_hdim if filter_fn(t)), default=0 + ) inners = str() - for i_trait, trait in enumerate(pool_by_hdim): + for i_trait, trait in enumerate( + [trait for trait in pool_by_hdim if filter_fn(trait)] + ): inners += FMHA_FWD_API_INNER_DISPATCH.format( F_if=if_(i_trait), F_arch=arch, @@ -463,8 +533,8 @@ class FmhaFwdApiPool: F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], - F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_mask=get_mask_cpp_type(trait.mask), + F_mask_check=get_mask_cpp_check_expr(trait.mask), F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], @@ -506,10 +576,9 @@ class FmhaFwdApiPool: F_arch=arch, F_dtype_case=indent(per_dtypes), ) - if not per_arch: - # empty string we add some ignore to suppress warning in api - per_arch = "(void)t; (void)s; (void)a;" - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=indent(per_arch)) + return FMHA_FWD_API_FUNC_TEMPLATE.format( + F_func_name=func_name, F_dispatch=indent(per_arch) + ) @dataclass @@ -548,18 +617,32 @@ class FmhaFwdTileSize: @dataclass class FmhaFwdKernel: F_arch: ArchTrait - F_idx: int # this is not a tunable, but a counter to differentiate symbol F_hdim: int # hdim F_dtype: str # data type F_mode: str # value from MODE_MAP F_tile: FmhaFwdTileSize F_pipeline: FmhaFwdPipeline - mask_impl: str - @property - def template(self) -> str: - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( - F_idx=self.F_idx, + _KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER + _KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE + + @classmethod + def _get_cpp_kernel_class_name(cls, pipeline_tag): + if pipeline_tag == "qr_async_trload_v3": + return "ck_tile::FmhaFwdV3Kernel" + else: + return "ck_tile::FmhaFwdKernel" + + @classmethod + def _get_cpp_kargs_creator_func_name(cls, pipeline_tag): + if pipeline_tag == "qr_async_trload_v3": + return "fmha_fwd_v3_create_kargs_and_grids" + else: + return "fmha_fwd_create_kargs_and_grids" + + def render(self) -> str: + return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format( + F_kname=self.name, F_arch=self.F_arch, F_hdim=self.F_hdim, F_dtype=FWD_DTYPE_MAP[self.F_dtype], @@ -594,10 +677,12 @@ class FmhaFwdKernel: F_skip=BOOL_MAP[self.F_pipeline.F_skip], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mask=get_mask_cpp_type(self.F_pipeline.F_mask), F_mode=MODE_MAP[self.F_mode], - F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag), + F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag), ) @property @@ -644,16 +729,179 @@ class FmhaFwdKernel: ) -class KernelComponentFactoryGfx9: +@dataclass +class ProblemContext: + dtype: str + mode: str + hdim: int + hdim_v: int + + +@dataclass +class KernelContext: + tile: FmhaFwdTileSize + pipeline: FmhaFwdPipeline + mask_impl: str + + +CompatibilityRule = Callable[[ProblemContext, KernelContext], bool] + + +def is_compatible( + problem_ctx: ProblemContext, + kernel_ctx: KernelContext, + rules: Iterable[CompatibilityRule], +) -> bool: + return all(rule(problem_ctx, kernel_ctx) for rule in rules) + + +def create_kernel( + arch: ArchTrait, problem_ctx: ProblemContext, kernel_ctx: KernelContext +) -> FmhaFwdKernel: + return FmhaFwdKernel( + F_arch=arch, + F_dtype=problem_ctx.dtype, + F_mode=problem_ctx.mode, + F_hdim=problem_ctx.hdim, + F_tile=kernel_ctx.tile, + F_pipeline=kernel_ctx.pipeline, + ) + + +class CompatibilityRuleFactory: + @staticmethod + def get_rules() -> list[CompatibilityRule]: + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + if problem_ctx.mode == "group": + if ( + kernel_ctx.pipeline.F_spad != "t" + or kernel_ctx.pipeline.F_skpad != "t" + ): + return False + return True + + def check_hdim(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if (problem_ctx.hdim, problem_ctx.hdim_v) == (192, 128): + if ( + kernel_ctx.pipeline.F_bias != "no" + or kernel_ctx.pipeline.F_dropout == "t" + ): + False + return True + + def check_feature( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + # logits_soft_cap is only allowed if no bias + if not ( + ( + kernel_ctx.pipeline.F_logits == "t" + and kernel_ctx.pipeline.F_bias == "no" + ) + or kernel_ctx.pipeline.F_logits == "f" + ): + return False + return True + + return [check_mode, check_hdim, check_feature] + + +class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): + _AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"}) + + @classmethod + def get_rules(cls) -> list[CompatibilityRule]: + rules = CompatibilityRuleFactory.get_rules() + + def check_hdim_tile( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + if problem_ctx.dtype != "fp32": + # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support + if kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES and ( + ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.tile.F_bn0 != 128 + ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128) + and kernel_ctx.tile.F_bm0 != 128 + ) + ): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + return False + return True + + rules.append(check_hdim_tile) + return rules + + +class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9): + _AVAILABLE_PIPELINES = ( + CompatibilityRuleFactoryGfx9._AVAILABLE_PIPELINES + | frozenset({"qr_async_trload", "qr_async_trload_v3"}) + ) + + @classmethod + def get_rules(cls) -> list[CompatibilityRule]: + rules = CompatibilityRuleFactoryGfx9.get_rules() + + def check_tile_pipeline( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + if kernel_ctx.pipeline.tag == "qr_async_trload" and ( + ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.tile.F_bn0 == 128 + ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) not in [(64, 64), (128, 128)] + ) + ): + return False + + # only qr_async_trload_v3 use km0=256 & 8-warps + is_v3_dedicated_tile = ( + kernel_ctx.tile.F_bm0 == 256 + and (kernel_ctx.tile.F_rm0 * kernel_ctx.tile.F_rn0 * kernel_ctx.tile.F_rk0) == 8 + and (kernel_ctx.tile.F_rm1 * kernel_ctx.tile.F_rn1 * kernel_ctx.tile.F_rk1) == 8 + ) # fmt: skip + is_v3_pipeline = kernel_ctx.pipeline.tag == "qr_async_trload_v3" + return is_v3_dedicated_tile == is_v3_pipeline + + rules.extend([check_tile_pipeline]) + return rules + + +class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): arch = ArchTrait( "gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)" ) + _DT_FP32 = ("fp32",) + _DT_FP16_BF16 = ("fp16", "bf16") + _DT_FP8 = ("fp8",) + _DT_FP8BF16 = ("fp8bf16",) + _DT_FP8FP32 = ("fp8fp32",) + + @classmethod + def supported_dtypes(cls) -> Tuple[str]: + return ( + cls._DT_FP32 + + cls._DT_FP16_BF16 + + cls._DT_FP8 + + cls._DT_FP8BF16 + + cls._DT_FP8FP32 + ) + # TODO: design a more practical way to do it # this is current supported tile size per hdim - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype in ["fp32"]: + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_FP32: return { # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], @@ -666,7 +914,7 @@ class KernelComponentFactoryGfx9: (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp16", "bf16"]: + elif dtype in cls._DT_FP16_BF16: return { ( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], ( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), @@ -682,30 +930,32 @@ class KernelComponentFactoryGfx9: (192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } # fmt: skip - elif dtype in ["fp8", "fp8bf16"]: + elif dtype in cls._DT_FP8 or dtype in cls._DT_FP8BF16: return { ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip - elif dtype in ["fp8fp32"]: + elif dtype in cls._DT_FP8FP32: return { (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip else: - return None + raise ValueError(f"unsupported dtype={dtype}") # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let "t" padding to appear later!! # TODO: how to design this more generic? pipelines = [] - if dtype in ["fp32"]: + if dtype in cls._DT_FP32: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -718,7 +968,7 @@ class KernelComponentFactoryGfx9: pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - elif dtype in ["fp16", "bf16"]: + elif dtype in cls._DT_FP16_BF16: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -743,7 +993,7 @@ class KernelComponentFactoryGfx9: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip - elif dtype in ["fp8bf16", "fp8fp32"]: + elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( ["f"], @@ -755,21 +1005,33 @@ class KernelComponentFactoryGfx9: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip elif dtype in ["fp8", "fp8fp16", "bf8"]: # TODO - None - else: - assert False + pass return pipelines -class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): +class KernelComponentFactoryGfx950( + KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx950 +): arch = ArchTrait("gfx950") - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) + if dtype in cls._DT_FP16_BF16: + # add tile for qr_async_trload_v3 + if (128, 128) in result.keys(): + result[(128, 128)].append( + FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + return result + + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: pipelines = KernelComponentFactoryGfx9.get_pipelines( dtype, hdim, hdim_v, receipt, mask_impl ) - if dtype in ["fp16", "bf16"]: + if dtype in cls._DT_FP16_BF16: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -788,15 +1050,31 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): ): pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip + + # qr_async_trload_v3 only supports hdim=hdim_v=128 for now + if (hdim, hdim_v) == (128, 128): + # qr_async_trload_v3 only supports (generic) causal mask + for mask in ["no", "causal"]: + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", + F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip + return pipelines -class KernelComponentFactoryGfx12: +class KernelComponentFactoryGfx12(CompatibilityRuleFactory): arch = ArchTrait("gfx12") - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype in ["fp16", "bf16"]: + _DT_FP16_BF16 = ("fp16", "bf16") + _DT_FP8_FP8BF16 = ("fp8", "fp8bf16") + _DT_FP8FP32 = ("fp8fp32",) + + @classmethod + def supported_dtypes(cls) -> Tuple[str]: + return cls._DT_FP16_BF16 + cls._DT_FP8_FP8BF16 + cls._DT_FP8FP32 + + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_FP16_BF16: return { # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], @@ -805,25 +1083,27 @@ class KernelComponentFactoryGfx12: (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp8", "fp8bf16"]: + elif dtype in cls._DT_FP8_FP8BF16: return { # bm0, bn0, bk0, bn1, bk1, ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 32, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp8fp32"]: + elif dtype in cls._DT_FP8FP32: return { # bm0, bn0, bk0, bn1, bk1, (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip else: - return None + raise ValueError(f"unsupported dtype={dtype}") - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: pipelines = [] - if dtype in ["fp16", "bf16"]: + if dtype in cls._DT_FP16_BF16: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -835,23 +1115,21 @@ class KernelComponentFactoryGfx12: ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: + elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip - else: - assert False return pipelines -class CustomFactory(KernelComponentFactoryGfx9): - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: +class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9): + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) - if dtype == "fp16" or dtype == "bf16": + if dtype in cls._DT_FP16_BF16: if (128, 128) in result.keys(): result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip return result @@ -874,150 +1152,162 @@ def get_factory(target: str): raise Exception(f"Unsupported device target {target}") +@dataclass(frozen=True) +class Product: + name: str + rule: CompatibilityRule + + def __call__(self, problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + return self.rule(problem_ctx, kernel_ctx) + + +def get_product(receipt: int) -> Product: + # Flash attention integration + if receipt in (2, 3): + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= kernel_ctx.pipeline.F_skip == "f" + return cond + + return Product(name="Flash attention integration", rule=fit) + # PyTorch integration + elif receipt == 4: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_bias in ["no", "bias"] + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + return cond + + return Product(name="PyTorch integration", rule=fit) + # Aiter(mha_fwd) integration + elif receipt == 100: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="Aiter(mha_fwd) integration", rule=fit) + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= problem_ctx.mode == "group" + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="Aiter(mha_varlen_fwd) integration", rule=fit) + # aiter::mha_fwd C++ api integration + elif receipt == 600: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="aiter::mha_fwd C++ api integration", rule=fit) + elif receipt == 888: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp8bf16", "fp8fp32"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="receipt = 888", rule=fit) + # fp32 only, all variations + elif receipt == 800: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype == "fp32" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + return cond + + return Product(name="fp32 only, all variations", rule=fit) + # fp32 only, minimal set of parameters + elif receipt == 801: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype == "fp32" + cond &= problem_ctx.hdim in [48, 128] + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_bias == "no" + cond &= kernel_ctx.pipeline.F_lse == "f" + cond &= kernel_ctx.pipeline.F_dropout == "f" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + cond &= kernel_ctx.pipeline.F_mask == "s_no" + return cond + + return Product(name="fp32 only, minimal set of parameters", rule=fit) + # Don't build fp32 by default + else: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + return problem_ctx.dtype != "fp32" + + return Product(name="Default", rule=fit) + + def get_fwd_blobs( targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() - api_pool = FmhaFwdApiPool(mask_impl) + api_pool = FmhaFwdApiPool() factories = get_factories_for_targets(targets, get_factory) - for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()): + for factory, dtype in ((f, t) for f in factories for t in f.supported_dtypes()): d = factory.get_hdim_tile_size_dict(dtype) - if d is None: - continue # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for ((hdim, hdim_v), tiles), mode in itertools.product( d.items(), MODE_MAP.keys() ): + if optdim_list != [-1]: + if hdim not in optdim_list: + continue for tile, next_tile in zip(tiles, tiles[1:]): assert next_tile.F_bm0 >= tile.F_bm0, ( "Tiles must be ordered by increasing bm0" ) + for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if mode == "group": - if pipeline.F_spad != "t" or pipeline.F_skpad != "t": - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - if (hdim, hdim_v) == (192, 128): - # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != "no" or pipeline.F_dropout == "t": - continue - if factory.arch.name.startswith("gfx9") and dtype != "fp32": - # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support - if pipeline.tag != "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) - or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128) - ): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) - or ((hdim, hdim_v) not in [(64, 64), (128, 128)]) - ): - continue - # logits_soft_cap is only allowed if no bias - if not ( - (pipeline.F_logits == "t" and pipeline.F_bias == "no") - or pipeline.F_logits == "f" - ): - continue - k = FmhaFwdKernel( - F_arch=factory.arch, - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, + problem_ctx = ProblemContext( + dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v ) + kernel_ctx = KernelContext( + tile=tile, pipeline=pipeline, mask_impl=mask_impl + ) + rules = factory.get_rules() + product = get_product(receipt) + + if not is_compatible(problem_ctx, kernel_ctx, [*rules, product]): + continue + + k = create_kernel(factory.arch, problem_ctx, kernel_ctx) if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_qscale == "no" - cond &= pipeline.F_skip == "f" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_qscale == "no" - cond &= mode == "batch" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - # aiter::mha_fwd C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - elif receipt == 888: - cond = dtype in ["fp8bf16", "fp8fp32"] - cond &= pipeline.F_vlayout == "row" - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - - # fp32 only, all variations - if receipt == 800: - cond = dtype == "fp32" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # fp32 only, minimal set of parameters - elif receipt == 801: - cond = dtype == "fp32" - cond &= hdim in [48, 128] - cond &= mode == "batch" - cond &= pipeline.F_bias == "no" - cond &= pipeline.F_lse == "f" - cond &= pipeline.F_dropout == "f" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - cond &= pipeline.F_mask == "s_no" - if not cond: - continue - else: - # Don't build fp32 by default - if dtype == "fp32": - continue api_pool.register_traits(k.api_trait()) gen.append(k) @@ -1026,11 +1316,34 @@ def get_fwd_blobs( def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - update_file(autogen_dir / kernel.filename, kernel.template) + update_file(autogen_dir / kernel.filename, kernel.render()) -def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) +def write_fwd_api( + api_pool: FmhaFwdApiPool, + autogen_dir: Path, +) -> None: + def accept_only_v3(trait: FmhaFwdApiTrait) -> bool: + return trait.pipeline_tag == "qr_async_trload_v3" + + def accept_only_v2(trait: FmhaFwdApiTrait) -> bool: + return not accept_only_v3(trait) + + content = "".join( + [ + FMHA_FWD_API_HEADER, + api_pool.render("fmha_fwd_v2", filter_fn=accept_only_v2), + api_pool.render("fmha_fwd_v3", filter_fn=accept_only_v3), + FMHA_FWD_API_FOOTER_TEMPLATE.format( + F_is_v3_enabled=BOOL_MAP[ + # NOTE: enable v3 pipelines when ready + # 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) + False + ] + ), + ] + ) + update_file(autogen_dir / FMHA_FWD_API_FILENAME, content) def write_blobs( diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp deleted file mode 100644 index c510b36bb5..0000000000 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ /dev/null @@ -1,616 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "fmha_fwd.hpp" -#include "fmha_fwd_v3.hpp" -#include "mask.hpp" - -auto parse_cmd_args(int argc, char* argv[]) -> std::pair -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("prec", "fp16", "data type. fp16/bf16") - .insert("b", "2", "batch size") - .insert("h", "8", "num of head, for q") - .insert("h_k", - "-1", - "num of head, for k/v, -1 means equal to h\n" - "if not equal to h, then this is GQA/MQA case") - .insert("s", "3328", "seqlen_q") - .insert("s_k", "-1", "seqlen_k, -1 means equal to s") - .insert("d", "128", "head dim for q & k") - .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)") - .insert("iperm", - "0", - "permute input\n" - "if true, will be b*h*s*d, else b*s*h*d") - .insert("operm", "0", "permute output") - .insert("causal", "0", "0: no mask, 1: causal mask") - .insert("v", "1", "0:no verify, 1:verify") - .insert("seed", - "11939", - "random seed used for initializing input tensors. 0 for " - "non-deterministic seed") - .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "30", "number of iterations to benchmark the kernel") - // Optional effective seqlen override (exclude PAD) for batch mode - .insert("q_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override.") - .insert("kv_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override."); - - bool result = arg_parser.parse(argc, argv); - return std::make_pair(result, arg_parser); -} - -enum class TensorLayout -{ - bhsd, - bshd, -}; - -std::ostream& operator<<(std::ostream& stream, TensorLayout layout) -{ - switch(layout) - { - case TensorLayout::bhsd: return stream << "bhsd"; - case TensorLayout::bshd: return stream << "bshd"; - default: return stream << "unknown"; - } -} - -struct Problem -{ - explicit Problem(const ck_tile::ArgParser& args) - { - data_type = args.get_str("prec") == "fp16" - ? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16 - : ck_tile::fmha_fwd_v3_args::data_type_enum::bf16; - batch = args.get_int("b"); - seqlen_q = args.get_int("s"); - seqlen_k = args.get_int("s_k"); - if(seqlen_k < 0) - { - seqlen_k = seqlen_q; - } - nhead_q = args.get_int("h"); - nhead_kv = args.get_int("h_k"); - if(nhead_kv < 0) - { - nhead_kv = nhead_q; - } - hdim = args.get_int("d"); - softmax_scale = args.get_float("scale_s"); - if(softmax_scale == .0f) - softmax_scale = 1.0 / ck_tile::sqrt(static_cast(hdim)); - - const auto is_causal = args.get_bool("causal"); - if(is_causal) - { - mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k); - } - else - { - mask = mask_info::decode("0", seqlen_q, seqlen_k); - } - - input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; - output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; - q_eff_lens = args.get_int_vec("q_eff_lens"); - kv_eff_lens = args.get_int_vec("kv_eff_lens"); - } - - std::vector get_query_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_q, seqlen_q, hdim}; - } - else - { - return {batch, seqlen_q, nhead_q, hdim}; - } - } - - std::vector get_key_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_kv, seqlen_k, hdim}; - } - else - { - return {batch, seqlen_k, nhead_kv, hdim}; - } - } - - std::vector get_value_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_kv, seqlen_k, hdim}; - } - else - { - return {batch, seqlen_k, nhead_kv, hdim}; - } - } - - std::vector get_output_shape() const - { - if(output_layout == TensorLayout::bhsd) - { - return {batch, nhead_q, seqlen_q, hdim}; - } - else - { - return {batch, seqlen_q, nhead_q, hdim}; - } - } - - ck_tile::fmha_fwd_v3_args::data_type_enum data_type; - ck_tile::index_t batch; - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t nhead_q; - ck_tile::index_t nhead_kv; - ck_tile::index_t hdim; - float softmax_scale; - mask_info mask; - TensorLayout input_layout; - TensorLayout output_layout; - std::vector q_eff_lens; - std::vector kv_eff_lens; -}; - -struct RunConfig -{ - explicit RunConfig(const ck_tile::ArgParser& args) - { - seed = args.get_uint32("seed"); - if(*seed == 0) - { - seed.reset(); - } - - kernel_warmup = args.get_int("warmup"); - kernel_repeat = args.get_int("repeat"); - verify = args.get_bool("v"); - } - - std::optional seed; - int kernel_warmup; - int kernel_repeat; - bool verify; -}; - -template -auto generate_qkv(const Problem& problem, - [[maybe_unused]] std::optional seed = std::nullopt) - -> std::tuple, - ck_tile::HostTensor, - ck_tile::HostTensor> -{ - ck_tile::HostTensor q(problem.get_query_shape()); - ck_tile::HostTensor k(problem.get_key_shape()); - ck_tile::HostTensor v(problem.get_value_shape()); - - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v); - - return std::make_tuple(q, k, v); -} - -namespace host { -template -CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, - const ck_tile::HostTensor& k_bshd, - const ck_tile::HostTensor& v_bshd, - const mask_info& mask, - ck_tile::HostTensor& o_bshd, - const QElementOp& q_element_op = {}, - const KElementOp& k_element_op = {}, - const VElementOp& v_element_op = {}, - const SAccElementOp& s_acc_element_op = {}) -{ - const int batch_size = q_bshd.mDesc.get_lengths()[0]; - const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; - const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; - const int nhead_q = q_bshd.mDesc.get_lengths()[2]; - const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; - const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; - const int hdim_v = v_bshd.mDesc.get_lengths()[3]; - - const int nr = nhead_q / nhead_kv; - - ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); - ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); - ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); - ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); - - ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); - ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); - - // do computation for each batch - for(int b = 0; b < batch_size; ++b) - { - // copy per-batch data from input tensors - // clang-format off - q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); }); - k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); }); - v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); - // clang-format on - ck_tile::reference_batched_gemm( - q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - - if(mask.type == mask_enum::no_mask) - { - ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); - } - else if(mask.type == mask_enum::window_generic) - { - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, seqlen_q, seqlen_kv)); - } - else - { - // if left window size is negative, means causal - // else means generic (for current batch) - if(mask.left < 0) - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - else - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - } - - ck_tile::reference_batched_softmax( - s_host_ref, p_host_ref, ck_tile::identity{}); - - ck_tile::reference_batched_gemm( - p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - - // copy resulting per-batch data to the output tensor - o_host_ref.ForEach( - [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); - } -} -} // namespace host - -template -bool run_impl(const Problem& problem, const RunConfig& run_config) -{ - auto [q, k, v] = generate_qkv(problem, run_config.seed); - - ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes()); - /// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v - ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes()); - - q_buf.ToDevice(q.data()); - k_buf.ToDevice(k.data()); - v_buf.ToDevice(v.data()); - // Ensure output buffer is zero-initialized so padded regions compare cleanly - o_buf.SetZero(); - - ck_tile::fmha_fwd_v3_args args{}; - - args.data_type = problem.data_type; - args.batch = problem.batch; - args.seqlen_q = problem.seqlen_q; - args.seqlen_k = problem.seqlen_k; - args.nhead_q = problem.nhead_q; - args.nhead_kv = problem.nhead_kv; - args.hdim_qk = problem.hdim; - args.hdim_v = problem.hdim; - args.softmax_scale = problem.softmax_scale; - - args.window_size_left = problem.mask.left; - args.window_size_right = problem.mask.right; - args.mask_type = static_cast(problem.mask.type); - - // bshd: (batch, seqlen_q, nhead_q, hdim) - // bhsd: (batch, nhead_q, seqlen_q, hdim) - args.q_ptr = q_buf.GetDeviceBuffer(); - args.stride_q = - problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; - args.nhead_stride_q = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim; - args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim; - - // bshd: (batch, seqlen_k, nhead_kv, hdim) - // bhsd: (batch, nhead_kv, seqlen_k, hdim) - args.k_ptr = k_buf.GetDeviceBuffer(); - args.stride_k = - problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; - args.nhead_stride_k = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; - args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim; - - // bshd: (batch, seqlen_k, nhead_kv, hdim) - // bhsd: (batch, nhead_kv, seqlen_k, hdim) - args.v_ptr = v_buf.GetDeviceBuffer(); - args.stride_v = - problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; - args.nhead_stride_v = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; - args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim; - - // bshd: (batch, seqlen_q, nhead_q, hdim) - // bhsd: (batch, nhead_q, seqlen_q, hdim) - args.o_ptr = o_buf.GetDeviceBuffer(); - args.stride_o = - problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; - args.nhead_stride_o = problem.output_layout == TensorLayout::bshd - ? problem.hdim - : problem.seqlen_q * problem.hdim; - args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; - - // Optional cumulative seqlen overrides (exclude PAD) - const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; - const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; - - auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { - std::vector eff; - if(!opt_vec.empty() && opt_vec[0] != -1) - { - eff.assign(opt_vec.begin(), opt_vec.end()); - if(eff.size() < static_cast(problem.batch)) - { - eff.resize(problem.batch, eff.back()); - } - } - else - { - eff.assign(problem.batch, fallback); - } - return eff; - }; - - const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); - const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); - - // Calculate cumulative sums for kernel arguments if varlen is used - std::vector cuq_cum, cukv_cum; - auto calculate_cumulative = [&](const std::vector& per_batch_vec, - std::vector& cum_vec) { - cum_vec.resize(per_batch_vec.size() + 1); - cum_vec[0] = 0; - for(std::size_t i = 0; i < per_batch_vec.size(); ++i) - cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; - }; - - if(has_varlen_q) - { - calculate_cumulative(eff_q_vec, cuq_cum); - } - if(has_varlen_k) - { - calculate_cumulative(eff_kv_vec, cukv_cum); - } - - ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); - ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); - cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); - cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); - args.cu_seqlen_q_ptr = - !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) - : nullptr; - args.cu_seqlen_kv_ptr = - !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) - : nullptr; - - ck_tile::stream_config stream_config{nullptr, - true, - /*log_level=*/0, - run_config.kernel_warmup, - run_config.kernel_repeat}; - - auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config); - if(!result) - { - std::cerr << "faild to run fmha_fwd_v3()" << std::endl; - return false; - } - - std::size_t flop = [&] { - if(problem.mask.type == mask_enum::no_mask) - { - return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - else - { - /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. - return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - }(); - float tflops = static_cast(flop) / 1.e9 / time; - - std::cout << "[" << problem.data_type << "|"; - if(problem.input_layout == problem.output_layout) - { - std::cout << problem.input_layout; - } - else - { - std::cout << problem.input_layout << "-" << problem.output_layout; - } - std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim - << ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed - << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops - << " TFlops" << std::endl; - - if(!run_config.verify) - { - return true; - } - - // transpose tensor descriptors from bhsd to bshd if necessary - if(problem.input_layout != TensorLayout::bshd) - { - q = q.transpose({0, 2, 1, 3}); - k = k.transpose({0, 2, 1, 3}); - v = v.transpose({0, 2, 1, 3}); - } - - ck_tile::HostTensor o_ref(problem.get_output_shape()); - if(problem.output_layout != TensorLayout::bshd) - { - o_ref = o_ref.transpose({0, 2, 1, 3}); - } - - // If variable lengths are provided, compute per-batch references - // with the effective lengths; else compute a single full reference. - if(has_varlen_q || has_varlen_k) - { - // Variable-length aware verification: zero-fill padded region and only compute valid part. - o_ref.SetZero(); - - for(int b = 0; b < problem.batch; ++b) - { - const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; - const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; - - if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) - continue; - - // Slice current batch from inputs (bshd) and build single-batch tensors - ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - - // Copy effective region - q_b.ForEach([&](auto& self, auto idx) { - // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); - }); - k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); - v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); - - // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) - host::fmha_fwd(q_b, - k_b, - v_b, - problem.mask, - o_b, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - - // Scatter into o_ref's bshd descriptor memory - for(int s = 0; s < seqlen_q_eff; ++s) - { - for(int h = 0; h < problem.nhead_q; ++h) - { - for(int d = 0; d < problem.hdim; ++d) - { - o_ref(b, s, h, d) = o_b(0, s, h, d); - } - } - } - } - } - else - { - // No varlen override: compute the full reference once - host::fmha_fwd(q, - k, - v, - problem.mask, - o_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - } - - ck_tile::HostTensor o(problem.get_output_shape()); - o_buf.FromDevice(o.data()); - - const auto [rtol, atol] = [&] { - if constexpr(std::is_same_v) - return std::make_tuple(1e-3, 1e-3); - else - return std::make_tuple(1e-2, 1e-2); - }(); - return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); -} - -int main(int argc, char* argv[]) -{ - auto [parse_result, args] = parse_cmd_args(argc, argv); - if(!parse_result) - { - std::cerr << "failed to parse command line arguments" << std::endl; - } - - Problem problem(args); - RunConfig run_config(args); - - const auto run = [&] { - if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16) - { - return run_impl(problem, run_config); - } - else - { - return run_impl(problem, run_config); - } - }; - - return !run(); -} diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index f279ebfcea..002d0a1035 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -686,6 +686,100 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) } } +template +auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) +{ + /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly + /// maximizes the kernel's performance. + int remap_opt = 2; + if(args.mask_type != static_cast(mask_enum::no_mask) && + ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) + { + if(65536 <= args.seqlen_q) + { + remap_opt = 0; + } + else + { + remap_opt = 1; + } + } + + auto kargs = [&] { + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + nullptr, // lse_ptr + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_q_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + 0, // nhead_stride_lse + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); + } + else + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + nullptr, // lse_ptr + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + 0, // nhead_stride_lse + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + 0, // batch_stride_lse + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + + return ck_tile::make_tuple(kargs, grids); +} + template auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) { diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/fmha_fwd_v3.cpp deleted file mode 100644 index 1c0256cc0f..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" -#include "mask.hpp" - -namespace ck_tile { - -std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type) -{ - switch(data_type) - { - case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16"; - case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16"; - default: return stream << "unknown"; - } -} - -std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config) -{ - if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16) - { - if(args.mask_type == static_cast(mask_enum::no_mask)) - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - else - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - } - else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16) - { - if(args.mask_type == static_cast(mask_enum::no_mask)) - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - else - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - } - - return std::make_pair(false, -1.f); -} - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp deleted file mode 100644 index 54cc4960a5..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck_tile/core/numeric/integer.hpp" -#include "ck_tile/host/stream_config.hpp" - -namespace ck_tile { - -struct fmha_fwd_v3_args -{ - enum class data_type_enum - { - fp16, - bf16 - }; - - data_type_enum data_type; - // bool is_varlen; - - index_t batch; - index_t seqlen_q; - index_t seqlen_k; - index_t nhead_q; - index_t nhead_kv; - index_t hdim_qk; - index_t hdim_v; - - float softmax_scale; - - index_t window_size_left; - index_t window_size_right; - index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and - // window_size_right == 0). - - const void* q_ptr; - index_t stride_q; - index_t nhead_stride_q; - index_t batch_stride_q; - - const void* k_ptr; - index_t stride_k; - index_t nhead_stride_k; - index_t batch_stride_k; - - const void* v_ptr; - index_t stride_v; - index_t nhead_stride_v; - index_t batch_stride_v; - - void* o_ptr; - index_t stride_o; - index_t nhead_stride_o; - index_t batch_stride_o; - - // Optional batch-mode cumulative seqlen overrides (exclude PAD) - // If provided, they override per-batch effective lengths to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] -}; - -std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); - -// return value: -// first = whether the kernel was launched (true = launched, false = skipped) -// second = elapsed time (ms) of the kernel launch, valid only if first == true -std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config); - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp deleted file mode 100644 index 19b8dfed4e..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include - -#include "ck_tile/core/numeric/bfloat16.hpp" -#include "ck_tile/core/numeric/half.hpp" -#include "ck_tile/core/container/sequence.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" -#include "ck_tile/ops/fmha/block/block_masking.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" - -#include "fmha_fwd_v3.hpp" -#include "mask.hpp" - -#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ - template <> \ - std::pair fmha_fwd_v3_kernel_dispatch( \ - const fmha_fwd_v3_args& args, const stream_config& config) \ - { \ - return std::make_pair(true, \ - fmha_fwd_v3_kernel_launch(args, config)); \ - } - -namespace ck_tile { - -template -struct fmha_fwd_v3_problem_traits; - -template <> -struct fmha_fwd_v3_problem_traits -{ - using qkvp_dtype = ck_tile::half_t; - using acc_dtype = float; - using o_dtype = ck_tile::half_t; - using lse_dtype = float; -}; - -template <> -struct fmha_fwd_v3_problem_traits -{ - using qkvp_dtype = ck_tile::bf16_t; - using acc_dtype = float; - using o_dtype = ck_tile::bf16_t; - using lse_dtype = float; -}; - -template -struct fmha_fwd_v3_kernel_traits -{ - static constexpr auto date_type = DataType; - static constexpr bool is_variable_seqlen = IsVariableSeqlen; - static constexpr bool is_masking = IsMasking; - - // M0 N0 K0 N1 K1 - using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>; - using fmha_warp_gemm_shape = sequence<32, 32, 16>; - using fmha_block_warps = sequence<8, 1, 1>; - - using fmha_shape = TileFmhaShape; - - using fmha_traits = TileFmhaFwdV3Traits; - - using fmha_mask = GenericAttentionMask; - - using fmha_pipeline_problem = - BlockFmhaFwdV3PipelineProblem::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::lse_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::o_dtype, - fmha_shape, - IsVariableSeqlen, - fmha_mask, - fmha_traits>; - - using fmha_pipeline = BlockFmhaFwdV3Pipeline; - - using epilogue = Default2DEpilogue< - Default2DEpilogueProblem::acc_dtype, - typename fmha_fwd_v3_problem_traits::o_dtype, - true, // kPadM - true, // kPadM - true // UseRawStore - >>; - - using kernel = FmhaFwdV3Kernel; -}; - -template -float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config) -{ - /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly - /// maximizes the kernel's performance. - int remap_opt = 2; - if(args.mask_type != static_cast(mask_enum::no_mask) && - ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) - { - if(65536 <= args.seqlen_q) - { - remap_opt = 0; - } - else - { - remap_opt = 1; - } - } - - auto kargs = Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - nullptr, // lse_ptr - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_qk, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_kv, - args.softmax_scale, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - 0, // nhead_stride_lse - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - 0, // batch_stride_lse - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - remap_opt, - args.cu_seqlen_q_ptr, - args.cu_seqlen_kv_ptr); - - dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); - constexpr dim3 blocks = Kernel::BlockSize(); - constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; - - return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); -} - -// return value: -// first = whether the kernel was launched (true = launched, false = skipped) -// second = elapsed time (ms) of the kernel launch, valid only if first == true -template -std::pair fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args, - const stream_config& config); - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp deleted file mode 100644 index 463c52b824..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp deleted file mode 100644 index acf79e43f4..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp deleted file mode 100644 index a6366209b2..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp deleted file mode 100644 index a83e37cc68..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index 81eea60c2f..29a7e2593e 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -1552,6 +1552,81 @@ CK_TILE_HOST_DEVICE static void print(const indexing& printf("}"); } +template +struct functor_transform : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + Functor functor_; + UpLengths up_lengths_; + + CK_TILE_HOST_DEVICE constexpr functor_transform() = default; + + CK_TILE_HOST_DEVICE constexpr functor_transform(const Functor& functor, + const LowLength& low_length) + : functor_{functor}, up_lengths_{make_tuple(low_length)} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = functor_(idx_up[number<0>{}]); + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& up_idx) const + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + const auto idx_low_old = idx_low; + calculate_lower_index(idx_low, up_idx); + idx_diff_low = idx_low - idx_low_old; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + // Note: When using functor_transform, ensure that the transformed coordinates + // are always valid for vectorized load/store operations. + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + return make_tuple(low_vector_lengths, low_vector_strides); + } +}; + //******************************************************************************************************* template @@ -1671,6 +1746,13 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le return offset{low_length, offset_length}; } +template +CK_TILE_HOST_DEVICE constexpr auto make_functor_transform(const Functor& functor, + const LowLength& low_length) +{ + return functor_transform{functor, low_length}; +} + } // namespace ck_tile #include "ck_tile/core/algorithm/indexing_adaptor.hpp" diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index e80267faec..d39da82a62 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -1263,7 +1263,9 @@ struct tile_window_with_static_lengths } }; -template +template >> CK_TILE_DEVICE constexpr auto make_tile_window(const TensorView_& tensor_view, const WindowLengths_& window_lengths, @@ -1310,7 +1312,10 @@ make_tile_window(const tile_window_with_static_lengths +template >> CK_TILE_DEVICE constexpr auto make_tile_window(const tile_window_with_static_lengths& tile_window, const StaticTileDistribution& tile_distribution, diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 1a79aebef5..756968871d 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -600,6 +600,19 @@ struct SimplifiedRatioAttentionMask mdiv y_ratio_mdiv; }; +template +struct is_generic_attention_mask : std::false_type +{ +}; + +template +struct is_generic_attention_mask> : std::true_type +{ +}; + +template +static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask::value; + // TODO: prefer use this function in host code // can convert from the FA style left/right to our generic coordinate // if left_size < 0 && right_size = 0, it is normal causal mask diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 38830ee6fe..9890d1f2e4 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -73,54 +73,6 @@ struct FmhaFwdKernel #endif static constexpr std::string_view kPipelineName = FmhaPipeline::name; - // clang-format off - template struct t2s; - template <> struct t2s { static constexpr const char * name = "fp32"; }; - template <> struct t2s { static constexpr const char * name = "fp16"; }; - template <> struct t2s { static constexpr const char * name = "bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8"; }; - template <> struct t2s { static constexpr const char * name = "bf8"; }; - template <> struct t2s { static constexpr const char * name = "fp8bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8fp32"; }; - // clang-format on - - CK_TILE_HOST static std::string GetName() - { - // sync with generate.py - // clang-format off - using bfs = typename FmhaPipeline::BlockFmhaShape; - using g0br = typename bfs::Gemm0BlockWarps; - using g1br = typename bfs::Gemm1BlockWarps; - using g0wt = typename bfs::Gemm0WarpTile; - using g1wt = typename bfs::Gemm1WarpTile; - #define _SS_ std::string - #define _TS_ std::to_string - auto pn = [&] () { - std::string n; - if (kPadSeqLenQ) n += "s"; - if (kPadSeqLenK) n += "sk"; - if (kPadHeadDimQ) n += "d"; - if (kPadHeadDimV) n += "dv"; - return n.empty() ? n : std::string("p") + n; }(); - return - _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" - "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + - _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + - "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + - "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + - (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + - "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + - (QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr::name)) + (kUseTrLoad ? "_trload" : "_ntrload"); - #undef _SS_ - #undef _TS_ - // clang-format on - } - template // to avoid duplicated base class prblem, introduce an template // arg struct FmhaFwdEmptyKargs diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index df17bdd879..f981c54bd8 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -12,6 +12,8 @@ namespace ck_tile { +/// NOTICE: This kernel is a work in progress and is awaiting upcoming compiler fixes and +/// instruction scheduling optimizations. template struct FmhaFwdV3Kernel { @@ -103,8 +105,8 @@ struct FmhaFwdV3Kernel // Optional cumulative sequence length pointers for batch mode // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_k_ptr = nullptr; // [batch+1] }; struct FmhaFwdGroupModeKargs @@ -114,12 +116,13 @@ struct FmhaFwdV3Kernel { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; + const int32_t* seqlen_q_ptr; const int32_t* seqlen_k_ptr; // Optional cumulative padded sequence starts (including PAD tokens) // Used solely to compute memory offsets when sequences are physically padded. - const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1] - const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1] + const int32_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const int32_t* cu_seqlen_k_ptr = nullptr; // [batch+1] }; using Kargs = std::conditional_t; @@ -156,8 +159,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -199,8 +202,8 @@ struct FmhaFwdV3Kernel kargs.batch_stride_lse = batch_stride_lse; } - kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; - kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } @@ -213,6 +216,7 @@ struct FmhaFwdV3Kernel void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + const void* seqlen_q_ptr, const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -232,8 +236,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, - const void* seqstart_padded_q_ptr = nullptr, - const void* seqstart_padded_k_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -258,6 +262,7 @@ struct FmhaFwdV3Kernel {}, // placeholder for lse reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_q_ptr), reinterpret_cast(seqlen_k_ptr)}; if constexpr(kHasMask) @@ -273,30 +278,29 @@ struct FmhaFwdV3Kernel kargs.nhead_stride_lse = nhead_stride_lse; } - kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); - kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t hdim_v) { - // TODO: this may need tuning - if constexpr(kHasMask) + if constexpr(kIsGroupMode) { - return dim3(nhead_, - ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - batch_size_); + return dim3(nhead, + batch_size, + ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1)); } else { - return dim3(nhead_, - ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - batch_size_); + return dim3(nhead, + ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1), + batch_size); } } @@ -344,13 +348,20 @@ struct FmhaFwdV3Kernel // FmhaPipeline::kN1); // assume that num_tile_n1 is always 1 - if constexpr(kHasMask) + if constexpr(kIsGroupMode) { const index_t i_nhead = blockIdx.x; - const index_t i_block = blockIdx.y; - const index_t i_batch = blockIdx.z; + const index_t i_batch = blockIdx.y; + const index_t i_block = blockIdx.z; - return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch); + if constexpr(kHasMask) + { + return ck_tile::make_tuple(gridDim.z - 1 - i_block, 0, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + } } else { @@ -358,7 +369,14 @@ struct FmhaFwdV3Kernel const index_t i_block = blockIdx.y; const index_t i_batch = blockIdx.z; - return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + if constexpr(kHasMask) + { + return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + } } } @@ -390,32 +408,36 @@ struct FmhaFwdV3Kernel if constexpr(kIsGroupMode) { - // get starting offset for each batch - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; + // Use seqstart_q_ptr and seqstart_k_ptr for physical starts + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; if constexpr(kStoreLSE) { // LSE layout is [nhead, total_seqlen], index by unpadded start - batch_offset_lse = query_start_unpadded; + batch_offset_lse = query_start; } - batch_offset_o = query_start_padded * kargs.stride_o; - - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + batch_offset_o = query_start * kargs.stride_o; + // real logical lengths (exclude PAD) + // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr + if(kargs.seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch]; + } + else if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; + } // # of required blocks is different in each groups, terminate unnecessary blocks // earlier if(kargs.seqlen_q <= i_m0) @@ -427,10 +449,14 @@ struct FmhaFwdV3Kernel { kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } + else if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } else { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; } } else @@ -450,10 +476,10 @@ struct FmhaFwdV3Kernel kargs.seqlen_q = kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; } - if(kargs.cu_seqlen_kv_ptr != nullptr) + if(kargs.cu_seqlen_k_ptr != nullptr) { kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 8bf24be386..68ec349694 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -246,6 +248,8 @@ CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs) } } // namespace detail +/// NOTICE: This pipeline is a work in progress and is awaiting upcoming compiler fixes and +/// instruction scheduling optimizations. template struct BlockFmhaFwdV3Pipeline { @@ -261,12 +265,16 @@ struct BlockFmhaFwdV3Pipeline using OaccDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; + static_assert(is_generic_attention_mask_v); static_assert(std::is_same_v, "we will the same dist tensor 'sp_compute' for both gemm0 & softmax"); using BlockFmhaShape = ck_tile::remove_cvref_t; + using VLayout = remove_cvref_t; + static_assert(std::is_same_v); + static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; @@ -277,14 +285,24 @@ struct BlockFmhaFwdV3Pipeline static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; - static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static_assert(kQKHeaddim == 128 && kSubQKHeaddim == 128, "only supports hdim=hdim_v=128"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto QScaleEnum = Problem::QScaleEnum; + static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; + static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS && + !kStoreLSE && !kHasDropout && + (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && + !kSkipMinSeqlenQ), + "enable unsupported features"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index da0fa16ee1..659bdd995b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -12,6 +12,7 @@ enum class BlockFmhaPipelineEnum QRKSVS_ASYNC, QSKSVS, QRKSVS_ASYNC_TRLOAD, + QRKSVS_ASYNC_TRLOAD_V3, }; template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index b90b760a0d..7c4a921b70 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -264,47 +264,4 @@ struct BlockFmhaFwdAppendKVPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; -template -struct BlockFmhaFwdV3PipelineProblem -{ - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using BlockFmhaShape = remove_cvref_t; - using FmhaMask = remove_cvref_t; - using Traits = remove_cvref_t; - - static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps; - static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps; - static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); - - static constexpr bool kIsGroupMode = kIsGroupMode_; - - // attributes from traits - static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr bool kStoreLSE = Traits::kStoreLSE; - static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; -}; - } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index b9e18de1e5..df33a93696 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -166,20 +166,4 @@ struct TileFmhaBwdConvertQGradTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; -template -struct TileFmhaFwdV3Traits -{ - static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; - static constexpr bool kPadSeqLenK = kPadSeqLenK_; - static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; - static constexpr bool kPadHeadDimV = kPadHeadDimV_; - static constexpr bool kStoreLSE = kStoreLSE_; - static constexpr index_t kBlockPerCu = kBlockPerCu_; -}; - } // namespace ck_tile diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index affa6d987b..aeec7bd471 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -90,7 +90,7 @@ submodule = submodule_t() # formatting format_procs = [] for x in all_files: - dos2unix = f"python -m dos2unix {str(x)} {str(x)}" + dos2unix = f"python3 -m dos2unix {str(x)} {str(x)}" clang_format = f"clang-format -style=file -i {str(x)}" # One process to avoid race conditions. cmd = f"{dos2unix} && {clang_format}" From 13f6d635653bd5ffbfcac8577f1ef09590c23d78 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Thu, 4 Dec 2025 19:12:36 -0800 Subject: [PATCH 03/24] Clean up conv_traits.hpp (#3354) When I asked for a description of operators that didn't have ConvTraits, I was getting very long confusing errors about ConvTraits not being defined. Now we get specific errors explaining which concepts are violated, making it easier to know which code to generalize or update. * Add concepts to conv_traits.hpp to get better error message. * Put the correct requires clauses in the right places to get descriptive error messages. * General cleanup of functions in conv_traits.hpp to make functions easier to read. --- .../builder/reflect/conv_description.hpp | 8 +- .../ck_tile/builder/reflect/conv_traits.hpp | 457 +++++++----------- 2 files changed, 186 insertions(+), 279 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index 261c3f103d..59ff83c238 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -251,14 +251,10 @@ class ConvDescription : public Description }; } // namespace conv -/// @brief Helper concept to detect if a type has ConvTraits specialization -template -concept HasConvTraits = requires { typename conv::ConvTraits; }; - /// @brief Factory function to create ConvDescription from a convolution instance type -/// @tparam Instance The convolution instance type (must have InstanceTraits specialization) +/// @tparam Instance The convolution instance type (must have ConvTraits specialization) /// @return A ConvDescription object populated with the instance's configuration details -template +template conv::ConvDescription describe() { using Traits = conv::ConvTraits; diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 29ac49e549..918fd6bdb6 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -21,6 +21,57 @@ namespace ck_tile::reflect::conv { +// Forward convolution layout concept - checks for A/B/E layout types +template +concept HasFwdConvLayouts = requires { + typename T::ALayout; + typename T::BLayout; + typename T::ELayout; +}; + +// GEMM specialization concept - checks for kGemmSpecialization member +template +concept HasGemmSpec = requires { + { + T::kGemmSpecialization + } -> std::convertible_to; +}; + +// Data types concept - checks for ADataType member +template +concept HasDataTypes = requires { typename T::ADataType; }; + +// Elementwise operations concept - checks for A/B/CDE elementwise operation types +template +concept HasElementwiseOps = requires { + typename T::AElementwiseOperation; + typename T::BElementwiseOperation; + typename T::CDEElementwiseOperation; +}; + +// Tile parameters concept - checks for tile dimension and transfer members +template +concept HasTileParams = requires { + { T::kKPerBlock } -> std::convertible_to; + { T::kMPerBlock } -> std::convertible_to; + { T::kNPerBlock } -> std::convertible_to; + { T::kAK1 } -> std::convertible_to; + { T::kBK1 } -> std::convertible_to; + T::kCThreadClusterLengths; +}; + +// Comprehensive concept that checks if an instance has all XDL forward convolution traits +// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions +template +concept IsXdlFwdConv = HasFwdConvLayouts && HasGemmSpec && HasDataTypes && + HasElementwiseOps && HasTileParams; + +// Primary concept for checking if a type can be described +// Currently only forward convolutions are supported, but this can be extended +// in the future to include backward data and backward weight convolutions +template +concept HasConvTraits = IsXdlFwdConv>; + // Helper metafunctions to convert from ck enums to builder enums /// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. @@ -35,16 +86,15 @@ constexpr auto convert_pipeline_version() { using enum ck::BlockGemmPipelineVersion; using enum builder::PipelineVersion; - if constexpr(ck_ver == v1) - return V1; - else if constexpr(ck_ver == v2) - return V2; - else if constexpr(ck_ver == v3) - return V3; - else if constexpr(ck_ver == v4) - return V4; - else if constexpr(ck_ver == v5) - return V5; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v3: return V3; + case v4: return V4; + case v5: return V5; + } } /// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. @@ -59,14 +109,14 @@ constexpr auto convert_pipeline_version() { using enum ck::PipelineVersion; using enum builder::PipelineVersion; - if constexpr(ck_ver == v1) - return V1; - else if constexpr(ck_ver == v2) - return V2; - else if constexpr(ck_ver == v4) - return V4; - else if constexpr(ck_ver == weight_only) - return WEIGHT_ONLY; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v4: return V4; + case weight_only: return WEIGHT_ONLY; + } } /// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. @@ -82,10 +132,12 @@ constexpr auto convert_pipeline_scheduler() { using enum ck::BlockGemmPipelineScheduler; using enum builder::PipelineScheduler; - if constexpr(ck_sched == Intrawave) - return INTRAWAVE; - else if constexpr(ck_sched == Interwave) - return INTERWAVE; + + switch(ck_sched) + { + case Intrawave: return INTRAWAVE; + case Interwave: return INTERWAVE; + } } /// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. @@ -101,10 +153,12 @@ constexpr auto convert_pipeline_scheduler() { using enum ck::LoopScheduler; using enum builder::PipelineScheduler; - if constexpr(ck_sched == Default) - return DEFAULT; - else if constexpr(ck_sched == Interwave) - return INTERWAVE; + + switch(ck_sched) + { + case Default: return DEFAULT; + case Interwave: return INTERWAVE; + } } /// @brief Helper structures for organizing trait data with domain-specific naming @@ -213,21 +267,13 @@ constexpr builder::ConvDirection conv_direction() using InstTraits = InstanceTraits; if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) - { return builder::ConvDirection::FORWARD; - } else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) - { return builder::ConvDirection::BACKWARD_DATA; - } else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) - { return builder::ConvDirection::BACKWARD_WEIGHT; - } else - { return builder::ConvDirection::FORWARD; // Default fallback - } } /// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. @@ -242,60 +288,52 @@ constexpr auto conv_spec() if constexpr(requires { InstTraits::kConvForwardSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + using enum builder::ConvFwdSpecialization; - if constexpr(InstTraits::kConvForwardSpecialization == Default) + switch(InstTraits::kConvForwardSpecialization) { - return builder::ConvFwdSpecialization::DEFAULT; - } - else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0) - { - return builder::ConvFwdSpecialization::FILTER_1X1_PAD0; - } - else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0) - { - return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0; - } - else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3) - { - return builder::ConvFwdSpecialization::FILTER_3x3; + case Default: return DEFAULT; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter3x3: return FILTER_3x3; } } else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + using enum builder::ConvBwdDataSpecialization; - if constexpr(InstTraits::kConvBwdDataSpecialization == Default) + switch(InstTraits::kConvBwdDataSpecialization) { - return builder::ConvBwdDataSpecialization::DEFAULT; - } - else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0) - { - return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0; + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; } } else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + using enum builder::ConvBwdWeightSpecialization; - if constexpr(InstTraits::kConvBwdWeightSpecialization == Default) + switch(InstTraits::kConvBwdWeightSpecialization) { - return builder::ConvBwdWeightSpecialization::DEFAULT; - } - else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0) - { - return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0; - } - else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0) - { - return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0; - } - else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC) - { - return builder::ConvBwdWeightSpecialization::ODD_C; + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case OddC: return ODD_C; } } } +// Helper variable template to check if CK layout enums match +template +inline constexpr bool layouts_are = + std::is_same_v && std::is_same_v && std::is_same_v; + /// @brief Derives the grouped convolution layout from a device kernel `Instance` type. /// @tparam Instance The device kernel instance type. /// @return An std::array corresponding to the tensor layouts: @@ -304,112 +342,49 @@ constexpr auto conv_spec() /// index 2 -> Output layout template constexpr auto conv_layout() + requires HasFwdConvLayouts> { - using InstTraits = InstanceTraits; - using ALayout = typename InstTraits::ALayout; - using BLayout = typename InstTraits::BLayout; - using ELayout = typename InstTraits::ELayout; + // Helper lambda to construct layout array + auto layouts = [](auto... Ls) { return std::array{Ls...}; }; - namespace ctc = ck::tensor_layout::convolution; + using A = typename InstanceTraits::ALayout; + using B = typename InstanceTraits::BLayout; + using E = typename InstanceTraits::ELayout; + namespace ctl = ck::tensor_layout::convolution; + using enum builder::TensorLayout; - if constexpr(InstTraits::kSpatialDim == 1) + switch(InstanceTraits::kSpatialDim) { - if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::GNWC, - builder::TensorLayout::GKXC, - builder::TensorLayout::GNWK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && std::is_same_v) - { - return std::array{builder::TensorLayout::NWGC, - builder::TensorLayout::GKXC, - builder::TensorLayout::NWGK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && std::is_same_v) - { - return std::array{builder::TensorLayout::NGCW, - builder::TensorLayout::GKXC, - builder::TensorLayout::NGKW}; - } - else if constexpr(std::is_same_v && - std::is_same_v && std::is_same_v) - { - return std::array{builder::TensorLayout::NGCW, - builder::TensorLayout::GKCX, - builder::TensorLayout::NGKW}; - } - } - else if constexpr(InstTraits::kSpatialDim == 2) - { - if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::GNHWC, - builder::TensorLayout::GKYXC, - builder::TensorLayout::GNHWK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NHWGC, - builder::TensorLayout::GKYXC, - builder::TensorLayout::NHWGK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCHW, - builder::TensorLayout::GKYXC, - builder::TensorLayout::NGKHW}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCHW, - builder::TensorLayout::GKCYX, - builder::TensorLayout::NGKHW}; - } - } - else if constexpr(InstTraits::kSpatialDim == 3) - { - if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::GNDHWC, - builder::TensorLayout::GKZYXC, - builder::TensorLayout::GNDHWK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NDHWGC, - builder::TensorLayout::GKZYXC, - builder::TensorLayout::NDHWGK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCDHW, - builder::TensorLayout::GKZYXC, - builder::TensorLayout::NGKDHW}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCDHW, - builder::TensorLayout::GKCZYX, - builder::TensorLayout::NGKDHW}; - } + case 1: + if constexpr(layouts_are) + return layouts(GNWC, GKXC, GNWK); + if constexpr(layouts_are) + return layouts(NWGC, GKXC, NWGK); + if constexpr(layouts_are) + return layouts(NGCW, GKXC, NGKW); + if constexpr(layouts_are) + return layouts(NGCW, GKCX, NGKW); + break; + case 2: + if constexpr(layouts_are) + return layouts(GNHWC, GKYXC, GNHWK); + if constexpr(layouts_are) + return layouts(NHWGC, GKYXC, NHWGK); + if constexpr(layouts_are) + return layouts(NGCHW, GKYXC, NGKHW); + if constexpr(layouts_are) + return layouts(NGCHW, GKCYX, NGKHW); + break; + case 3: + if constexpr(layouts_are) + return layouts(GNDHWC, GKZYXC, GNDHWK); + if constexpr(layouts_are) + return layouts(NDHWGC, GKZYXC, NDHWGK); + if constexpr(layouts_are) + return layouts(NGCDHW, GKZYXC, NGKDHW); + if constexpr(layouts_are) + return layouts(NGCDHW, GKCZYX, NGKDHW); + break; } } @@ -418,39 +393,26 @@ constexpr auto conv_layout() /// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32). template constexpr builder::DataType conv_data_type() + requires HasDataTypes> { using InstTraits = InstanceTraits; using ADataType = typename InstTraits::ADataType; + using enum builder::DataType; if constexpr(std::is_same_v) - { - return builder::DataType::FP16; - } + return FP16; else if constexpr(std::is_same_v) - { - return builder::DataType::BF16; - } + return BF16; else if constexpr(std::is_same_v) - { - return builder::DataType::FP32; - } + return FP32; else if constexpr(std::is_same_v) - { - return builder::DataType::FP8; - } + return FP8; else if constexpr(std::is_same_v) - { - return builder::DataType::I8; - } + return I8; else if constexpr(std::is_same_v) - { - return builder::DataType::U8; - } + return U8; else - { - // Default fallback - return builder::DataType::FP32; - } + return FP32; // Default fallback } /// @brief Derives the elementwise operation from op type. @@ -459,27 +421,19 @@ constexpr builder::DataType conv_data_type() template constexpr builder::ElementwiseOperation elementwise_op() { + using enum builder::ElementwiseOperation; constexpr std::string_view name = detail::elementwise_op_name(); + if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) - { - return builder::ElementwiseOperation::BIAS_BNORM_CLAMP; - } - else if constexpr(detail::case_insensitive_equal(name, "Clamp")) - { - return builder::ElementwiseOperation::CLAMP; - } - else if constexpr(detail::case_insensitive_equal(name, "Scale")) - { - return builder::ElementwiseOperation::SCALE; - } - else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) - { - return builder::ElementwiseOperation::PASS_THROUGH; - } - else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) - { - return builder::ElementwiseOperation::SCALEADD_SCALEADD_RELU; - } + return BIAS_BNORM_CLAMP; + if constexpr(detail::case_insensitive_equal(name, "Clamp")) + return CLAMP; + if constexpr(detail::case_insensitive_equal(name, "Scale")) + return SCALE; + if constexpr(detail::case_insensitive_equal(name, "PassThrough")) + return PASS_THROUGH; + if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) + return SCALEADD_SCALEADD_RELU; } /// @brief Derives a gemm padding from a kernel instance type. @@ -487,6 +441,7 @@ constexpr builder::ElementwiseOperation elementwise_op() /// @return A `builder::GemmPadding` enum value corresponding to kernel padding. template constexpr builder::GemmPadding gemm_spec() + requires HasGemmSpec> { using InstTraits = InstanceTraits; using enum builder::GemmPadding; @@ -494,69 +449,24 @@ constexpr builder::GemmPadding gemm_spec() constexpr auto gemm_spec = InstTraits::kGemmSpecialization; - if constexpr(gemm_spec == Default) + switch(gemm_spec) { - return DEFAULT; - } - else if constexpr(gemm_spec == MPadding) - { - return M_PADDING; - } - else if constexpr(gemm_spec == NPadding) - { - return N_PADDING; - } - else if constexpr(gemm_spec == KPadding) - { - return K_PADDING; - } - else if constexpr(gemm_spec == MNPadding) - { - return MN_PADDING; - } - else if constexpr(gemm_spec == MKPadding) - { - return MK_PADDING; - } - else if constexpr(gemm_spec == NKPadding) - { - return NK_PADDING; - } - else if constexpr(gemm_spec == MNKPadding) - { - return MNK_PADDING; - } - else if constexpr(gemm_spec == OPadding) - { - return O_PADDING; - } - else if constexpr(gemm_spec == MOPadding) - { - return MO_PADDING; - } - else if constexpr(gemm_spec == NOPadding) - { - return NO_PADDING; - } - else if constexpr(gemm_spec == KOPadding) - { - return KO_PADDING; - } - else if constexpr(gemm_spec == MNOPadding) - { - return MNO_PADDING; - } - else if constexpr(gemm_spec == MKOPadding) - { - return MKO_PADDING; - } - else if constexpr(gemm_spec == NKOPadding) - { - return NKO_PADDING; - } - else if constexpr(gemm_spec == MNKOPadding) - { - return MNKO_PADDING; + case Default: return DEFAULT; + case MPadding: return M_PADDING; + case NPadding: return N_PADDING; + case KPadding: return K_PADDING; + case MNPadding: return MN_PADDING; + case MKPadding: return MK_PADDING; + case NKPadding: return NK_PADDING; + case MNKPadding: return MNK_PADDING; + case OPadding: return O_PADDING; + case MOPadding: return MO_PADDING; + case NOPadding: return NO_PADDING; + case KOPadding: return KO_PADDING; + case MNOPadding: return MNO_PADDING; + case MKOPadding: return MKO_PADDING; + case NKOPadding: return NKO_PADDING; + case MNKOPadding: return MNKO_PADDING; } } @@ -571,6 +481,7 @@ struct ConvTraits; /// set of traits directly from a fully-formed device kernel `Instance` type. /// It uses `InstanceTraits` to access the kernel's template parameters. template + requires IsXdlFwdConv> struct ConvTraits { using InstTraits = InstanceTraits; From f7650ee82b306a05d9c3c44d3feefdd570a4bd58 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Fri, 5 Dec 2025 09:30:22 +0100 Subject: [PATCH 04/24] fix enforcing fixedvectorsizes for ck tile conv (#3344) --- .../gemm_universal_pipeline_ag_bg_cr_policy.hpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index d843916f5e..76341af70b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -545,7 +545,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA() { using AsLayout = remove_cvref_t; using AsDataType = remove_cvref_t; @@ -555,6 +555,11 @@ struct UniversalGemmBasePolicy using ALayout = remove_cvref_t{}, AsLayout>>; using ADataType = remove_cvref_t{}, AsDataType>>; + if constexpr(Problem::FixedVectorSize) + { + return Problem::VectorSizeA; + } + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB() { using BsLayout = remove_cvref_t; using BsDataType = remove_cvref_t; @@ -584,6 +589,11 @@ struct UniversalGemmBasePolicy using BLayout = remove_cvref_t{}, BsLayout>>; using BDataType = remove_cvref_t{}, BsDataType>>; + if constexpr(Problem::FixedVectorSize) + { + return Problem::VectorSizeB; + } + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize Date: Fri, 5 Dec 2025 16:14:52 +0100 Subject: [PATCH 05/24] Add new section to changelog (#3295) * Add new section to changelog * Update CHANGELOG.md Co-authored-by: spolifroni-amd --------- Co-authored-by: spolifroni-amd --- CHANGELOG.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b07e322fe1..a50303113d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,15 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). +## (Unreleased) Composable Kernel 1.3.0 + +### Added +* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. + +### Changed + +### Upcoming changes + ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added From f5b0af22722b130f03cac590ca9b8729b1b84991 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Fri, 5 Dec 2025 07:44:10 -0800 Subject: [PATCH 06/24] Simplify includes for CK builder reflection (#3357) We only want to import enums and types into the builder reflection code. But, some of the enums are included in much larger files or even big trees of include files. This leads to unintended mixing of code and very confusing interactions and symbol conflicts. We organize the includes and extract two new enum-only headers to help with decoupling in CK. This refactoring is critical if we want to include reflection in a device-operator "describe" method. * Remove a few unnecessary includes from headers in builder/reflect/. * Extract enums scheduler and pipeline to their own headers so they can be used without importing other code. * Order includes alphabetically for better organization. The immediate goal is to unblock reflection integration, and this type of cleanup helps the flexibility and robustness of the CK header library. --- .../ck_tile/builder/reflect/conv_traits.hpp | 26 +++--- .../builder/reflect/instance_traits_util.hpp | 42 +++++----- .../test/test_bwd_data_instance_traits.cpp | 7 +- .../test/test_bwd_weight_instance_traits.cpp | 10 ++- .../builder/test/test_fwd_instance_traits.cpp | 22 ++--- .../test/test_instance_traits_util.cpp | 18 ++-- .../grid/gridwise_gemm_pipeline_selector.hpp | 27 +----- include/ck/utility/blkgemmpipe_scheduler.hpp | 44 +--------- include/ck/utility/loop_scheduler.hpp | 28 +------ include/ck/utility/pipeline_enum.hpp | 40 +++++++++ include/ck/utility/scheduler_enum.hpp | 83 +++++++++++++++++++ 11 files changed, 197 insertions(+), 150 deletions(-) create mode 100644 include/ck/utility/pipeline_enum.hpp create mode 100644 include/ck/utility/scheduler_enum.hpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 918fd6bdb6..e5a5638887 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -4,20 +4,20 @@ #pragma once #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/pipeline_enum.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck_tile/builder/conv_builder.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_util.hpp" +#include "ck_tile/builder/types.hpp" #include "ck_tile/ops/epilogue.hpp" -#include +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" namespace ck_tile::reflect::conv { diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 64996f96f7..1055cbc038 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -8,28 +8,30 @@ #pragma once #include -#include -#include -#include -#include -#include -#include #include -#include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ck_tile/ops/epilogue.hpp" +#include +#include +#include +#include +#include +#include +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/pipeline_enum.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" diff --git a/experimental/builder/test/test_bwd_data_instance_traits.cpp b/experimental/builder/test/test_bwd_data_instance_traits.cpp index 80e8ae8d98..f26b5d7caf 100644 --- a/experimental/builder/test/test_bwd_data_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_data_instance_traits.cpp @@ -2,9 +2,10 @@ // SPDX-License-Identifier: MIT #include -#include -#include -#include +#include "ck/ck.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" namespace { diff --git a/experimental/builder/test/test_bwd_weight_instance_traits.cpp b/experimental/builder/test/test_bwd_weight_instance_traits.cpp index 9b3cd169bb..c7c4e370e2 100644 --- a/experimental/builder/test/test_bwd_weight_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_weight_instance_traits.cpp @@ -2,10 +2,12 @@ // SPDX-License-Identifier: MIT #include -#include -#include -#include -#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" namespace { diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index 6a8f1f14e3..396533cef4 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -1,17 +1,19 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" namespace { diff --git a/experimental/builder/test/test_instance_traits_util.cpp b/experimental/builder/test/test_instance_traits_util.cpp index 42810ace72..852174b805 100644 --- a/experimental/builder/test/test_instance_traits_util.cpp +++ b/experimental/builder/test/test_instance_traits_util.cpp @@ -1,16 +1,16 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include #include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/sequence.hpp" +#include "ck_tile/builder/reflect/instance_traits_util.hpp" namespace ck_tile::reflect::detail { namespace { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 8d45b8fd74..751608299c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -5,24 +5,16 @@ #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include -#include #endif +#include "ck/utility/pipeline_enum.hpp" +#include "ck/utility/loop_scheduler.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp" namespace ck { -enum struct PipelineVersion -{ - v1, - v2, - // v3 is only used in the Stream-K implementation. - v4, - weight_only, -}; - template Prefetch stages, number of loop is multiple of unroll stages - Empty, - // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add - // prefetchstages - Full, -}; - enum SchedulerGroup : uint32_t { SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions diff --git a/include/ck/utility/loop_scheduler.hpp b/include/ck/utility/loop_scheduler.hpp index f186d0fea9..b3303e1138 100644 --- a/include/ck/utility/loop_scheduler.hpp +++ b/include/ck/utility/loop_scheduler.hpp @@ -3,40 +3,20 @@ #pragma once -#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -#include -#endif - #include "ck/utility/common_header.hpp" +#include "ck/utility/scheduler_enum.hpp" namespace ck { -enum struct LoopScheduler -{ - Default, - Interwave, -}; - +/// @brief Helper function to get default loop scheduler +/// @details Returns the default loop scheduler based on compile-time configuration. constexpr LoopScheduler make_default_loop_scheduler() { #if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING return LoopScheduler::Interwave; #else return LoopScheduler::Default; -#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING +#endif } } // namespace ck - -#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) -{ - switch(s) - { - case ck::LoopScheduler::Default: os << "Default"; break; - case ck::LoopScheduler::Interwave: os << "Interwave"; break; - default: os << ""; - } - return os; -} -#endif diff --git a/include/ck/utility/pipeline_enum.hpp b/include/ck/utility/pipeline_enum.hpp new file mode 100644 index 0000000000..4421386f59 --- /dev/null +++ b/include/ck/utility/pipeline_enum.hpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#include +#endif + +namespace ck { + +/// @brief Pipeline version enumeration for GEMM kernels +/// @details Defines different pipeline strategies for data movement and computation overlap +/// in GEMM kernels. This is a lightweight header containing only the enum definition, +/// extracted from gridwise_gemm_pipeline_selector.hpp to minimize dependencies. +enum struct PipelineVersion +{ + v1, ///< Version 1 pipeline + v2, ///< Version 2 pipeline + // v3 is only used in the Stream-K implementation. + v4, ///< Version 4 pipeline + weight_only, ///< Weight-only specialized pipeline +}; + +} // namespace ck + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) +{ + switch(p) + { + case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break; + case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break; + case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break; + case ck::PipelineVersion::weight_only: os << "PipelineVersion::weight_only"; break; + default: os << ""; + } + return os; +} +#endif diff --git a/include/ck/utility/scheduler_enum.hpp b/include/ck/utility/scheduler_enum.hpp new file mode 100644 index 0000000000..0c4bfabaf3 --- /dev/null +++ b/include/ck/utility/scheduler_enum.hpp @@ -0,0 +1,83 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#include +#endif + +namespace ck { + +/// @brief Block GEMM pipeline version enumeration +/// @details Defines different block GEMM pipeline strategies. +/// This is a lightweight header containing only enum definitions, +/// extracted from blkgemmpipe_scheduler.hpp to minimize dependencies. +enum struct BlockGemmPipelineVersion +{ + // For GEMM + v1, ///< Naive pipeline + v2, ///< Memory-optimized pipeline + v3, ///< Compute-optimized pipeline + v4, ///< Compute-optimized with double LDS buffer + v5, ///< Compute-optimized with double global prefetch register buffer + + // For GEMM with preshuffled weight + // v1, single lds buffer + // v2, double lds buffer +}; + +/// @brief Block GEMM pipeline scheduler enumeration +/// @details Defines scheduling strategies for block GEMM pipelines. +enum struct BlockGemmPipelineScheduler +{ + Intrawave, ///< Schedule within a single wavefront + Interwave, ///< Schedule across multiple wavefronts +}; + +/// @brief Loop scheduler enumeration +/// @details Defines scheduling strategies for computational loops. +enum struct LoopScheduler +{ + Default, ///< Default scheduling strategy + Interwave, ///< Cross-wavefront scheduling +}; + +/// @brief Tail number enumeration for pipeline buffering +/// @details Defines the number of tail iterations in pipelined loops. +enum struct TailNumber +{ + // Single / Double buffer pipeline + Odd, ///< Odd number of iterations + Even, ///< Even number of iterations + + // Long prefetch pipeline, up to 8 + One, ///< One tail iteration + Two, ///< Two tail iterations + Three, ///< Three tail iterations + Four, ///< Four tail iterations + Five, ///< Five tail iterations + Six, ///< Six tail iterations + Seven, ///< Seven tail iterations + + // Unroll stages > Prefetch stages, number of loop is multiple of unroll stages + Empty, ///< No tail iterations + // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add + // prefetchstages + Full, ///< Full tail iterations +}; + +} // namespace ck + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) +{ + switch(s) + { + case ck::LoopScheduler::Default: os << "Default"; break; + case ck::LoopScheduler::Interwave: os << "Interwave"; break; + default: os << ""; + } + return os; +} +#endif From 82f796a1f096219da34d614b6084a03bb23f8dc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 5 Dec 2025 17:20:46 +0100 Subject: [PATCH 07/24] Profile resnet layout fixes (#3360) --- .../include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp | 4 ++-- profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp b/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp index 3cda620831..47a12e2d88 100644 --- a/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp @@ -75,13 +75,13 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, is_same::value || is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout); } else if constexpr(is_same::value || is_same::value || is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout); } }; diff --git a/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp b/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp index 2a7ee6fd66..ac7ab78ed7 100644 --- a/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp @@ -75,13 +75,13 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, is_same::value || is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout); } else if constexpr(is_same::value || is_same::value || is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout); } }; From 7541d9b5b0e0ce241eb75476e3ef5d61ba019210 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Fri, 5 Dec 2025 08:26:00 -0800 Subject: [PATCH 08/24] Ignore .cmake-format.yaml (#3356) We don't want to add cmake formatting until we are in the super repo, but its handy if developers want to experiment with formatting. For now we should ignore .cmake-format.yaml. --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 2641a661d8..d8468cf24e 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,9 @@ tags # Editors .vscode +# CMake formatting configuration (local) +.cmake-format.yaml + # Cline .cline* From ed080f5a56c38caea8fedbd0bcc2919ba2376a6f Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Fri, 5 Dec 2025 09:35:27 -0700 Subject: [PATCH 09/24] Congma/ck tile/aquant mem pipeline (#3346) * [CK TILE GEMM QUANT] Fix the bug in HotLoopTail of memory pipeline --- .../run_gemm_quant_example.inc | 11 +- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 8 +- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 151 ++++++++++++++---- 3 files changed, 127 insertions(+), 43 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 396a54c7c2..0ee19b4a26 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -69,7 +69,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str using BaseGemmPipeline = std::conditional_t< GemmConfig::PreshuffleB == true, ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseGemmPipelineAgBgCrCompV3>; + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, + ck_tile::BaseGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BaseGemmPipelineAgBgCrCompV3>>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -128,7 +133,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::GemmPipelineAgBgCrCompV3, std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::AQuantGemmPipelineAgBgCrMem>, std::conditional_t, ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index 71e0ebb957..38a22e38ac 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -36,17 +36,13 @@ struct BaseGemmPipelineAgBgCrMem // TODO: Is this 32K value gfx9 arch specific? static constexpr index_t MinMemInFlyBytes = 32768; - static constexpr index_t WgpPerCU = - (4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1; + static constexpr index_t WgpPerCU = ck_tile::max(4 * get_warp_size() / BlockSize, 1); static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil(MinMemInFlyBytes / WgpPerCU, (MPerBlock * sizeof(ADataType) / APackedSize + NPerBlock * sizeof(BDataType) / BPackedSize) * KPerBlock); - static constexpr index_t PrefetchStages = - FullMemBandPrefetchStages >= 2 - ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 - : 2; + static constexpr index_t PrefetchStages = ck_tile::clamp(FullMemBandPrefetchStages, 2, 8); static constexpr index_t LocalPrefillStages = 1; static constexpr index_t GlobalBufferNum = PrefetchStages; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index f3c8b7a1a3..7f89d98349 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -80,6 +80,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -165,6 +168,19 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { using Base = PipelineImplBase; + template + CK_TILE_DEVICE static void + LoadAndConvertATile(ABlockTile_& a_block_tile, + ADramWindow& a_dram_window, + const DramTileWindowStep& dram_tile_window_step) + { + using DestDataType = typename ABlockTile_::DataType; + using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(a_block_tile, a_dram_window); + move_tile_window(a_dram_window, dram_tile_window_step); + } + template const BDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, const AQDramBlockWindowTmp& aq_dram_block_window_tmp, - index_t m, + [[maybe_unused]] index_t m, index_t num_loop, void* p_smem) const { - (void)m; // unused variable static_assert( std::is_same_v> && std::is_same_v std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; - static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)"); static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!"); - static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}], - "Aq block window has incorrect lengths for defined AqLayout!"); static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -217,7 +228,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem "B block window has incorrect lengths for defined BLayout!"); // A/B tiles in LDS - using the same approach as regular gemm pipeline - auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem); + auto ab_lds_blocks = Base::template GetABLdsTensorViews(p_smem); auto& a_lds_block = ab_lds_blocks.at(I0{}); auto& b_lds_block = ab_lds_blocks.at(I1{}); @@ -249,7 +260,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = @@ -272,7 +283,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); // Global prefetch initialization - DRAM to VGPRs - Base::GlobalPrefetch( + LoadAndConvertATile( a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); Base::GlobalPrefetch( b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); @@ -282,10 +293,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS prefill - VGPRs to LDS - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); } @@ -293,10 +304,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } @@ -306,9 +317,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } // Additional prefetching for memory pipeline - DRAM to VGPRs static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); + LoadAndConvertATile(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); @@ -325,16 +336,17 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, aq_block_tiles.get(number{}), a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); // Prepare next iteration data - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d( a_shuffle_tmp, @@ -348,7 +360,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -365,9 +377,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem b_element_func); } - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); + LoadAndConvertATile(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); @@ -381,20 +393,89 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } // Tail handling - block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); - block_gemm( - c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window); + auto HotLoopTail = [&](auto tail_num) { + static_for<0, tail_num - 1, 1>{}([&](auto prefetch_idx) { + block_sync_lds(); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + // no second block_sync_lds because it's interwave - if constexpr(TailNum == TailNumber::Even) - { + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, + a_block_tiles.get(number{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + Base::LocalPrefill(a_copy_lds_window, + a_block_tiles.get(number{})); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, + b_block_tiles.get(number{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + Base::LocalPrefill(b_copy_lds_window, + b_block_tiles.get(number{})); + } + }); - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I1{}), a_element_func); - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I1{}), b_element_func); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm( - c_block_tile, aq_block_tiles.get(I1{}), a_lds_gemm_window, b_lds_gemm_window); + c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window); + } + else if constexpr(TailNum == TailNumber::Two) + { + HotLoopTail(number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + HotLoopTail(number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + HotLoopTail(number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + HotLoopTail(number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + HotLoopTail(number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + HotLoopTail(number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + HotLoopTail(number{}); } return c_block_tile; } @@ -413,7 +494,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return PipelineImpl{} .template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const BDataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, From 608232ce82636e7c9ab8dec55dc7507c6792fb65 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 5 Dec 2025 08:39:18 -0800 Subject: [PATCH 10/24] do not build hipblaslt for gfx90a to save time and disc space (#3362) --- Dockerfile.pytorch | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.pytorch b/Dockerfile.pytorch index 9628bf46fa..2d3856fa2d 100644 --- a/Dockerfile.pytorch +++ b/Dockerfile.pytorch @@ -29,4 +29,4 @@ RUN groupadd -g 109 render && \ git sparse-checkout set projects/hipblaslt shared/origami && \ cd projects/hipblaslt && \ git show --oneline -s && \ - CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx90a;gfx942;gfx950" -j 128 --skip_rocroller + CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx942;gfx950" -j 128 --skip_rocroller From 6b1bceca7baea62941793e562d6ff58c571d9191 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Fri, 5 Dec 2025 09:57:52 -0800 Subject: [PATCH 11/24] [CK_Tile] Enable PreshuffleB for 2d block scale Gemm (#3298) * formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * debugging permuteN * debugging * debugging PermuteN * initial commit * resolving merge conflicts * adding test cases * fixing bq tensor calculation --------- Co-authored-by: Cong Ma Co-authored-by: Thomas Ning --- .../gemm_bquant_quantgrouped_preshuffleb.cpp | 192 ++++++++++++++++-- .../run_gemm_quant_example.inc | 27 ++- include/ck_tile/host/tensor_shuffle_utils.hpp | 10 +- ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 13 +- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 1 - .../test_gemm_quant_bquant_preshuffle.cpp | 44 +++- .../test_gemm_quant_fixtures.hpp | 6 +- 7 files changed, 257 insertions(+), 36 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp index 8ebf5bbd96..b32356c29d 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp @@ -14,36 +14,154 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; void bquant_quantgrouped_preshuffleb_instance_factory( std::unordered_map>& lut) { - using QuantGroupSize = ck_tile::QuantGroupShape>; lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "preshuffleb", + "non-preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; lut[hash_multiple_strings( {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, @@ -52,10 +170,50 @@ void bquant_quantgrouped_preshuffleb_instance_factory( lut[hash_multiple_strings( {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 0ee19b4a26..8a0dd9bc08 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -140,6 +140,13 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::WPQuantBPipelineAgBgCrV2, ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; + constexpr bool TiledPermuteN = + (QuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN; + if(s.log_level_ > 0) + { + printf( + "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, QuantGroupSize::kN); + } using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; @@ -382,7 +389,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, "K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode"); } } - ck_tile::index_t AQK, BQK; + ck_tile::index_t AQK, BQK, BQN = 0; if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize @@ -392,6 +399,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { AQK = 0; // No A quantization BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize + BQN = ck_tile::integer_divide_ceil(N, QuantGroupSize::kN); } else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant) @@ -431,7 +439,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { stride_AQ = 0; // No A quantization - stride_BQ = ck_tile::get_default_stride(BQK, N, stride_BQ, is_row_major(bq_layout)); + stride_BQ = ck_tile::get_default_stride(BQK, BQN, stride_BQ, is_row_major(bq_layout)); } else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) { @@ -471,7 +479,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::RowColQuant) { bq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout))); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); } else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) { @@ -557,7 +565,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, b_k_n.SetZero(); bq_tensor_ptr->SetZero(); } - ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); @@ -610,7 +617,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor b_k_n_dev = b_k_n; if constexpr(GemmConfig::PreshuffleB) { - if constexpr(GemmConfig::TiledMMAPermuteN) + if constexpr(GemmConfig::TiledMMAPermuteN && QuantGroupSize::kN == 1) { printf("PreshuffleB with TiledMMAPermuteN\n"); b_k_n_dev = ck_tile::shuffle_b_permuteN(b_k_n); @@ -635,11 +642,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant) { - if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN) + if constexpr(GemmConfig::PreshuffleB && GemmConfig::TiledMMAPermuteN && + QuantGroupSize::kN == 1) { - printf("Preshuffle BQ with TiledMMAPermuteN \n"); ck_tile::HostTensor bq_permuted_host = - ck_tile::bq_permuteN(*bq_tensor_ptr); + ck_tile::bq_permuteN(*bq_tensor_ptr, QuantGroupSize::kN); if constexpr(GemmConfig::PreshuffleQuant) { @@ -659,7 +666,9 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data()); } else + { bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data()); + } } invoke_gemm* t, int block_aq_k) } int m_ = t->get_lengths()[0]; int aqk_ = t->get_lengths()[1]; + if(aqk_ % block_aq_k != 0) { throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k."); @@ -110,7 +111,7 @@ auto shuffle_b(const ck_tile::HostTensor& t) } template -auto bq_permuteN(const ck_tile::HostTensor& t) +auto bq_permuteN(const ck_tile::HostTensor& t, index_t group_n) { assert(t.get_lengths().size() == 2); @@ -118,8 +119,11 @@ auto bq_permuteN(const ck_tile::HostTensor& t) int bqk_ = t.get_lengths()[0]; constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; - ck_tile::HostTensor t_view( - {n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_}); + ck_tile::HostTensor t_view({n_ / (GemmConfig::N_Tile / group_n), + GemmConfig::N_Warp, + GemmConfig::N_Warp_Tile / group_n, + NRepeat, + bqk_}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 2, 4}); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index b54a93614a..58b713cb35 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -28,7 +28,6 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg using QuantGroupSize = remove_cvref_t; static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); - static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!"); static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); @@ -205,7 +204,17 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg } else { - constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale; + index_t reg_offset = [&]() { + if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) + { + return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + + kQScale; + } + else + { + return nIter * KPerBlockBQ + kQScale; + } + }(); auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; float scale_reg_f = cvt_scale_to_fp32(scale_reg); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index f6cf4ce9be..dd85705cf2 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -747,7 +747,6 @@ struct QuantGemmKernel (splitk_batch_offset.splitted_k / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); index_t kFlatN = kargs.N * kargs.K / kFlatK; - return make_naive_tensor_view( b_ptr, make_tuple(kFlatN, kFlatK), diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp index 59b267842f..6cde4bded5 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp @@ -19,6 +19,12 @@ using PkInt4 = ck_tile::pk_int4_t; using BQuantGrouped = std::integral_constant; using GroupSize = ck_tile::QuantGroupShape>; +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; + // Type combinations for BQuant tests with PreshuffleB // Tuple format: @@ -37,7 +43,43 @@ using BPreshuffleBQuantTypes = ::testing::Types< std::tuple, std::tuple, std::tuple, - std::tuple + std::tuple, + + // //2d cases with preshuffle B + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 3b62d8073e..7b16529aa8 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -433,7 +433,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase b_k_n_dev = b_k_n; if constexpr(PreshuffleB) { - if constexpr(TiledMMAPermuteN) + if constexpr(TiledMMAPermuteN && QuantGroupSize::kN == 1) { printf("PreshuffleB with TiledMMAPermuteN\n"); b_k_n_dev = ck_tile::shuffle_b_permuteN(b_k_n); @@ -451,11 +451,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase bq_shuffle_host = - ck_tile::bq_permuteN(bq_bqk_bqn); + ck_tile::bq_permuteN(bq_bqk_bqn, QuantGroupSize::kN); bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data()); } else if constexpr(GemmConfig::PreshuffleQuant) From 86a84ae61122b8ed2d2e40e45f108a8fa23d3210 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Fri, 5 Dec 2025 14:18:30 -0800 Subject: [PATCH 12/24] Add the gfx1011 support on CK Tile with the SGPR builtin reading protection (#3350) * Finish the fixes * add the gfx1010 support macro * Fix the compilation error --- include/ck_tile/core/config.hpp | 7 ++++ .../core/tensor/tile_scatter_gather.hpp | 3 +- .../core/tensor/tile_window_linear.hpp | 3 +- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 37 +++++++++++++++---- 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index de97b46336..678a2fbfff 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -357,6 +357,12 @@ struct amdgcn_compiler_target_state #endif // __gfx950__ // GFX10 +#if defined(__gfx1010__) + static constexpr bool CK_TILE_ARCH_GFX1010 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1010 = false; +#endif + #if defined(__gfx1030__) static constexpr bool CK_TILE_ARCH_GFX1030 = true; #else @@ -493,6 +499,7 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se amdgcn_compiler_target_state::CK_TILE_ARCH_GFX90A, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \ diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 97a44f38e8..7a4da64c4a 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -533,7 +533,8 @@ struct tile_scatter_gather size_per_buf; const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - m0_set_with_memory(m0_init_value); // This should be wave independent + m0_set_with_memory( + amd_wave_read_first_lane(m0_init_value)); // This should be wave independent using Traits = load_store_traits; diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 815c1bf158..6c84122d01 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -517,7 +517,8 @@ struct tile_window_linear size_per_buf; const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - m0_set_with_memory(m0_init_value); // This should be wave independent + m0_set_with_memory( + amd_wave_read_first_lane(m0_init_value)); // This should be wave independent using vector_t = typename Base::Traits::vector_t; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index d83338fbb2..51f0f5f1b1 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -99,28 +99,49 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV template CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { + // Estimated number of VMEM vector loads for A per block: + // total A bytes / (threads per block * vector width) constexpr index_t Aload_inst = (kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize; + // Estimated number of VMEM vector loads for B per block: + // total B bytes / (threads per block * vector width) constexpr index_t Bload_inst = (kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize; + + // Estimated number of VMEM loads for B's quant data (e.g. scales / zp). + // First ceil-divide by quant group size (how many elements share one scale), + // then by vector width to get an approximate number of vector loads. constexpr index_t BQload_inst = ck_tile::integer_divide_ceil( ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType), QuantGroupSize::kK * QuantGroupSize::kK), VectorLoadSize); - constexpr index_t kLdsVec = 8; + + // ToDo: Hardcoded, need to change in future. How many instruction emit per iteration + constexpr index_t kLdsInstCycle = 8; + // Total VMEM load instructions (A + B + quant data) constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; - constexpr index_t ds_read_inst = kMPerBlock / kLdsVec; - constexpr index_t ds_write_inst = Aload_inst; - constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); - constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + // Approximate number of LDS reads per block + constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle; + // Approximate number of LDS writes per block + // (e.g., writing A from VMEM into LDS once per A load) + constexpr index_t ds_write_inst = Aload_inst; + // Number of MFMA instructions per wave for one block tile: + constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + // How often (in MFMA units) we should insert DS (LDS) operations. + constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + // How often (in MFMA units) we should insert VMEM buffer loads. + // buffer_load_rep ≈ "MFMA per VMEM_READ", clamped so that one buffer_load + // is assumed to cover at most 4 MFMA instructions. constexpr index_t buffer_load_rep = min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma - static_for<0, nloop, 1>{}([&](auto j_inst) { - ignore = j_inst; + static_for<0, nloop, 1>{}([&](auto) { static_for<0, mfma_inst, 1>{}([&](auto i_inst) { __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA + // Insert LDS read/write groups periodically based on ds_rep. + // The % pattern staggers READ and WRITE so they don't collapse + // into the same cycle in the model. if constexpr(ds_rep > 0 && i_inst % ds_rep == 0) { __builtin_amdgcn_sched_group_barrier( @@ -140,6 +161,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read } } + // Always mark some VALU work in the loop to reflect auxiliary scalar + // or vector ALU instructions that coexist with MFMA (Blockscale calculation). __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU }); }); From 8fec8054b2473bc7e367adf009f40a9d3fcc52df Mon Sep 17 00:00:00 2001 From: yinglu Date: Mon, 8 Dec 2025 16:24:20 +0800 Subject: [PATCH 13/24] ck: add tf32 in `DTYPES` to control instances build(#3317) --- CHANGELOG.md | 1 + CMakeLists.txt | 17 +++++++++++++++ README.md | 2 +- client_example/CMakeLists.txt | 12 +++++++++++ include/ck/config.h.in | 11 ++++++++++ .../gpu/grouped_convolution_backward_data.hpp | 16 ++++++++------ ...ped_convolution_backward_data_bilinear.hpp | 20 +++++++++++------- ...rouped_convolution_backward_data_scale.hpp | 20 +++++++++++------- .../grouped_convolution_backward_weight.hpp | 16 ++++++++------ ...d_convolution_backward_weight_bilinear.hpp | 11 +++++++--- ...uped_convolution_backward_weight_scale.hpp | 10 ++++++--- .../gpu/grouped_convolution_forward.hpp | 17 +++++++++------ ...d_convolution_forward_bias_bnorm_clamp.hpp | 16 ++++++++------ ...grouped_convolution_forward_bias_clamp.hpp | 18 +++++++++------- .../grouped_convolution_forward_bilinear.hpp | 10 ++++++--- .../gpu/grouped_convolution_forward_clamp.hpp | 17 ++++++++------- .../gpu/grouped_convolution_forward_scale.hpp | 10 ++++++--- .../gpu/CMakeLists.txt | 21 +++++++++++++------ .../src/profile_grouped_conv_bwd_data.cpp | 18 ---------------- .../src/profile_grouped_conv_bwd_weight.cpp | 16 -------------- profiler/src/profile_grouped_conv_fwd.cpp | 20 ------------------ .../profile_grouped_conv_fwd_bias_clamp.cpp | 6 ------ .../src/profile_grouped_conv_fwd_clamp.cpp | 6 ------ test/CMakeLists.txt | 6 ++++++ 24 files changed, 177 insertions(+), 140 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a50303113d..15fdb09f49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. +* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". ### Changed diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d0c4d79f9..acae1f5ece 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,6 +92,10 @@ if (DTYPES) add_definitions(-DCK_ENABLE_FP32) set(CK_ENABLE_FP32 "ON") endif() + if (DTYPES MATCHES "tf32") + # definition will be added based on the GPU target in the following section + set(CK_ENABLE_TF32 "ON") + endif() if (DTYPES MATCHES "fp64") add_definitions(-DCK_ENABLE_FP64) set(CK_ENABLE_FP64 "ON") @@ -106,6 +110,7 @@ else() set(CK_ENABLE_INT8 "ON") set(CK_ENABLE_FP16 "ON") set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_TF32 "ON") set(CK_ENABLE_FP64 "ON") set(CK_ENABLE_BF16 "ON") set(CK_ENABLE_FP8 "ON") @@ -282,6 +287,15 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") set(CK_GFX950_SUPPORT "ON") endif() +if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32) + add_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "ON") +else() + message(STATUS "Disabling TF32 instances") + remove_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "OFF") +endif() + option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH) @@ -651,6 +665,9 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") set(add_inst 1) endif() + if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32") + set(add_inst 1) + endif() if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") set(add_inst 1) endif() diff --git a/README.md b/README.md index 01d523c2ab..8a5258bab6 100644 --- a/README.md +++ b/README.md @@ -187,7 +187,7 @@ limit the number of threads. For example, if you have a 128-core CPU and 128 Gb Additional cmake flags can be used to significantly speed-up the build: -* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build +* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build instances of select data types only. The main default data types are fp32 and fp16; you can safely skip other data types. diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 2ed338d08a..cab84f5c6c 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -27,6 +27,9 @@ if (DTYPES) add_definitions(-DCK_ENABLE_FP32) set(CK_ENABLE_FP32 "ON") endif() + if (DTYPES MATCHES "tf32") + set(CK_ENABLE_TF32 "ON") + endif() if (DTYPES MATCHES "fp64") add_definitions(-DCK_ENABLE_FP64) set(CK_ENABLE_FP64 "ON") @@ -41,6 +44,7 @@ else() set(CK_ENABLE_INT8 "ON") set(CK_ENABLE_FP16 "ON") set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_TF32 "ON") set(CK_ENABLE_FP64 "ON") set(CK_ENABLE_BF16 "ON") if (GPU_TARGETS MATCHES "gfx94") @@ -67,6 +71,14 @@ if (GPU_TARGETS) add_definitions(-DCK_USE_FNUZ_FP8) set(CK_USE_FNUZ_FP8 "ON") endif() + if ((GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32) + add_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "ON") + else() + message(STATUS "Disabling TF32 instances for this target") + remove_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "OFF") + endif() else() add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) set(CK_USE_XDL "ON") diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 306a6c2ff1..113bf99243 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -55,6 +55,11 @@ #ifndef CK_ENABLE_FP32 #define CK_ENABLE_FP32 "ON" #endif +#ifndef CK_ENABLE_TF32 +#if defined(__gfx942__) || defined(__gfx95__) +#define CK_ENABLE_TF32 "ON" +#endif +#endif #ifndef CK_ENABLE_FP64 #define CK_ENABLE_FP64 "ON" #endif @@ -85,6 +90,12 @@ #cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@ #endif +#ifndef CK_ENABLE_TF32 +#if defined(__gfx942__) || defined(__gfx95__) +#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@ +#endif +#endif + #ifndef CK_ENABLE_FP64 #cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@ #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 03e3ae88a3..89009c6d0b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -115,12 +115,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); } #endif -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances( @@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( op_ptrs); @@ -139,8 +141,8 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -284,12 +286,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); } #endif -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( @@ -299,7 +301,9 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); @@ -308,8 +312,8 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp index cd65a2285a..84a715b70a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp @@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_in PassThrough, PassThrough, Bilinear>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( std::vector && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { static_assert(is_same_v, "ComputeTypeA and ComputeTypeB must be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp index 36980e5935..c898dbf781 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp @@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_insta PassThrough, PassThrough, Scale>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( std::vector && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { static_assert(is_same_v, " only support same compute type"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index e677f6f848..3fe8fa9c5a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -347,12 +347,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: ComputeTypeA and ComputeTypeB should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( @@ -367,7 +367,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -380,8 +382,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && @@ -610,12 +612,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: ComputeTypeA and ComputeTypeB should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( @@ -629,7 +631,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); @@ -642,8 +646,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp index 448a6b5d51..a0e8e46570 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp @@ -62,6 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_ PassThrough, Bilinear, PassThrough>>>& instances); +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp index acf9c9e150..64bbdf6ec5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp @@ -62,7 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_ins PassThrough, Scale, PassThrough>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index ba2f6b921a..5089ea2c1e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -198,12 +198,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same!"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); @@ -219,7 +219,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances( @@ -235,8 +237,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && @@ -451,10 +453,10 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -472,7 +474,10 @@ struct DeviceOperationInstanceFactory && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( @@ -488,8 +493,8 @@ struct DeviceOperationInstanceFactory && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp index 46bc0d2320..d4729f4d13 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp @@ -129,12 +129,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "A and B compute types should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { @@ -153,7 +153,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -170,8 +172,8 @@ struct DeviceOperationInstanceFactory && @@ -229,12 +231,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "A and B compute types should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { @@ -253,7 +255,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); @@ -270,8 +274,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( @@ -152,7 +152,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); @@ -169,9 +171,8 @@ struct DeviceOperationInstanceFactory && @@ -221,12 +222,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -244,7 +245,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); @@ -261,9 +264,8 @@ struct DeviceOperationInstanceFactory>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && DLayouts::Size() == 1 && is_same_v, NDHWGK>) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp index 90852d2945..090c99819f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp @@ -127,12 +127,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( @@ -150,7 +150,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); @@ -167,9 +169,8 @@ struct DeviceOperationInstanceFactory && @@ -218,12 +219,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -241,7 +242,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); @@ -258,8 +261,8 @@ struct DeviceOperationInstanceFactory>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && DLayouts::Size() == 0) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index eeaf269394..ef037526ca 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -13,6 +13,8 @@ function(add_instance_library INSTANCE_NAME) set(type1 "_f16") elseif(type MATCHES "fp32") set(type1 "_f32") + elseif(type MATCHES "tf32") + set(type1 "_tf32") elseif(type MATCHES "fp8") set(type1 "_f8") elseif(type MATCHES "bf16") @@ -27,8 +29,8 @@ function(add_instance_library INSTANCE_NAME) #if filename matches any selected type, exit type loop and do no exclude the file from the list set(test 0) break() - elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR - source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND + elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "tf32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR + source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_tf32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND NOT (source_name MATCHES type OR source_name MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) @@ -102,9 +104,11 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() # Only build tf32 instances for gfx942 & gfx950 - if(NOT (INST_TARGETS MATCHES "gfx942|gfx950") AND source_name MATCHES "_tf32_") - message(DEBUG "removing tf32 instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") + if(source_name MATCHES "_tf32_") + if(NOT ((INST_TARGETS MATCHES "gfx942|gfx950") AND CK_ENABLE_TF32)) + message(DEBUG "removing tf32 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() endif() endforeach() @@ -223,6 +227,10 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "fp32 instance found!") set(add_inst 1) endif() + if(("${cmake_instance}" MATCHES "_tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32") + message(DEBUG "tf32 instance found!") + set(add_inst 1) + endif() if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") message(DEBUG "fp64 instance found!") set(add_inst 1) @@ -237,6 +245,7 @@ FOREACH(subdir_path ${dir_list}) "${cmake_instance}" MATCHES "_f16" OR "${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32" OR + "${cmake_instance}" MATCHES "_tf32" OR "${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64" OR "${cmake_instance}" MATCHES "_bf16" OR @@ -330,7 +339,7 @@ FOREACH(subdir_path ${dir_list}) list(APPEND CK_DEVICE_OTHER_INSTANCES $) endif() message(DEBUG "add_instance_directory ${subdir_path}") - endif() + endif() else() message(DEBUG "skip_instance_directory ${subdir_path}") endif() diff --git a/profiler/src/profile_grouped_conv_bwd_data.cpp b/profiler/src/profile_grouped_conv_bwd_data.cpp index 62d6e860f9..cbf763fc13 100644 --- a/profiler/src/profile_grouped_conv_bwd_data.cpp +++ b/profiler/src/profile_grouped_conv_bwd_data.cpp @@ -84,9 +84,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using namespace ck::tensor_layout::convolution; @@ -143,9 +141,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -164,9 +160,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -185,9 +179,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW) @@ -206,9 +198,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } } @@ -230,9 +220,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -251,9 +239,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -272,9 +258,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -293,9 +277,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } } diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index a18aab41a5..c4f154e180 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -99,9 +99,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using namespace ck::tensor_layout::convolution; @@ -162,9 +160,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -184,9 +180,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -210,9 +204,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -243,9 +235,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -270,9 +260,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -306,9 +294,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -340,9 +326,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index c94b77dd4f..4319d849c8 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -105,9 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using INT8 = int8_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; -#if defined(__gfx942__) || defined(__gfx950__) using TF32 = ck::tf32_t; -#endif // using GNWC = ck::tensor_layout::convolution::GNWC; @@ -228,9 +226,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -253,9 +249,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -280,9 +274,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } // NHWGC_GKYXC_NHWGK @@ -306,9 +298,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -331,9 +321,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -352,9 +340,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) @@ -373,9 +359,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -416,9 +400,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } // NGCDHW_GKCZYX_NGKDHW @@ -439,9 +421,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp index 4eb12e6e19..79b9beb8c7 100644 --- a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp @@ -105,9 +105,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) using F32 = float; using BF16 = ck::bhalf_t; using F16 = ck::half_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; @@ -172,9 +170,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -194,9 +190,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_clamp.cpp index 7df9fd6167..f497ee8da5 100644 --- a/profiler/src/profile_grouped_conv_fwd_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_clamp.cpp @@ -105,9 +105,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) using F32 = float; using BF16 = ck::bhalf_t; using F16 = ck::half_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; @@ -175,9 +173,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -197,9 +193,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f8498c6c03..c221f11f46 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -65,6 +65,9 @@ function(add_test_executable TEST_NAME) if((source_name MATCHES "_fp32|_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() + if((source_name MATCHES "_tf32|_tf32") AND NOT "tf32" IN_LIST DTYPES) + set(test 1) + endif() if((source_name MATCHES "_fp64|_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() @@ -156,6 +159,9 @@ function(add_gtest_executable TEST_NAME) if((source_name MATCHES "_fp32|_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() + if((source_name MATCHES "_tf32|_tf32") AND NOT "tf32" IN_LIST DTYPES) + set(test 1) + endif() if((source_name MATCHES "_fp64|_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() From 04612c30ceab818cd6c03a3e833a6c6d1a21dafa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Mon, 8 Dec 2025 10:32:56 +0100 Subject: [PATCH 14/24] [CK_BUILDER] Ck Tile Grouped convolution factory (#3352) * [BUILDER] Ck Tile Grouped convolution factory * Part 2 * Fixes after rebase * Remove leftovers --- .../builder/conv_algorithm_concepts.hpp | 85 +++++++- .../ck_tile/builder/conv_algorithm_limits.hpp | 5 + .../builder/factory/conv_dispatcher.hpp | 29 ++- .../builder/factory/conv_fwd_dl_factory.hpp | 10 +- .../factory/conv_fwd_large_tensor_factory.hpp | 12 +- .../builder/factory/conv_fwd_v3_factory.hpp | 12 +- .../builder/factory/conv_fwd_wmma_factory.hpp | 12 +- .../builder/factory/conv_fwd_xdl_factory.hpp | 12 +- .../builder/factory/conv_tile_factory.hpp | 131 ++++++++++++ .../helpers/{ => ck}/conv_block_transfer.hpp | 0 .../helpers/{ => ck}/conv_elementwise_op.hpp | 0 .../helpers/{ => ck}/conv_tensor_layout.hpp | 0 .../helpers/{ => ck}/conv_tensor_type.hpp | 0 .../helpers/{ => ck}/conv_thread_block.hpp | 0 .../helpers/{ => ck}/conv_tuning_params.hpp | 0 .../ck_tile/conv_tile_block_transfer.hpp | 25 +++ .../ck_tile/conv_tile_elementwise_op.hpp | 62 ++++++ .../ck_tile/conv_tile_kernel_directions.hpp | 88 ++++++++ .../ck_tile/conv_tile_tensor_layout.hpp | 200 ++++++++++++++++++ .../helpers/ck_tile/conv_tile_tensor_type.hpp | 87 ++++++++ .../ck_tile/conv_tile_thread_block.hpp | 32 +++ .../ck_tile/conv_tile_tuning_params.hpp | 158 ++++++++++++++ .../builder/include/ck_tile/builder/types.hpp | 9 + experimental/builder/test/CMakeLists.txt | 31 +-- .../{ => ck}/test_ckb_conv_fwd_1d_bf16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_1d_fp16.cpp | 0 .../conv/{ => ck}/test_ckb_conv_fwd_1d_i8.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_bf16.cpp | 0 ...est_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_dl_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_fp32.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_2d_fp8.cpp | 0 ...test_ckb_conv_fwd_2d_large_tensor_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_3d_bf16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_3d_fp16.cpp | 0 .../{ => ck}/test_ckb_conv_fwd_3d_fp32.cpp | 0 .../test/conv/{ => ck}/test_conv_traits.cpp | 0 .../test_ckb_conv_bwd_data_2d_fp16_v3.cpp | 52 +++++ .../test_ckb_conv_bwd_weight_2d_fp16_v3.cpp | 52 +++++ .../ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp | 52 +++++ .../test/impl/conv_algorithm_types.hpp | 118 +++++++++++ .../builder/test/unit_conv_elementwise_op.cpp | 2 +- .../builder/test/unit_conv_tensor_layout.cpp | 2 +- .../builder/test/unit_conv_tensor_type.cpp | 2 +- .../builder/test/unit_conv_thread_block.cpp | 2 +- .../builder/test/unit_conv_tuning_params.cpp | 2 +- .../test/utils/ckb_conv_test_utils.hpp | 16 ++ .../test/utils/ckb_conv_tile_test_configs.hpp | 85 ++++++++ .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 4 +- .../gemm/pipeline/gemm_pipeline_problem.hpp | 7 +- ...ouped_convolution_backward_data_kernel.hpp | 17 +- ...ped_convolution_backward_weight_kernel.hpp | 37 ++-- .../grouped_convolution_forward_kernel.hpp | 36 ++-- .../utils/grouped_convolution_utils.hpp | 37 ++++ 55 files changed, 1431 insertions(+), 92 deletions(-) create mode 100644 experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_block_transfer.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_elementwise_op.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_tensor_layout.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_tensor_type.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_thread_block.hpp (100%) rename experimental/builder/include/ck_tile/builder/factory/helpers/{ => ck}/conv_tuning_params.hpp (100%) create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp create mode 100644 experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_1d_bf16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_1d_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_1d_i8.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_bf16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_dl_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_fp32.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_fp8.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_3d_bf16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_3d_fp16.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_ckb_conv_fwd_3d_fp32.cpp (100%) rename experimental/builder/test/conv/{ => ck}/test_conv_traits.cpp (100%) create mode 100644 experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp create mode 100644 experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp create mode 100644 experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp create mode 100644 experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index ecb1ff933e..bf7e89fcaa 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -95,6 +95,47 @@ concept AccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; }; +// Concept for thread block dimensions for a GEMM problem for CK Tile (Block +// size is deduced from block gemm structure). +template +concept TileThreadBlockDescriptor = requires(T t) { + { t.tile_size.m } -> std::convertible_to; + { t.tile_size.n } -> std::convertible_to; + { t.tile_size.k } -> std::convertible_to; +}; + +// Concept for thread block dimensions for a GEMM problem for CK Tile (Block +// size is deduced from block gemm structure). +template +concept TileTransferDescriptor = requires(T t) { + { t.a_scalar_per_vector } -> std::convertible_to; + { t.b_scalar_per_vector } -> std::convertible_to; + { t.c_scalar_per_vector } -> std::convertible_to; +}; + +// Concept to check if struct specifies block GEMM (CK Tile). +template +concept TileBlockGemmDescriptor = requires(T t) { + { t.warps.m } -> std::convertible_to; + { t.warps.n } -> std::convertible_to; + { t.warps.k } -> std::convertible_to; + { t.warp_tile.m } -> std::convertible_to; + { t.warp_tile.n } -> std::convertible_to; + { t.warp_tile.k } -> std::convertible_to; + { t.double_smem_buffer } -> std::convertible_to; + { t.num_wave_groups } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; + { t.scheduler } -> std::convertible_to; +}; + +// Concept to check if struct specifies optimizations (CK Tile). +template +concept TileOptimizationsDescriptor = requires(T t) { + { t.num_groups_to_merge } -> std::convertible_to; + { t.split_image } -> std::convertible_to; + { t.explicit_gemm } -> std::convertible_to; +}; + // Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this // concept. template @@ -110,6 +151,12 @@ concept SpecifiesThreadBlock = requires { { T::thread_block } -> ThreadBlockDescriptor; }; +// Concept to check if struct specifies thread block info (CK Tile). +template +concept SpecifiesTileThreadBlock = requires { + { T::thread_block } -> TileThreadBlockDescriptor; +}; + // Concept to check if a struct specifies gridwise XDL GEMM info. template concept SpecifiesGridwiseXdlGemm = requires { @@ -130,6 +177,14 @@ concept SpecifiesBlockTransfer = requires(T t) { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; +// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C. +template +concept SpecifiesTileTransfer = requires(T t) { + { T::transfer.a_scalar_per_vector } -> std::convertible_to; + { T::transfer.b_scalar_per_vector } -> std::convertible_to; + { T::transfer.c_scalar_per_vector } -> std::convertible_to; +}; + // Concept to check if a struct specifies LDS transfer info for tensors A, B, and C. template concept SpecifiesLdsTransfer = requires(T t) { @@ -159,8 +214,36 @@ concept SpecifiesBlockGemm = requires { { T::block_gemm.scheduler } -> std::convertible_to; }; +// Concept to check if struct specifies block GEMM (CK Tile). template -concept SpecifiesFwdConcSpecialization = requires { +concept SpecifiesTileBlockGemm = requires { + { T::block_gemm.warps.m } -> std::convertible_to; + { T::block_gemm.warps.n } -> std::convertible_to; + { T::block_gemm.warps.k } -> std::convertible_to; + { T::block_gemm.warp_tile.m } -> std::convertible_to; + { T::block_gemm.warp_tile.n } -> std::convertible_to; + { T::block_gemm.warp_tile.k } -> std::convertible_to; + { T::block_gemm.double_smem_buffer } -> std::convertible_to; + { T::block_gemm.num_wave_groups } -> std::convertible_to; + { T::block_gemm.pipeline_version } -> std::convertible_to; + { T::block_gemm.scheduler } -> std::convertible_to; +}; + +// Concept to check if struct specifies block GEMM (CK Tile). +template +concept SpecifiesTileOptimizations = requires { + { T::optimizations.num_groups_to_merge } -> std::convertible_to; + { T::optimizations.split_image } -> std::convertible_to; + { T::optimizations.explicit_gemm } -> std::convertible_to; +}; + +template +concept SpecifiesTileConvSpecialization = requires { + { T::specialization } -> std::convertible_to; +}; + +template +concept SpecifiesFwdConvSpecialization = requires { { T::fwd_specialization } -> std::convertible_to; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index 093916dac3..10a619024a 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -15,6 +15,11 @@ concept InputVectorTransferLimits = requires { Value.lds_dst_scalar_per_vector > 0; }; +// Limits for input and output vector transfer (CK Tile). +template +concept TileInputOutputVectorTransferLimits = + requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; }; + // Limits for output vector transfer. template concept OutputVectorTransferLimits = requires { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 51945544b2..9a9c2235e0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -59,6 +59,7 @@ #include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" +#include "ck_tile/builder/factory/conv_tile_factory.hpp" namespace ck_tile::builder::factory { @@ -81,6 +82,15 @@ namespace ck_tile::builder::factory { // // TODO: Make this dispatch logic much more robust and clear for users. +// CK Tile kernel +template +consteval bool IsTileAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && SpecifiesTileTransfer && + SpecifiesTileConvSpecialization && SpecifiesTileBlockGemm && + SpecifiesTileOptimizations; +} + // XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) template consteval bool IsXdlV3Algorithm() @@ -88,7 +98,7 @@ consteval bool IsXdlV3Algorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesBlockGemm; } @@ -99,7 +109,7 @@ consteval bool IsXdlAlgorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; } @@ -111,7 +121,7 @@ consteval bool IsWmmaAlgorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; } @@ -120,7 +130,7 @@ template consteval bool IsDlAlgorithm() { return ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; } @@ -137,10 +147,15 @@ template constexpr auto make_conv_instance() { - if constexpr(ConvDirectionIsForward) - { - using AlgoType = std::remove_const_t; + using AlgoType = std::remove_const_t; + // CK Tile supports common factory for each direction + if constexpr(IsTileAlgorithm()) + { + return typename ConvTileFactory::Instance{}; + } + else if constexpr(ConvDirectionIsForward) + { if constexpr(IsXdlV3Algorithm()) { return typename ConvFwdXdlV3Factory::Instance{}; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index 0c675ac7f1..ca202aabfd 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -7,11 +7,11 @@ #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 98e368ca61..fadf41f48a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 79955a1f44..89787cc1b3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index fcce46aea7..bb84479071 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index df7fb25168..8ec5c633ce 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp new file mode 100644 index 0000000000..cce95cb3f1 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp @@ -0,0 +1,131 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp" + +namespace ck_tile::builder::factory { + +// Factory for CK Tile Grouped Convolution kernels. +template +struct ConvTileFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::TileConvTensorLayouts; + using Types = internal::TileConvTensorTypes; + using Ops = internal::TileElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto CONV_SPECIALIZATION = internal::SetTileConvSpecialization(); + static constexpr auto BLOCK = internal::SetTileThreadBlockInfo(); + static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm(); + static constexpr auto OPTIMIZATIONS = internal::SetTileOptimizations(); + static constexpr auto SCALAR_PER_VECTOR = internal::SetTileBlockTransfer(); + static constexpr auto CONV_DIRECTION = internal::SetTileConvDirection(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(TileInputOutputVectorTransferLimits); + + using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + BLOCK_GEMM.double_smem_buffer, + typename GroupedConvTraitsType::template GemmLayouts::AsLayout, + typename GroupedConvTraitsType::template GemmLayouts::BsLayout, + typename GroupedConvTraitsType::template GemmLayouts::CLayout, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + BLOCK_GEMM.num_wave_groups>; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + GemmShape, + GemmUniversalTraits, + BLOCK_GEMM.scheduler, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Types::EDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = typename internal::TilePipelineType< + BLOCK_GEMM.pipeline_version>::template GemmPipeline; + + using ConvEpilogue = ck_tile::CShuffleEpilogue>; + + using Instance = typename internal::GroupedConvolutionTileKernel::Instance; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_block_transfer.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_thread_block.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_thread_block.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tuning_params.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp new file mode 100644 index 0000000000..fbeb48b045 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp @@ -0,0 +1,25 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +struct TileScalarPerVector +{ + size_t a = 0; + size_t b = 0; + size_t c = 0; +}; + +template +constexpr TileScalarPerVector SetTileBlockTransfer() +{ + return TileScalarPerVector{.a = ALGORITHM.transfer.a_scalar_per_vector, + .b = ALGORITHM.transfer.b_scalar_per_vector, + .c = ALGORITHM.transfer.c_scalar_per_vector}; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp new file mode 100644 index 0000000000..45ff7d265d --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp @@ -0,0 +1,62 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder::factory::internal { + +template +struct ElementwiseOpToCKTile +{ + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Unsupported elementwise operation conversion to CK."); +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::PassThrough; +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::Scale; +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::Clamp; +}; + +template +consteval auto GetTileElementwiseOp() +{ + if constexpr(HasTensorOp) + { + constexpr auto op = TensorDesc.operation.elementwise_operation; + return ElementwiseOpToCKTile{}; + } + else + { + return ElementwiseOpToCKTile{}; + } +} + +template +struct TileElementwiseOps +{ + static constexpr auto input_op = GetTileElementwiseOp(); + static constexpr auto weight_op = GetTileElementwiseOp(); + static constexpr auto output_op = GetTileElementwiseOp(); + using AElementwiseOp = typename decltype(input_op)::Op; + using BElementwiseOp = typename decltype(weight_op)::Op; + using CDEElementwiseOp = typename decltype(output_op)::Op; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp new file mode 100644 index 0000000000..189b199ffc --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_signature_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +template +struct GroupedConvolutionTileKernel +{ + static_assert(false, "Unknown Direction"); +}; + +template + requires ConvDirectionIsForward +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionForwardKernel; +}; + +template + requires ConvDirectionIsBackwardData +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionBackwardDataKernel; +}; + +template + requires ConvDirectionIsBackwardWeight +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionBackwardWeightKernel; +}; + +template +consteval ck_tile::GroupedConvDirection SetTileConvDirection() +{ + constexpr auto direction = SIGNATURE.direction; + using ck_tile_direction = ck_tile::GroupedConvDirection; + switch(direction) + { + case ConvDirection::FORWARD: return ck_tile_direction::FORWARD; + case ConvDirection::BACKWARD_DATA: return ck_tile_direction::BACKWARD_DATA; + case ConvDirection::BACKWARD_WEIGHT: return ck_tile_direction::BACKWARD_WEIGHT; + default: throw "Unknown Direction"; + } +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp new file mode 100644 index 0000000000..2aaca98586 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp @@ -0,0 +1,200 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" + +namespace ck_tile::builder::factory::internal { +using ALayout = ck_tile::tensor_layout::convolution::NWGC; +template +struct LayoutToCKTile +{ + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Unsupported layout conversion to CK."); +}; + +// Bias layouts +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::G_K; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::G_C; +}; + +// Input 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNWC; +}; + +// Input 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NHWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNHWC; +}; + +// Input 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NDHWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNDHWC; +}; + +// Weight 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKXC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCX; +}; + +// Weight 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKYXC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCYX; +}; + +// Weight 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCZYX; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKZYXC; +}; + +// Output 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNWK; +}; + +// Output 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NHWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNHWK; +}; + +// Output 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NDHWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNDHWK; +}; + +template +consteval auto TensorLayoutToCKTile() +{ + return typename LayoutToCKTile::type{}; +} + +struct EmptyAuxiliaryTileTensorLayout +{ + using type = ck_tile::tuple<>; +}; + +template +consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence) +{ + return ck_tile::tuple< + decltype(TensorLayoutToCKTile())...>{}; +} + +template + requires(ConvSpatialDim) +struct AuxiliaryTileTensorLayouts +{ + static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size(); + using type = decltype(GetAuxiliaryTileTensorLayoutTuple( + std::make_index_sequence{})); +}; + +// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias). +template + requires(HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTileTensorLayouts() +{ + return AuxiliaryTileTensorLayouts{}; +} + +template + requires(!HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTileTensorLayouts() +{ + return EmptyAuxiliaryTileTensorLayout{}; +} + +template + requires(ConvSpatialDim && + ValidConvInputLayoutForSpatialDim && + ValidConvWeightLayoutForSpatialDim && + ValidConvOutputLayoutForSpatialDim) +struct TileConvTensorLayouts +{ + using ALayout = decltype(TensorLayoutToCKTile()); + using BLayout = decltype(TensorLayoutToCKTile()); + using ELayout = decltype(TensorLayoutToCKTile()); + using DsLayout = decltype(GetAuxiliaryTileTensorLayouts())::type; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp new file mode 100644 index 0000000000..493fbb7d9b --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp @@ -0,0 +1,87 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/builder/types.hpp" +#include "ck_tile/builder/builder_utils.hpp" + +namespace ck_tile::builder::factory::internal { + +// Type mappings from builder convolution data type to CK Tile tensor types. +template +struct TileConvTensorTypes +{ + // This will trigger if a specialization for the given DataType is not found. + // We should always catch this in an earlier validation check. + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Internal error. Unsupported data type for convolution factory."); +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::half_t; + using AComputeType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using BComputeType = ck_tile::half_t; + using CShuffleDataType = ck_tile::half_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::half_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::bf16_t; + using AComputeType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using BComputeType = ck_tile::bf16_t; + using CShuffleDataType = ck_tile::bf16_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::bf16_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = float; + using AComputeType = float; + using BDataType = float; + using BComputeType = float; + using CShuffleDataType = float; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = float; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = int8_t; + using AComputeType = int8_t; + using BDataType = int8_t; + using BComputeType = int8_t; + using CShuffleDataType = int8_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = int32_t; + using EDataType = int8_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::fp8_t; + using AComputeType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using BComputeType = ck_tile::fp8_t; + using CShuffleDataType = ck_tile::fp8_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::fp8_t; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp new file mode 100644 index 0000000000..65d81a49c4 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +// Convenience struct for a tuple of m, n, and k values. +struct TileBlockMNK +{ + int m{}; + int n{}; + int k{}; +}; + +struct TileConvBlock +{ + TileBlockMNK per_block = {}; +}; + +template +constexpr TileConvBlock SetTileThreadBlockInfo() +{ + constexpr auto& TB = ALGORITHM.thread_block; + return TileConvBlock{ + .per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k}, + }; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp new file mode 100644 index 0000000000..b7df0e4d0e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp @@ -0,0 +1,158 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder::factory::internal { + +// Convenience struct for a tuple of m, n, and k values. +struct TileBlockGemmMNK +{ + int m{}; + int n{}; + int k{}; +}; + +struct TileBlockGemmSpec +{ + TileBlockGemmMNK warps = {}; + TileBlockGemmMNK warp_tile = {}; + + bool double_smem_buffer = false; + int num_wave_groups = 1; + + ck_tile::GemmPipeline pipeline_version; + ck_tile::GemmPipelineScheduler scheduler; +}; + +struct TileOptimizations +{ + int num_groups_to_merge = 1; + bool split_image = false; + bool explicit_gemm = false; +}; + +template +consteval ck_tile::GemmPipelineScheduler SetTileScheduler() +{ + constexpr auto scheduler = ALGORITHM.block_gemm.scheduler; + using ck_tile_sched = ck_tile::GemmPipelineScheduler; + switch(scheduler) + { + case PipelineScheduler::DEFAULT: return ck_tile_sched::Default; + case PipelineScheduler::INTERWAVE: return ck_tile_sched::Interwave; + case PipelineScheduler::INTRAWAVE: return ck_tile_sched::Intrawave; + default: throw "Unknown PipelineScheduler"; + } +} + +template +struct TilePipelineType +{ + static_assert(false, "Unknown PipelineScheduler"); +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; +}; + +template +consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion() +{ + constexpr auto version = ALGORITHM.block_gemm.pipeline_version; + using ck_tile_pipeline = ck_tile::GemmPipeline; + switch(version) + { + case PipelineVersion::V1: return ck_tile_pipeline::BASIC_V1; + case PipelineVersion::V2: return ck_tile_pipeline::MEMORY; + case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3; + case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4; + case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5; + case PipelineVersion::WEIGHT_ONLY: + throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version."; + default: throw "Unknown block GEMM PipelineVersion"; + } +} + +template +consteval ck_tile::ConvolutionSpecialization SetTileConvSpecialization() +{ + constexpr auto specialization = ALGORITHM.specialization; + using ck_tile_conv_spec = ck_tile::ConvolutionSpecialization; + switch(specialization) + { + case TileConvSpecialization::DEFAULT: return ck_tile_conv_spec::Default; + case TileConvSpecialization::FILTER_1X1_PAD0: return ck_tile_conv_spec::Filter1x1Pad0; + case TileConvSpecialization::FILTER_1X1_STRIDE1_PAD0: + return ck_tile_conv_spec::Filter1x1Stride1Pad0; + case TileConvSpecialization::FILTER_3x3: return ck_tile_conv_spec::Filter3x3; + default: throw "Unknown ConvFwdSpecialization"; + } +} + +template +consteval TileBlockGemmSpec SetTileBlockGemm() +{ + constexpr auto& BG = ALGORITHM.block_gemm; + + constexpr bool double_smem_buffer = BG.double_smem_buffer; + constexpr int num_wave_groups = BG.num_wave_groups; + + constexpr ck_tile::GemmPipeline pipeline_version = SetTileBlockGemmPipelineVersion(); + constexpr ck_tile::GemmPipelineScheduler scheduler = SetTileScheduler(); + + return TileBlockGemmSpec{ + .warps = {.m = BG.warps.m, .n = BG.warps.n, .k = BG.warps.k}, + .warp_tile = {.m = BG.warp_tile.m, .n = BG.warp_tile.n, .k = BG.warp_tile.k}, + .double_smem_buffer = double_smem_buffer, + .num_wave_groups = num_wave_groups, + .pipeline_version = pipeline_version, + .scheduler = scheduler}; +} + +template +consteval TileOptimizations SetTileOptimizations() +{ + constexpr auto& OPT = ALGORITHM.optimizations; + + return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge, + .split_image = OPT.split_image, + .explicit_gemm = OPT.explicit_gemm}; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 565bb98528..532d8a1882 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -145,6 +145,15 @@ enum struct GemmSpecialization MNKOPadding }; +// Enums for the CK Tile convolution specialization. +enum class TileConvSpecialization +{ + DEFAULT, + FILTER_1X1_PAD0, + FILTER_1X1_STRIDE1_PAD0, + FILTER_3x3 +}; + // Enums for the forward convolution specialization. enum class ConvFwdSpecialization { diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index a340a789de..eef1110d27 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -90,7 +90,7 @@ add_ck_builder_test(test_ckb_conv_builder # Tests convolution trait selection and configuration add_ck_builder_test(test_ckb_conv_traits - conv/test_conv_traits.cpp) + conv/ck/test_conv_traits.cpp) # Tests convolution problem description and parameter handling add_ck_builder_test(test_ckb_conv_description @@ -119,19 +119,22 @@ add_ck_builder_test(test_ckb_instance_string # Tests the forward convolution builder across multiple data types and dimensions. # Individual tests are split into separate files to enable parallel compilation. add_ck_builder_test(test_ckb_build_fwd_instances - conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp - conv/test_ckb_conv_fwd_1d_fp16.cpp - conv/test_ckb_conv_fwd_1d_bf16.cpp - conv/test_ckb_conv_fwd_1d_i8.cpp - conv/test_ckb_conv_fwd_2d_fp8.cpp - conv/test_ckb_conv_fwd_2d_bf16.cpp - conv/test_ckb_conv_fwd_2d_fp16.cpp - conv/test_ckb_conv_fwd_2d_fp32.cpp - conv/test_ckb_conv_fwd_2d_dl_fp16.cpp - conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp - conv/test_ckb_conv_fwd_3d_bf16.cpp - conv/test_ckb_conv_fwd_3d_fp16.cpp - conv/test_ckb_conv_fwd_3d_fp32.cpp + conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp + conv/ck/test_ckb_conv_fwd_1d_fp16.cpp + conv/ck/test_ckb_conv_fwd_1d_bf16.cpp + conv/ck/test_ckb_conv_fwd_1d_i8.cpp + conv/ck/test_ckb_conv_fwd_2d_fp8.cpp + conv/ck/test_ckb_conv_fwd_2d_bf16.cpp + conv/ck/test_ckb_conv_fwd_2d_fp16.cpp + conv/ck/test_ckb_conv_fwd_2d_fp32.cpp + conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp + conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp + conv/ck/test_ckb_conv_fwd_3d_bf16.cpp + conv/ck/test_ckb_conv_fwd_3d_fp16.cpp + conv/ck/test_ckb_conv_fwd_3d_fp32.cpp + conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp + conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp + conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp ) diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp similarity index 100% rename from experimental/builder/test/conv/test_conv_traits.cpp rename to experimental/builder/test/conv/ck/test_conv_traits.cpp diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp new file mode 100644 index 0000000000..ad31fc52bc --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::BACKWARD_DATA, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_backward_data", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp new file mode 100644 index 0000000000..47908e0e5b --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::BACKWARD_WEIGHT, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_backward_weight", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp new file mode 100644 index 0000000000..083d9d9955 --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_forward", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index d89d83357f..29c7f3cdcc 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -243,6 +243,73 @@ struct LargeTensorWrapper ConvAlgorithmSpecialization::LARGE_TENSOR; }; +// Specify thread block dimensions for a GEMM (CK Tile). +struct TileThreadBlock +{ + // Size of the submatrix problem in a thread block. + MNK tile_size; +}; +static_assert(ckb::TileThreadBlockDescriptor); + +struct TileTransfer +{ + size_t a_scalar_per_vector; + size_t b_scalar_per_vector; + size_t c_scalar_per_vector; +}; +static_assert(ckb::TileTransferDescriptor); + +struct TileBlockGemm +{ + // Number of warps per each dimension. + MNK warps; + // Number of data processed per each dimension for each XDL/WMMA instruction. + MNK warp_tile; + // Double LDS buffer. + bool double_smem_buffer; + // Waves grouping (Ping-Pong scheduler). + int num_wave_groups; + PipelineVersion pipeline_version; + PipelineScheduler scheduler; +}; +static_assert(ckb::TileBlockGemmDescriptor); + +struct TileOptimizations +{ + // Number of convolution groups processed per one workgroup + int num_groups_to_merge; + // Split image for large tensors + bool split_image; + // Explicit gemm for 1x1, stride=0, pad=0 cases + bool explicit_gemm; +}; +static_assert(ckb::TileOptimizationsDescriptor); + +struct TileConvSpecialization_ +{ + TileConvSpecialization specialization; +}; + +struct TileThreadBlock_ +{ + TileThreadBlock thread_block; +}; + +struct TileTransfer_ +{ + TileTransfer transfer; +}; + +struct TileBlockGemm_ +{ + TileBlockGemm block_gemm; +}; + +struct TileOptimizations_ +{ + TileOptimizations optimizations; +}; + // Factory template @@ -339,6 +406,51 @@ struct ConvAlgorithmTemplate : Components... result.transfer = t; return result; } + + template + constexpr auto with_tile_specializations(const S& s) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.specialization = s; + return result; + } + + template + constexpr auto with_tile_thread_block(const TB& tb) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.thread_block = tb; + return result; + } + + template + constexpr auto with_tile_block_gemm(const BG& bg) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.block_gemm = bg; + return result; + } + + template + constexpr auto with_tile_transfer(const T& t) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.transfer = t; + return result; + } + + template + constexpr auto with_tile_optimizations(const O& o) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.optimizations = o; + return result; + } }; // Algorithm types @@ -361,4 +473,10 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = LargeTensorWrapper; +using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate; + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/unit_conv_elementwise_op.cpp b/experimental/builder/test/unit_conv_elementwise_op.cpp index 84a9c533f6..610edd281e 100644 --- a/experimental/builder/test/unit_conv_elementwise_op.cpp +++ b/experimental/builder/test/unit_conv_elementwise_op.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp index 7764e94dc6..26df33cc8d 100644 --- a/experimental/builder/test/unit_conv_tensor_layout.cpp +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" #include "impl/conv_signature_types.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tensor_type.cpp b/experimental/builder/test/unit_conv_tensor_type.cpp index c92b24626e..7ffd446966 100644 --- a/experimental/builder/test/unit_conv_tensor_type.cpp +++ b/experimental/builder/test/unit_conv_tensor_type.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_thread_block.cpp b/experimental/builder/test/unit_conv_thread_block.cpp index f829708696..ce5a772cfa 100644 --- a/experimental/builder/test/unit_conv_thread_block.cpp +++ b/experimental/builder/test/unit_conv_thread_block.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index 82117c53d8..b35a1ced55 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -3,7 +3,7 @@ #include -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" namespace { diff --git a/experimental/builder/test/utils/ckb_conv_test_utils.hpp b/experimental/builder/test/utils/ckb_conv_test_utils.hpp index 508c621c2e..1acf170455 100644 --- a/experimental/builder/test/utils/ckb_conv_test_utils.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_utils.hpp @@ -28,4 +28,20 @@ constexpr void run_test(const std::vector& kernel_instance_componen } } +// Common CK Tile test implementation +template +constexpr void run_ck_tile_test(const std::vector& kernel_instance_components) +{ + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + std::cout << kernel_string << std::endl; + for(const auto& component : kernel_instance_components) + { + EXPECT_THAT(kernel_string, ::testing::HasSubstr(component)); + } +} + } // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp new file mode 100644 index 0000000000..377234dd19 --- /dev/null +++ b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "impl/conv_algorithm_types.hpp" +#include "impl/conv_signature_types.hpp" +#include "ck_tile/builder/conv_builder.hpp" + +namespace ck_tile::builder::test_utils { + +using namespace ck_tile::builder; +using namespace test; + +constexpr TileTransfer FwdTileTransfer_1x1x1{ + .a_scalar_per_vector = 1, + .b_scalar_per_vector = 1, + .c_scalar_per_vector = 1, +}; + +constexpr TileTransfer FwdTileTransfer_4x4x4{ + .a_scalar_per_vector = 4, + .b_scalar_per_vector = 4, + .c_scalar_per_vector = 4, +}; + +constexpr TileTransfer FwdTileTransfer_8x8x8{ + .a_scalar_per_vector = 8, + .b_scalar_per_vector = 8, + .c_scalar_per_vector = 8, +}; + +constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}}; + +constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V1, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v2_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V2, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v3_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V3, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v4_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V4, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v5_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V5, + .scheduler = PipelineScheduler::INTRAWAVE}; + +} // namespace ck_tile::builder::test_utils diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index d4475e8c60..8fae704203 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -176,8 +176,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); return concat('_', "pipeline_AgBgCrCompV3", concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), concat('x', WaveNumM, WaveNumN), - concat('x', kPadM, kPadN, kPadK)); + concat('x', kPadM, kPadN, kPadK), + Problem::GetName()); // clang-format on } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 2c6b1f3d48..e35f4ce70d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -301,7 +301,12 @@ struct UniversalGemmPipelineProblem return concat('_', "gemm_problem", concat('x', kBlockSize), concat('x', kPadM, kPadN, kPadK), - Scheduler); + Scheduler, + "NumWaveGroups", + NumWaveGroups, + "DoubleSmemBuffer", + DoubleSmemBuffer + ); // clang-format on } }; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index e172e732fa..46c60cb6d7 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -560,16 +560,31 @@ struct GroupedConvolutionBackwardDataKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { + static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off return concat('_', "grouped_convolution_backward_data", gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, "gemm", GemmPipeline::GetName(), "epilogue", - EpiloguePipeline::GetName()); + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 6ef1d84a6e..f43bfdacac 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -417,26 +417,31 @@ struct GroupedConvolutionBackwardWeightKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { - constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off - if (NumGroupsToMerge > 1) { - return concat('_', "grouped_convolution_backward_weight", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName()); - } else { - return concat('_', "grouped_convolution_backward_weight", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName(), "merge", NumGroupsToMerge); - } + return concat('_', "grouped_convolution_backward_weight", + gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 72ba17c5a5..a9f3274805 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -594,26 +594,28 @@ struct GroupedConvolutionForwardKernel { constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off - if (NumGroupsToMerge > 1) { - return concat('_', "grouped_convolution_forward", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName(), - "merge", - NumGroupsToMerge); - } else { - return concat('_', "grouped_convolution_forward", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName()); - } + return concat('_', "grouped_convolution_forward", + gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 71739c9083..5b00e53af8 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -9,6 +9,13 @@ namespace ck_tile { +enum class GroupedConvDirection +{ + FORWARD, + BACKWARD_DATA, + BACKWARD_WEIGHT +}; + /// @brief The Grouped Conv kernel host arguments. /// /// @par Overview @@ -113,6 +120,36 @@ struct GroupedConvTraits using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; + template + struct GemmLayouts + { + static_assert(false, "Unsupported direction."); + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutFwd; + using BsLayout = BsLayoutFwd; + using CLayout = CLayoutFwd; + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutBwdData; + using BsLayout = BsLayoutBwdData; + using CLayout = CLayoutBwdData; + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutBwdWeight; + using BsLayout = BsLayoutBwdWeight; + using CLayout = CLayoutBwdWeight; + }; + template using GroupedConvImplicitGemmTraitsFwd = TileGemmTraits; From 878b4e7f46d7e47618f4d860d71b438cb6d992fd Mon Sep 17 00:00:00 2001 From: Yi DING Date: Mon, 8 Dec 2025 19:20:44 +0800 Subject: [PATCH 15/24] [CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2 (#3287) * [CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2 * typo --- .../ops/flatmm/kernel/mx_flatmm_kernel.hpp | 134 +++++++----------- ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 81 ++++++----- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 47 +++++- 3 files changed, 141 insertions(+), 121 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index d9fb144176..1133da33ad 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -18,21 +18,21 @@ struct MXFlatmmKernel : FlatmmKernel; - using TilePartitioner = remove_cvref_t; - using FlatmmPipeline = remove_cvref_t; + using TilePartitioner = remove_cvref_t; + using MXFlatmmPipeline = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; using DsLayout = remove_cvref_t; using DsDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; - static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel; + static constexpr index_t KernelBlockSize = MXFlatmmPipeline::BlockSize; + static constexpr bool UsePersistentKernel = MXFlatmmPipeline::UsePersistentKernel; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. using EDataType = remove_cvref_t; @@ -43,9 +43,9 @@ struct MXFlatmmKernel : FlatmmKernel::PackedSize; static constexpr int BPackedSize = numeric_traits::PackedSize; - static constexpr int MXdlPack = FlatmmPipeline::MXdlPack; - static constexpr int NXdlPack = FlatmmPipeline::NXdlPack; - static constexpr int KXdlPack = FlatmmPipeline::KXdlPack; + static constexpr int MXdlPack = MXFlatmmPipeline::MXdlPack; + static constexpr int NXdlPack = MXFlatmmPipeline::NXdlPack; + static constexpr int KXdlPack = MXFlatmmPipeline::KXdlPack; static constexpr index_t NumDTensor = DsDataType::size(); @@ -63,7 +63,7 @@ struct MXFlatmmKernel : FlatmmKernel, FlatmmPipeline::GetName()); + return concat('_', "mx_flatmm_gemm", gemm_prec_str, MXFlatmmPipeline::GetName()); // clang-format on } @@ -123,33 +123,23 @@ struct MXFlatmmKernel : FlatmmKernel) - { - return make_naive_tensor_view( - a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } + static_assert(std::is_same_v, + "A tensor for mx must be RowMajor"); + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); }(); - constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock; + constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock; constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1); constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile; const index_t kFlatKBlocks = kargs.K / kKPerBlock; const index_t kFlatN = kargs.N / kNWarpTile; const auto& b_flat_tensor_view = [&]() { - static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0, + static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0, "wrong! vector size for B tensor"); auto&& naive_desc = make_naive_tensor_descriptor_packed( make_tuple(kFlatN, kFlatKBlocks, number{})); @@ -262,20 +252,12 @@ struct MXFlatmmKernel : FlatmmKernel) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } + static_assert(std::is_same_v, + "A tensor for mx must be RowMajor"); + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); }(); const auto& b_flat_tensor_view = views.at(I1); @@ -289,14 +271,14 @@ struct MXFlatmmKernel : FlatmmKernel{}, number{}), - sequence{}); + sequence{}); } else { return pad_tensor_view(d_tensor_view[i], make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }, number{}); @@ -309,14 +291,14 @@ struct MXFlatmmKernel : FlatmmKernel{}, number{}), - sequence{}); + sequence{}); } else { return pad_tensor_view(e_tensor_view, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }(); @@ -334,26 +316,18 @@ struct MXFlatmmKernel : FlatmmKernel) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } + static_assert(std::is_same_v, + "A tensor for mx must be RowMajor"); + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); }(); const auto& b_flat_block_window = make_tile_window(b_flat_pad_view, - make_tuple(number{}, - number{}), + make_tuple(number{}, + number{}), {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); const auto ds_block_window = generate_tuple( @@ -444,14 +418,14 @@ struct MXFlatmmKernel : FlatmmKernel(kargs.a_ptr) + - splitk_batch_offset.a_k_split_offset / APackedSize; - const BDataType* b_flat_ptr = static_cast(kargs.b_ptr) + - splitk_batch_offset.b_k_split_offset / BPackedSize; + const auto a_ptr = static_cast(kargs.a_ptr) + + splitk_batch_offset.a_k_split_offset / APackedSize; + const auto b_flat_ptr = static_cast(kargs.b_ptr) + + splitk_batch_offset.b_k_split_offset / BPackedSize; EDataType* e_ptr = static_cast(kargs.e_ptr); // allocate LDS @@ -501,7 +475,7 @@ struct MXFlatmmKernel : FlatmmKernel::value)) { - constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); + constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1); RunFlatmm(a_ptr, b_flat_ptr, kargs.ds_ptr, diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index ff799cb0fc..87ae7f57d8 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -34,13 +34,11 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem; using CLayout = remove_cvref_t; + static constexpr index_t APackedSize = numeric_traits::PackedSize; + static constexpr index_t BPackedSize = numeric_traits::PackedSize; + using BlockFlatmm = remove_cvref_t())>; @@ -81,8 +82,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1::PackedSize; - static constexpr index_t BPackedSize = numeric_traits::PackedSize; + // static constexpr index_t WG_AKPacks = WG::kK / APackedSize; + // static constexpr index_t WG_BKPacks = WG::kK / BPackedSize; static constexpr index_t MXdlPack = Problem::MXdlPack; static constexpr index_t NXdlPack = Problem::NXdlPack; static constexpr index_t KXdlPack = Problem::KXdlPack; static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK; - static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize; - static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize; + static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType); + static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType); static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) ? DsReadPreload @@ -562,11 +563,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeMX_BFlatDramTileDistribution()); + auto b_flat_dram_window = PipelinePolicy::template MakeMX_BFlatBytesDramWindow( + b_flat_dram_block_window_tmp); auto b_flat_dram_offsets = generate_tuple( [&](auto nIter) { constexpr auto packed_n_idx = nIter / number{}; @@ -621,7 +619,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, true_type{}, false_type{}); + async_load_tile(lds, dram, number<-1>{}, true_type{}, true_type{}); }; // HEAD @@ -633,11 +631,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); }); // move B window to next flat K b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); + tuple, number>{}); }); // prefetch Scale A @@ -698,12 +697,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); + tuple, number>{}); }); }); @@ -739,8 +738,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_ping(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) @@ -792,12 +793,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); + tuple, number>{}); }); }); @@ -833,8 +834,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_pong(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_pong(number{})(number{})), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) @@ -897,7 +900,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); }); }); @@ -932,8 +935,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_ping(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) @@ -986,8 +991,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_pong(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_pong(number{})(number{})), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) @@ -1029,8 +1036,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_ping(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 969cddf3e7..4d76ab7da2 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -255,9 +255,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; + using TileShape = typename Problem::BlockGemmShape; + using BDataType = remove_cvref_t; + constexpr index_t BPack = numeric_traits::PackedSize; static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16"); @@ -282,21 +284,56 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tile_distribution_encoding< // sequence, tuple, // 4 2 - sequence>, // 1 64 32 + sequence>, // 1 64 32 tuple, sequence<2>>, tuple, sequence<1>>, sequence<2>, sequence<2>>, tile_distribution_encoding< // sequence, - tuple, // 4 2 - sequence>, // 2 1 64 16 + tuple, // 4 2 + sequence>, // 2 1 64 16 tuple, sequence<2>>, tuple, sequence<2>>, sequence<2, 2>, sequence<0, 3>>>{}); } + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp) + { + + using BDataType = remove_cvref_t; + constexpr auto BPackedSize = numeric_traits::PackedSize; + constexpr auto kKPerBlock = Problem::BlockGemmShape::kK; + constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1); + constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp; + constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp; + + static_assert(std::decay_t::get_num_of_dimension() == 2); + auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); + const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); + constexpr auto flat_k_per_block = kKPerBlock * M_Warp_Tile; + auto&& byte_tensor_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple( + flat_n, flat_k / flat_k_per_block, number{})), + make_tuple(make_pass_through_transform(flat_n), + make_merge_transform_v3_division_mod(make_tuple( + flat_k / flat_k_per_block, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); + auto&& byte_tensor_view = + make_tensor_view(byte_ptr, byte_tensor_desc); + auto&& origin_tmp = window_tmp.get_window_origin(); + return make_tile_window( + byte_tensor_view, + make_tuple(number{}, number{}), + {origin_tmp[0], origin_tmp[1] / BPackedSize}, + MakeMX_BFlatBytesDramTileDistribution()); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() { From ca6143f0b2237a1af80ef5550f1b774fd463676d Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Mon, 8 Dec 2025 20:44:17 +0500 Subject: [PATCH 16/24] Add a workaround for a compiler issue for bwd on gfx90a and ROCm 7.1.1 (#3369) Sometimes there are not enough wait-states between v_mfma_f32... and v_accvgpr_read_b32 instructions if they are separated by s_cbranch. The workaround is to read accvgprs to vgpr before branching. --- .../block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 854e45c432..7cc424597a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -552,6 +552,15 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR }); }); } +#if defined(__gfx9__) + else + { + // Workaround for a compiler issue: sometimes there are not enough wait-states + // between v_mfma_f32... and v_accvgpr_read_b32 instructions if they are separated + // by s_cbranch. + tile_elementwise_inout([](auto& x) { asm("; force move to %0" : "+v"(x)); }, s_acc); + } +#endif { bool need_perpixel_check = mask.IsEdgeTile( From fe07b5a1bff597df8a81fb227aee0ac95e06b197 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Mon, 8 Dec 2025 21:19:22 +0100 Subject: [PATCH 17/24] [CK Tile] Grouped GEMM aquant mode and non-persistent kernel (#3337) * wip: add aquant to grouped gemm quant example * fix: properly handle hot loop count in aquant pipeline * fix: add separate GemmConfig structs for AQuant, automatically select the correct one * feat: finish support for a non-persistent kernel invocation for grouped gemm quant, and add support code to example * refactor: cleaned up grouped gemm quant example a bit by reusing pipeline selection logic * chore: add warp gemm dispatchers for a couple of TransposeC K=32 variants * feat: add quant grouped gemm tests cases for aquant (regular and transpose C) and non-persistent kernel * fix: update base pipeline classes according to changes in develop branch * Revert "chore: add warp gemm dispatchers for a couple of TransposeC K=32 variants" This reverts commit b3fd4d326d9ccb13e6902bd470bbe76fb323ba54. * feat: remove aquant config from grouped gemm quant example, update to add persistency as runtime parameter * chore: removed work-around for aquant bug that has been fixed * chore: fix typo in command-line parameters * fix: correct K warp tile size for gfx950 * chore: incorrect warp tile configuration on gfx942 --- .../17_grouped_gemm/quant_grouped_gemm.cpp | 225 ++++++++++-- .../17_grouped_gemm/quant_grouped_gemm.hpp | 75 +++- .../quant_run_grouped_gemm_example.inc | 233 ++++++++---- .../kernel/grouped_gemm_quant_kernel.hpp | 109 +++++- .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 49 ++- .../ck_tile/grouped_gemm_quant/CMakeLists.txt | 3 + .../test_grouped_gemm_quant.cpp | 51 +-- .../test_grouped_gemm_quant_aquant.cpp | 38 ++ .../test_grouped_gemm_quant_bquant.cpp | 11 +- .../test_grouped_gemm_quant_rowcol.cpp | 13 +- .../test_grouped_gemm_quant_tensor.cpp | 13 +- .../test_grouped_gemm_util_quant.hpp | 334 +++++++++++++++--- 12 files changed, 948 insertions(+), 206 deletions(-) create mode 100644 test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index d8b905fe3d..d3b75ac72f 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -9,14 +9,190 @@ #include #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm_quant.hpp" #include "ck_tile/host.hpp" #include "quant_grouped_gemm.hpp" +template +float grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + GemmQuantConfig::template BaseGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; + + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + template ; // Persistence + GemmConfig::Persistent>; float ave_time{0}; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - constexpr bool transpose_c = false; - using QuantGemmProblem = typename std::conditional< - QuantMode == ck_tile::QuantType::BQuantGrouped, - ck_tile::GemmBQuantPipelineProblem, + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; + + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, ck_tile::GemmRowColTensorQuantPipelineProblem>::type; + scheduler>>; - using GemmPipeline = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(argc, argv); + int result1 = run_grouped_gemm_example(argc, argv); return result1; } diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index ede683abe6..0317685770 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -64,6 +64,7 @@ struct GemmTypeConfig using CDataType = ck_tile::half_t; }; +template struct GemmConfigBase { static constexpr bool kPadM = false; @@ -83,10 +84,11 @@ struct GemmConfigBase static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool DoubleSmemBuffer = false; static constexpr bool PreshuffleB = false; + static constexpr bool Persistent = Persistent_; }; -template -struct GemmConfigComputeV3_2 : public GemmConfigBase +template +struct GemmConfigComputeV3_2 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -101,8 +103,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; -template -struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase +template +struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -121,6 +123,66 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr bool DoubleSmemBuffer = true; }; +template +struct GemmQuantConfig; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; + + template + using GemmPipeline = std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>; + + template + using BaseGemmPipeline = + std::conditional_t, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; +}; + using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; auto create_args(int argc, char* argv[]) @@ -148,8 +210,9 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") .insert("kbatch", "1", "kbatch for SplitK") - .insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol") - .insert("init", "0", "0. Random, 2. One(s) (Constant)"); + .insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol") + .insert("init", "0", "0. Random, 2. One(s) (Constant)") + .insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 37fab44f77..37832b54ba 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -57,56 +57,83 @@ float invoke_gemm(int n_warmup, float ave_time = 0; - // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have - // the gemm problems known on the host. Instead, we can just pass the pointer - // to the kernel and let the workgroups figure out which tiles to work on. - // This is useful when the gemm problems are generated dynamically. - // In this example however, we generate the `kargs` using the known gemm_descs, - // and copy the gemm descriptions to the device memory. - // The contents of the memory pointed to by `kargs_ptr` pointer could be - // written by e.g. another kernel from earlier stage. - std::vector kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - assert(args[0].k_batch == 1); - for(const auto& arg : args) + if constexpr(!GemmConfig::Persistent) { - kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, - arg.b_ptr, - arg.aq_ptr, - arg.bq_ptr, - arg.e_ptr, - arg.M, - arg.N, - arg.K, - arg.QK_A, - arg.QK_B, - arg.stride_A, - arg.stride_B, - arg.stride_E, - arg.stride_AQ, - arg.stride_BQ, - arg.k_batch}); + ave_time = + grouped_gemm(args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if(args[0].k_batch != 1) + { + throw std::runtime_error("Split-K not supported yet for persistent kernel"); + } + + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); } - const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), - hipMemcpyHostToDevice, - stream.stream_id_)); - ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; @@ -259,13 +286,24 @@ int run_grouped_gemm_example_with_layouts(int argc, AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize + BQK = 0; // No B quantization + if(K % QuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for AQuantGrouped mode"); + } + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { AQK = 0; // No A quantization BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize if(K % QuantGroupSize::kK != 0) { - throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for BQuantGrouped mode"); } } @@ -284,6 +322,12 @@ int run_grouped_gemm_example_with_layouts(int argc, stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout)); + stride_BQs[i] = 0; // No B quantization + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { stride_AQs[i] = 0; // No A quantization @@ -311,10 +355,17 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout)))); } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(0, 0, stride_BQs[i], is_row_major(bq_layout)))); + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { aq_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout)))); + ck_tile::host_tensor_descriptor(0, 0, stride_AQs[i], is_row_major(aq_layout)))); bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); } @@ -444,7 +495,7 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors[i], c_m_n_host_ref); } - else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } + else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); } @@ -477,7 +539,7 @@ int run_grouped_gemm_example_with_layouts(int argc, return pass; } -template +template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; @@ -494,6 +556,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a if(a_layout == "R" && b_layout == "C") { + return run_grouped_gemm_example_with_layouts typename GemmConfig> +template +int run_gemm_example_persistency( + std::string a_layout, std::string b_layout, bool persistent, int argc, char* argv[]) +{ + if(persistent) + { + using GemmConfig = GemmQuantConfig::template GemmConfig; + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else + { + using GemmConfig = GemmQuantConfig::template GemmConfig; + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } +} + int run_grouped_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -524,29 +604,29 @@ int run_grouped_gemm_example(int argc, char* argv[]) const std::string b_layout = arg_parser.get_str("b_layout"); const std::string data_type = arg_parser.get_str("prec"); std::string quant_mode = arg_parser.get_str("quant_mode"); + bool persistent = arg_parser.get_bool("persistent"); if(data_type == "fp8") { if(quant_mode == "tensor") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::TensorQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "rowcol") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::RowColQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); + } + else if(quant_mode == "aquant") + { + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "bquant") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::BQuantGrouped>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else { @@ -557,24 +637,23 @@ int run_grouped_gemm_example(int argc, char* argv[]) { if(quant_mode == "tensor") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::TensorQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "rowcol") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::RowColQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); + } + else if(quant_mode == "aquant") + { + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "bquant") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::BQuantGrouped>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else { diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index caa6aad363..726f678d37 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -163,7 +163,6 @@ struct QuantGroupedGemmKernel static constexpr index_t kBlockSize = GemmPipeline::BlockSize; static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel; - static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true"); [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -262,10 +261,9 @@ struct QuantGroupedGemmKernel auto karg = QuantGroupedGemmKernelArgs{type_convert(gemm_descs[i].a_ptr), type_convert(gemm_descs[i].b_ptr), - type_convert(gemm_descs[i].e_ptr), type_convert(gemm_descs[i].aq_ptr), type_convert(gemm_descs[i].bq_ptr), - gemm_descs[i].k_batch, + type_convert(gemm_descs[i].e_ptr), M, N, K, @@ -275,7 +273,8 @@ struct QuantGroupedGemmKernel stride_b, stride_e, gemm_descs[i].stride_AQ, - gemm_descs[i].stride_BQ}; + gemm_descs[i].stride_BQ, + gemm_descs[i].k_batch}; gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } @@ -342,16 +341,32 @@ struct QuantGroupedGemmKernel else { - RunGemmWithPipelineSelection(a_ptr, - b_ptr, - aq_ptr, - bq_ptr, - c_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); + if constexpr(UsePersistentKernel) + { + RunGemmWithPipelineSelection(a_ptr, + b_ptr, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else // Non-persistent kernel + { + Base::RunGemm({a_ptr}, + {b_ptr}, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } } } @@ -451,7 +466,24 @@ struct QuantGroupedGemmKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - if constexpr(kQuantType == QuantType::BQuantGrouped) + if constexpr(kQuantType == QuantType::AQuantGrouped) + { + const auto& aq_block_window = gemm_tile_windows.at(Base::I1); + // Run GEMM pipeline + const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, + b_block_window, + aq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + + auto& c_block_window = gemm_tile_windows.at(Base::I4); + + // Run Epilogue Pipeline + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) { const auto& bq_block_window = gemm_tile_windows.at(Base::I3); // Run GEMM pipeline @@ -496,6 +528,53 @@ struct QuantGroupedGemmKernel } } + CK_TILE_DEVICE index_t FindGroupId(const QuantGemmTransKernelArg* gemm_desc_ptr, + index_t block_id, + index_t group_count) const + { + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) >> 1); + + while((!(block_id >= gemm_desc_ptr[group_id].block_start && + block_id < gemm_desc_ptr[group_id].block_end)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) >> 1); + } + + return group_id; + } + + // For non-persistent kernels + template > + CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + index_t group_count) const + { + const index_t block_id = ck_tile::get_block_1d_id(); + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count); + const auto& kargs = gemm_desc_ptr[group_id]; + + const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N); + const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( + 0, + kargs.group_karg.M, + kargs.group_karg.N, + (block_id - kargs.block_start) % grid_size_2d); + Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d); + } + // For persistent kernels template , diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 30b9d70eb8..e7bd4a2626 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -319,6 +319,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem, + index_t m = 0) const + { + const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) { + constexpr bool hot_loop = has_hot_loop_.value; + constexpr auto tail_num = tail_number_.value; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + aq_dram_block_window_tmp, + m, // dummy value, won't be used + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } }; } // namespace ck_tile diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt index 2bd2571993..7a7ae77730 100644 --- a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -14,6 +14,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp index 551989421f..6a1a28884a 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp @@ -18,32 +18,41 @@ using True = ck_tile::bool_constant; using False = ck_tile::bool_constant; using RowColQuant = std::integral_constant; using TensorQuant = std::integral_constant; +using AQuant = std::integral_constant; using BQuant = std::integral_constant; // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True> + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp new file mode 100644 index 0000000000..8dcd6d017d --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_util_quant.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; +using AQuant = std::integral_constant; + +// clang-format off +using KernelTypes_AQuant = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_AQuant, KernelTypes_AQuant); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_AQuant +#include "test_grouped_gemm_quant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp index 4f44acf4c4..6c0ad545b7 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp @@ -20,9 +20,14 @@ using BQuant = std::integral_constant, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, False, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp index 48720aeebf..cc1b32fb20 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp @@ -20,11 +20,14 @@ using RowColQuant = std::integral_constant, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp index f59fa29ec2..e446f7b168 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp @@ -20,11 +20,14 @@ using TensorQuant = std::integral_constant, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index 68b6735655..9941066c3e 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -3,6 +3,7 @@ #pragma once #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" @@ -32,24 +33,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using AQLayout = Row; using BQLayout = Col; - static constexpr bool Persistent = true; static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value; - - template - static constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() - { -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif - } + static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value; + static constexpr bool TransposeC = std::tuple_element_t<12, Tuple>::value; struct GroupedGemKernelParam_Mfma { @@ -66,11 +52,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test static const ck_tile::index_t N_Warp = 2; static const ck_tile::index_t K_Warp = 1; - static const ck_tile::index_t M_Warp_Tile = 32; - static const ck_tile::index_t N_Warp_Tile = 32; - static const ck_tile::index_t K_Warp_Tile = - TestCkTileGroupedGemmQuant::template get_k_from_preshuffled_warp_tile(); + static const ck_tile::index_t M_Warp_Tile = 16; + static const ck_tile::index_t N_Warp_Tile = 16; + static const ck_tile::index_t K_Warp_Tile = 32; }; struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma @@ -90,16 +74,201 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); } + template + float invoke_grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + constexpr bool DoubleSmemBuffer = + PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B + + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || + QuantType == ck_tile::QuantType::BQuantGrouped; + + using QuantGroupSize = ck_tile::QuantGroupShape>; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = std::conditional_t< + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrCompV3, + std::conditional_t< + PreshuffleB == true, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, + ck_tile::BaseGemmPipelineAgBgCrCompV3>>, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile; + const ck_tile::index_t K_split = + (gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; + + using GemmPipeline = std::conditional_t< + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GroupedGemKernelParam::M_Warp, + GroupedGemKernelParam::N_Warp, + GroupedGemKernelParam::M_Warp_Tile, + GroupedGemKernelParam::N_Warp_Tile, + GroupedGemKernelParam::K_Warp_Tile, + QuantGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } + template void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr) { - constexpr bool TransposeC = false; constexpr bool DoubleSmemBuffer = PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B - constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; @@ -131,40 +300,53 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test BQLayout, TransposeC, DoubleSmemBuffer, - true>; + Persistent>; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; constexpr auto memory_operation = memory_operation_.value; - constexpr bool transpose_c = false; // We create the GEMM pipeline without specifying hotloop or tailnumber. // These are automatically run inside the kernel based on the given input data. - using QuantGemmProblem = typename std::conditional< - QuantType == ck_tile::QuantType::BQuantGrouped, - ck_tile::GemmBQuantPipelineProblem, + + constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || + QuantType == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, ck_tile::GemmRowColTensorQuantPipelineProblem>::type; + scheduler>>; using GemmPipeline = std::conditional_t< - QuantType == ck_tile::QuantType::RowColQuant || - QuantType == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( + ck_tile::make_kernel( Kernel{}, grids, blocks, @@ -292,13 +474,24 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize + BQK = 0; // No B quantization + if(K % QuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for AQuantGrouped mode"); + } + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { - AQK = 0; // No A quantization - BQK = K / 128; // Group quantization: BQK = K / GroupSize - if(K % 128 != 0) + AQK = 0; // No A quantization + BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize + if(K % QuantGroupSize::kK != 0) { - throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for BQuantGrouped mode"); } } @@ -317,6 +510,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout())); + stride_BQs[i] = 0; // No B quantization + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { stride_AQs[i] = 0; // No A quantization @@ -348,11 +547,20 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test ck_tile::HostTensor(ck_tile::host_tensor_descriptor( 1, 1, stride_BQs[i], is_row_major(BQLayout())))); } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + aq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + M, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + bq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + 0, 0, stride_BQs[i], is_row_major(BQLayout())))); + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { aq_tensors.push_back( ck_tile::HostTensor(ck_tile::host_tensor_descriptor( - 0, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + 0, 0, stride_AQs[i], is_row_major(AQLayout{})))); bq_tensors.push_back( ck_tile::HostTensor(ck_tile::host_tensor_descriptor( BQK, N, stride_BQs[i], is_row_major(BQLayout())))); @@ -429,11 +637,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if constexpr(Persistent) { // Generate kernel arguments std::vector kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); assert(gemm_descs[0].k_batch == 1); for(const auto& arg : gemm_descs) { @@ -471,7 +680,14 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test } else { - GTEST_FAIL() << "Non-persistent kernel not implemented yet"; + const auto stream = ck_tile::stream_config{nullptr, false, 1}; +#if CK_TILE_USE_WMMA + invoke_grouped_gemm( + gemm_descs, stream, kargs_ptr); +#else + invoke_grouped_gemm( + gemm_descs, stream, kargs_ptr); +#endif } // Copy results back to host for validation @@ -512,7 +728,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test bq_tensors[i], c_m_n_host_ref); } - else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } + else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); } @@ -550,5 +777,8 @@ using TestCkTileGroupedGemmQuant_RowCol = TestCkTileGroupedGemmQuant; template using TestCkTileGroupedGemmQuant_Tensor = TestCkTileGroupedGemmQuant; +template +using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant; + template using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant; From c363a98d4154c647c1a2d5331ad0d76879b84dfa Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Mon, 8 Dec 2025 21:05:56 +0000 Subject: [PATCH 18/24] [CK_TILE] Support more layouts for BQuant GEMM (#3349) * WIP: preparing to add transpose bq support * WIP: handle both row/col layout for BQ windows/tile dstr * Fix build * WIP: adding some test, debugging numerical errors * Fix all but pkint4 tests * Remove test_gemm_quant_typed.cpp again * update disabled tests * add conversion from pkint4 for b matrix * fix formatting * fix formatting * Fix tr_load and use override b datatype for clarity * fix formatting * make bquant preshuffle tests bqlayout column-major --- .../block_universal_gemm_as_bs_bquant_cr.hpp | 32 +++- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 66 ++++++-- .../gemm_bquant_pipeline_ag_bg_cr_base.hpp | 14 +- .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 22 ++- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 85 +++++++--- .../pipeline/gemm_group_quant_utils.hpp | 151 ++++++++++++------ .../gemm_block_scale/test_gemm_quant_base.hpp | 4 +- .../test_gemm_quant_bquant.cpp | 76 +++++---- .../test_gemm_quant_bquant_preshuffle.cpp | 90 +++++------ .../test_gemm_quant_fixtures.hpp | 8 +- 10 files changed, 359 insertions(+), 189 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index d97145cbc3..628e9194ae 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using BQDataType = remove_cvref_t; + using BQLayout = remove_cvref_t; using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; @@ -154,6 +155,10 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading + using OverrideBDataType = + std::conditional_t, ADataType, BDataType>; + using Base = BlockGemmBQuantBase; using WarpGemm = remove_cvref_t; @@ -271,12 +276,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase ALdsTile a_warp_tile_; BLdsTile b_warp_tile_; - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { - load_int4_tile(a_warp_tile_, a_block_window); - load_int4_tile(b_warp_tile_, b_block_window); + load_int4_tile( + a_warp_tile_, a_block_window); + // If B datatype were pkint4 it would be converted prior to storing in LDS + load_int4_tile( + b_warp_tile_, b_block_window); } // C += A * B @@ -397,11 +410,16 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase MakeCBlockTile(); } - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window); + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index dd85705cf2..203b79aec6 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -426,7 +426,6 @@ struct QuantGemmKernel if constexpr(kQuantType == QuantType::BQuantGrouped) { - static_assert(std::is_same_v); if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) @@ -781,7 +780,9 @@ struct QuantGemmKernel { if constexpr(PreshuffleQuant) { - static_assert(std::is_same_v); + static_assert(std::is_same_v, + "PreshuffleQuant with BQuantGrouped currently only supports " + "ColumnMajor BQ layout"); return MakePreshuffledQuantTensorView< GemmPipeline::KPerBlockBQ, @@ -791,14 +792,35 @@ struct QuantGemmKernel } else { - static_assert(std::is_same_v); using QuantGroupSize = remove_cvref_t; - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), - make_tuple(kargs.stride_BQ, 1), - number{}, - number<1>{}); + + if constexpr(std::is_same_v) + { + // For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN] + // Dimensions: [K/QuantGroupK, N/QuantGroupN] + // Strides: [N/QuantGroupN, 1] + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), + integer_divide_ceil(kargs.N, QuantGroupSize::kN)), + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), + number{}, + number<1>{}); + } + else + { + static_assert(std::is_same_v); + // For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK] + // Dimensions: [N/QuantGroupN, K/QuantGroupK] + // Strides: [K/QuantGroupK, 1] + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), + integer_divide_ceil(kargs.K, QuantGroupSize::kK)), + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1), + number{}, + number<1>{}); + } } } else @@ -1023,10 +1045,10 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::BQuantGrouped) { + using QuantGroupSize = remove_cvref_t; if constexpr(PreshuffleQuant) { static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; @@ -1042,13 +1064,23 @@ struct QuantGemmKernel } else { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); + if constexpr(std::is_same_v) + { + return make_tile_window( + bq_pad_view, + make_tuple(number{}, + number{}), + {0, i_n / QuantGroupSize::kN}); + } + else + { + static_assert(std::is_same_v); + return make_tile_window( + bq_pad_view, + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); + } } } else diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp index 4cd343e640..c570d4a131 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp @@ -42,14 +42,18 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase); - - using YPerTile = number; - using XPerTile = number; + using YPerTile = + std::conditional_t, + number, + number>; + using XPerTile = + std::conditional_t, + number, + number>; auto bq_copy_dram_window = make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(YPerTile(), XPerTile()), + make_tuple(YPerTile{}, XPerTile{}), bq_dram_block_window_tmp.get_window_origin(), Policy::template MakeBQDramTileDistribution()); return bq_copy_dram_window; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 870326cb9d..154d068f0a 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -25,8 +25,16 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; - static_assert(std::is_same_v); - return GetABQGlobalVectorLoadSize(); + // Support both RowMajor and ColumnMajor layouts for BQ + if constexpr(std::is_same_v) + { + return GetABQGlobalVectorLoadSize(); + } + else + { + static_assert(std::is_same_v); + return GetABQGlobalVectorLoadSize(); + } } template @@ -52,7 +60,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC WarpTile::at(I2), Problem::TransposeC>; - static_assert(std::is_same_v); if constexpr(PreshuffleQuant) { using TileEncodingPattern = tile_distribution_encoding_pattern_bq< @@ -62,18 +69,21 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC NPerBlock / WarpGemm::kN, ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()), VecLoadSize, + BQLayout, PreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution(); } else { + // KPerTile and NPerTile are LOGICAL dimensions (K quant groups and N quant groups) using TileEncodingPattern = tile_distribution_encoding_pattern_bq; + KPerBlockBQ, // Logical K dimension + NPerBlockBQ, // Logical N dimension + Problem::QuantGroupSize::kN, + BQLayout>; return TileEncodingPattern::make_2d_static_tile_distribution(); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 4883a30f57..2c191cc2b4 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -33,6 +33,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using QuantGroupSize = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading + using OverrideBDataType = + std::conditional_t, ADataType, BDataType>; + static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); using I0 = number<0>; using I1 = number<1>; @@ -83,6 +87,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -125,7 +132,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile, + const BDramWindow& b_dram_window) + { + using DestDataType = typename BBlockTile_::DataType; + using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(b_block_tile, b_dram_window); + } + template ; - constexpr bool is_bq_col_major = - std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; - - static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + constexpr bool is_bq_row_major = + std::is_same_v; static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -212,12 +227,22 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -237,7 +262,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(ABlockTileDistr{})); using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + decltype(make_static_distributed_tensor(BBlockTileDistr{})); using BQBlockTile = decltype(make_static_distributed_tensor(BQBlockTileDistr{})); @@ -258,18 +283,20 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}), 0) - : is_bq_col_major ? make_array(0, KPerBlockBQ) - : make_array(KPerBlockBQ, 0); + : is_bq_row_major ? make_array(KPerBlockBQ, 0) + : make_array(0, KPerBlockBQ); // DRAM prefetch (global read 0) Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // B tile gets converted to A datatype during loading + LoadAndConvertBTile(b_block_tile, b_copy_dram_window); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); Base::GlobalPrefetch( bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -281,9 +308,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // B datatype is converted to A datatype during loading + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -294,11 +322,13 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledARegTileDistribution()); @@ -322,9 +352,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: BDataType PkInt4 gets converted during loading earlier + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -335,7 +366,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: BDataType gets converted during loading from PkInt4 + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -393,7 +427,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern { @@ -210,36 +211,41 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) /// /// This function determines the optimal thread distribution pattern for loading and applying - /// quantization scales to the B matrix based on the quantization group size (XPerQ) relative + /// quantization scales to the B matrix based on the quantization group size (NPerQ) relative /// to warp dimensions. /// /// Three distinct distribution patterns are handled: /// - /// 1. Fine-grained quantization (XPerQ < WarpGemm::kN): + /// 1. Fine-grained quantization (NPerQ < WarpGemm::kN): /// - Multiple quantization groups exist within a single warp's N-dimension - /// - Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp) - /// - Distribution includes explicit replication factor (XR = XPerQ) for scale broadcast - /// - Example: XPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp + /// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp) + /// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast + /// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp /// - /// 2. Medium-grained quantization (WarpGemm::kN <= XPerQ <= WarpGemm::kN * NWarps): + /// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps): /// - Each warp handles exactly one quantization scale - /// - Scales are distributed across warps with replication factor XR = XPerQ / WarpGemm::kN - /// - Example: XPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 + /// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN + /// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 /// - /// 3. Coarse-grained quantization (XPerQ > WarpGemm::kN * NWarps): + /// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps): /// - Quantization group spans multiple warps /// - All warps share the same scale value - /// - Example: XPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale + /// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale /// /// @return A static tile distribution encoding for the BQ scale tensor CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { + // Preshuffle only supported for ColumnMajor currently + static_assert(!(PreshuffleQuant && std::is_same_v), + "PreshuffleQuant only supported for ColumnMajor BQLayout"); + if constexpr(PreshuffleQuant) { + // ColumnMajor only for preshuffle constexpr index_t X1 = warp_size; - constexpr index_t X0 = XPerTile / warp_size; + constexpr index_t X0 = NPerTile / warp_size; constexpr index_t Y1 = NWarps; - constexpr index_t Y0 = YPerTile / Y1; + constexpr index_t Y0 = KPerTile / Y1; return make_static_tile_distribution( tile_distribution_encoding, @@ -251,52 +257,97 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding } else { - if constexpr(YPerQ < WarpGemm::kN) + if constexpr(NPerQ < WarpGemm::kN) { // Case 1: Fine-grained - multiple quantization scales within a single warp - constexpr index_t X = XPerTile; // Full X dimension of tile - constexpr index_t XR = 1; // No Y replication needed - constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim - constexpr index_t Y1 = NWarps; // Number of warps in N-dim - constexpr index_t Y2 = WarpGemm::kN / YPerQ; // Number of scales per warp - constexpr index_t YR = YPerQ; // Elements per quantization group + // N dimension needs to be partitioned the same way regardless of layout + constexpr index_t NR = 1; // No N replication needed + constexpr index_t N0 = NIterPerWarp; // Iterations per warp in N-dim + constexpr index_t N1 = NWarps; // Number of warps in N-dim + constexpr index_t N2 = WarpGemm::kN / NPerQ; // Number of scales per warp - static_assert(Y0 * Y1 * Y2 == YPerTile, - "Y0, Y1, Y2 must cover the blocktile along Y."); + static_assert(N0 * N1 * N2 == NPerTile, + "N0, N1, N2 must cover the blocktile along N dimension."); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0, 1, 0>>, - tuple, sequence<1, 2, 2>>, - sequence<1, 2>, - sequence<0, 0>>{}); + if constexpr(std::is_same_v) + { + // ColumnMajor: [(N0, N1, N2), K] - N on Y-axis, partition Y + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1, 0>>, + tuple, sequence<1, 2, 2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, (N0, N1, N2)] - N on X-axis, partition X + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0>>, + tuple, sequence<1, 2, 2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } - else if constexpr(YPerQ <= WarpGemm::kN * NWarps) + else if constexpr(NPerQ <= WarpGemm::kN * NWarps) { // Case 2: Medium-grained - one quantization scale per warp - constexpr auto YR = YPerQ / WarpGemm::kN; // Scale replication factor - constexpr auto Y1 = NWarps / YR; // Warps per unique scale - constexpr auto Y0 = YPerTile / Y1; // Iterations to cover X dimension - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, - tuple, sequence<2>>, - sequence<1, 2>, - sequence<0, 0>>{}); + constexpr auto NR = NPerQ / WarpGemm::kN; // Scale replication factor + constexpr auto N1 = NWarps / NR; // Warps per unique scale + constexpr auto N0 = NPerTile / N1; // Iterations to cover N dimension + + if constexpr(std::is_same_v) + { + // ColumnMajor: [(N0, N1), K] - N on Y-axis + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, (N0, N1)] - N on X-axis + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } - else // XPerQ > WarpGemm::kN * NWarps + else // NPerQ > WarpGemm::kN * NWarps { // Case 3: Coarse-grained - quantization group spans all warps // All warps in N-dimension share the same quantization scale - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, - tuple, sequence<2>>, - sequence<2, 1>, - sequence<0, 0>>{}); + if constexpr(std::is_same_v) + { + // ColumnMajor: [N, K] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, N] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } } } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 38bd59b882..39a7c66f38 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -86,8 +86,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test using TilePartitioner = ck_tile::GemmTile1DPartitioner; - // BQLayout is always ColumnMajor for BQuant - using BQLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + // Re-use the AQLayout for BQLayout + using BQLayout = AQLayout; using CodegenGemmTraits = ck_tile::TileGemmQuantTraits>; using GroupSize2D128N = ck_tile::QuantGroupShape>; // Type combinations for BQuant tests (without PreshuffleB) -// Tuple format: // clang-format off using BQuantTypes = ::testing::Types< - // 1d cases with grouping only on k axis (AQLayout is always RowMajor for BQuant) - std::tuple, - std::tuple, - std::tuple, - std::tuple, + // 1d cases with grouping only on k axis + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, // 2d cases with grouping also on the n axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // some cases with transpose layouts + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple, + + // pkint4 + transpose cases + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp index 6cde4bded5..3a62fc091a 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp @@ -26,60 +26,60 @@ using GroupSize2D32N = ck_tile::QuantGroupShape>; using GroupSize2D64N = ck_tile::QuantGroupShape>; // Type combinations for BQuant tests with PreshuffleB -// Tuple format: // clang-format off using BPreshuffleBQuantTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, // //2d cases with preshuffle B - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 7b16529aa8..bf9c7a138d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -389,6 +389,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBaseis_row_major(BQLayout{}) ? BQN : BQK; // Generate test data ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); ck_tile::HostTensor b_k_n( ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); - // BQ is always ColumnMajor ck_tile::HostTensor bq_bqk_bqn( - ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, ck_tile::bool_constant{})); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{}))); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); From c1c2e41a0387e8e76970ad86959e28963f569d54 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Tue, 9 Dec 2025 11:02:33 +0800 Subject: [PATCH 19/24] [CK_TILE] Generate random tensor values with multiple threads (#3324) --- example/ck_tile/15_fused_moe/main.cpp | 33 ++-- .../18_flatmm/mxgemm/run_mx_flatmm.inc | 16 +- include/ck_tile/host/fill.hpp | 96 ++++++----- include/ck_tile/host/joinable_thread.hpp | 49 ++++++ test/ck_tile/utility/CMakeLists.txt | 2 + test/ck_tile/utility/test_fill.cpp | 156 ++++++++++++++++++ 6 files changed, 286 insertions(+), 66 deletions(-) create mode 100644 test/ck_tile/utility/test_fill.cpp diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index ac174379df..efb83efbd2 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -284,26 +284,25 @@ bool run(const ck_tile::ArgParser& arg_parser) } else if(init == 1) { - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(a_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(g_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(d_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sa_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sg_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sd_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sy_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}( - topk_weight_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(g_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(d_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sa_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sg_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sd_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sy_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(topk_weight_host); } else if(init == 2) { - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(a_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(g_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(d_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sa_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sg_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sd_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sy_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(topk_weight_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(a_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(g_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(d_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sa_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sg_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sd_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sy_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(topk_weight_host); } // permute weight diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index 44fd12e2d9..cc2c041ed6 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -71,17 +71,17 @@ int run_mx_flatmm_with_layouts(int argc, if(init_method == 0) { - ck_tile::FillUniformDistribution{0.0f, 1.0f}(a_host); - ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); - ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_a); - ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_b); + ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host); + ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b); } else if(init_method == 1) { - ck_tile::FillUniformDistribution{1.f, 1.f}(a_host); - ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); - ck_tile::FillUniformDistribution{1.f, 1.f}(scale_a); - ck_tile::FillUniformDistribution{1.f, 1.f}(scale_b); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); } else { diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 12f43ebc5e..4bbf8cbf3f 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -33,59 +33,73 @@ namespace ck_tile { * @example * * // Direct usage without creating a separate variable: - * ck_tile::FillUniformDistribution{-1.f, 1.f}(a_host_tensor); + * ck_tile::FillUniformDistribution<>{-1.f, 1.f}(a_host_tensor); */ -template +template struct FillUniformDistribution { float a_{-5.f}; float b_{5.f}; std::optional seed_{11939}; - // ATTENTION: Whether to use multi-threading (note: not guaranteed to be perfectly distributed - // across threads). - bool threaded = false; template void operator()(ForwardIter first, ForwardIter last) const { - if(threaded) - { - uint32_t num_thread = std::thread::hardware_concurrency(); - auto total = static_cast(std::distance(first, last)); - auto work_per_thread = static_cast((total + num_thread - 1) / num_thread); + if(first == last) + return; + using T_iter = std::decay_t; + static_assert(std::is_same_v || std::is_void_v, + "Iterator value type must match template type T"); + constexpr auto PackedSize = numeric_traits::PackedSize; + const auto total = static_cast(std::distance(first, last)); + const auto total_bytes = total * sizeof(T_iter); - std::vector threads(num_thread); - for(std::size_t it = 0; it < num_thread; ++it) - { - std::size_t iw_begin = it * work_per_thread; - std::size_t iw_end = std::min((it + 1) * work_per_thread, total); - auto thread_f = [this, total, iw_begin, iw_end, &first] { - if(iw_begin > total || iw_end > total) - return; - // need to make each thread unique, add an offset to current seed - std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin) - : std::random_device{}()); - std::uniform_real_distribution dis(a_, b_); - std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { - if constexpr(numeric_traits::PackedSize == 2) - return ck_tile::type_convert(fp32x2_t{dis(gen), dis(gen)}); - else - return ck_tile::type_convert(dis(gen)); - }); - }; - threads[it] = joinable_thread(thread_f); - } - } - else + // max 80 threads; at least 2MB per thread + const size_t available_cpu_cores = get_available_cpu_cores(); + const size_t num_thread = + min(80UL, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL)); + constexpr size_t BLOCK_BYTES = 64; + constexpr size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T_iter); + const size_t num_blocks = integer_divide_ceil(total_bytes, BLOCK_BYTES); + const size_t blocks_per_thread = integer_divide_ceil(num_blocks, num_thread); + + // use minstd_rand for better performance on discard() + std::minstd_rand gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::uniform_real_distribution dis(a_, b_); + + std::vector threads; + threads.reserve(num_thread - 1); // last job run in the main thread + for(int it = num_thread - 1; it >= 0; --it) { - std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); - std::uniform_real_distribution dis(a_, b_); - std::generate(first, last, [&dis, &gen]() { - if constexpr(numeric_traits::PackedSize == 2) - return ck_tile::type_convert(fp32x2_t{dis(gen), dis(gen)}); - else - return ck_tile::type_convert(dis(gen)); - }); + const size_t ib_begin = it * blocks_per_thread; + const size_t ib_end = min(ib_begin + blocks_per_thread, num_blocks); + + auto job = [=]() { + auto g_ = gen; // copy + auto d_ = dis; // copy + g_.discard(ib_begin * BLOCK_SIZE * PackedSize); + auto t_fn = [&]() { + if constexpr(PackedSize == 2) + return type_convert(fp32x2_t{d_(g_), d_(g_)}); + else + return type_convert(d_(g_)); + }; + + size_t ib = ib_begin; + for(; ib < ib_end - 1; ++ib) // full blocks + static_for<0, BLOCK_SIZE, 1>{}([&](auto iw_) { + constexpr size_t iw = iw_.value; + *(first + ib * BLOCK_SIZE + iw) = t_fn(); + }); + for(size_t iw = 0; iw < BLOCK_SIZE; ++iw) // last block + if(ib * BLOCK_SIZE + iw < total) + *(first + ib * BLOCK_SIZE + iw) = t_fn(); + }; + + if(it > 0) + threads.emplace_back(std::move(job)); + else + job(); // last job run in the main thread } } diff --git a/include/ck_tile/host/joinable_thread.hpp b/include/ck_tile/host/joinable_thread.hpp index bf84858ee2..b2e1fc4dac 100644 --- a/include/ck_tile/host/joinable_thread.hpp +++ b/include/ck_tile/host/joinable_thread.hpp @@ -3,6 +3,9 @@ #pragma once +#ifdef __linux__ +#include +#endif #include #include @@ -24,4 +27,50 @@ struct joinable_thread : std::thread this->join(); } }; + +inline unsigned int get_available_cpu_cores() +{ +#if defined(__linux__) + cpu_set_t cpu_set; + if(sched_getaffinity(0, sizeof(cpu_set_t), &cpu_set) == 0) + { + unsigned int cpu_count = CPU_COUNT(&cpu_set); + if(cpu_count > 0) + return cpu_count; + } +#endif + // Fallback if sched_getaffinity unavailable or fails + return std::thread::hardware_concurrency(); +} + +class cpu_core_guard +{ +#if defined(__linux__) + cpu_set_t original_cpu_set_; + + public: + cpu_core_guard(unsigned int num_cores) : original_cpu_set_() + { + // save original cpu set + sched_getaffinity(0, sizeof(cpu_set_t), &original_cpu_set_); + + // set new cpu set + cpu_set_t new_cpu_set; + CPU_ZERO(&new_cpu_set); + for(unsigned int i = 0; i < num_cores; ++i) + { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + CPU_SET(i, &new_cpu_set); // NOLINT(old-style-cast) +#pragma clang diagnostic pop + } + sched_setaffinity(0, sizeof(cpu_set_t), &new_cpu_set); + } + ~cpu_core_guard() + { + // restore original cpu set + sched_setaffinity(0, sizeof(cpu_set_t), &original_cpu_set_); + } +#endif +}; } // namespace ck_tile diff --git a/test/ck_tile/utility/CMakeLists.txt b/test/ck_tile/utility/CMakeLists.txt index aa15293411..01ed83841b 100644 --- a/test/ck_tile/utility/CMakeLists.txt +++ b/test/ck_tile/utility/CMakeLists.txt @@ -3,5 +3,7 @@ message("-- Adding: test/ck_tile/utility/") +add_gtest_executable(test_fill test_fill.cpp) + # Add print tests add_subdirectory(print) diff --git a/test/ck_tile/utility/test_fill.cpp b/test/ck_tile/utility/test_fill.cpp new file mode 100644 index 0000000000..18f42c4ad0 --- /dev/null +++ b/test/ck_tile/utility/test_fill.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host/fill.hpp" +#include "ck_tile/host/joinable_thread.hpp" +#include +#include +#include +#include + +using namespace ck_tile; + +namespace test { + +// Test fixture for FillUniformDistribution tests +template +class FillUniformDistributionTest : public ::testing::Test +{ + public: + static constexpr uint32_t seed = 42; + static constexpr float a = -5.0f; + static constexpr float b = 5.0f; +}; + +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(FillUniformDistributionTest, TestTypes); + +// Test that multiple runs with the same seed produce identical results +TYPED_TEST(FillUniformDistributionTest, ConsistencyWithSameSeed) +{ + using T = TypeParam; + const auto a = this->a; + const auto b = this->b; + const auto seed = this->seed; + + constexpr size_t size = 1024 * 1024 * 1024 / sizeof(T); // 1G + + std::vector vec1(size); + auto start = std::chrono::high_resolution_clock::now(); + FillUniformDistribution{a, b, seed}(vec1.begin(), vec1.end()); + auto end = std::chrono::high_resolution_clock::now(); + double sec = std::chrono::duration(end - start).count(); + std::cout << "Taking " << sec << " sec to fill 1GB of data of type " << typeid(T).name() + << std::endl; + + const auto cpu_cores = max(32U, get_available_cpu_cores()); + for(auto num_threads_diff : {-3, -1}) + { + cpu_core_guard cg(min(max(cpu_cores + num_threads_diff, 1U), get_available_cpu_cores())); + std::vector vec2(size); + FillUniformDistribution{a, b, seed}(vec2.begin(), vec2.end()); + EXPECT_EQ(0, std::memcmp(vec1.data(), vec2.data(), size * sizeof(T))) + << "First and second fill should be identical"; + } +} + +// Test consistency across different data sizes (which affects threading) +TYPED_TEST(FillUniformDistributionTest, ConsistencyAcrossSizes) +{ + using T = TypeParam; + const auto a = this->a; + const auto b = this->b; + const auto seed = this->seed; + + std::vector test_sizes = { + 100, // Small - likely single threaded + 10000, // Medium + 1000000, // Large - will use multiple threads + 5000000 // Very large - will use many threads + }; + + for(size_t size : test_sizes) + { + std::vector reference(size); + std::vector test_vec(size); + + FillUniformDistribution{a, b, seed}(reference.begin(), reference.end()); + + // Run multiple times to ensure consistency + for(int run = 0; run < 3; ++run) + { + std::fill(test_vec.begin(), test_vec.end(), T{}); + FillUniformDistribution{a, b, seed}(test_vec.begin(), test_vec.end()); + + EXPECT_EQ(0, std::memcmp(reference.data(), test_vec.data(), size * sizeof(T))) + << "Mismatch for size=" << size << " run=" << run; + } + } +} + +// Test that different seeds produce different results +TYPED_TEST(FillUniformDistributionTest, CommonPrefix) +{ + using T = TypeParam; + const auto a = this->a; + const auto b = this->b; + const auto seed = this->seed; + + std::vector test_sizes = { + 100, // Small - likely single threaded + 10000, // Medium + 1000000, // Large - will use multiple threads + 5000000 // Very large - will use many threads + }; + + auto longest = std::make_unique>(test_sizes[0]); + FillUniformDistribution{a, b, seed}(longest->begin(), longest->end()); + for(size_t i = 1; i < test_sizes.size(); ++i) + { + auto current = std::make_unique>(test_sizes[i]); + FillUniformDistribution{a, b, seed}(current->begin(), current->end()); + size_t min_size = std::min(longest->size(), current->size()); + EXPECT_EQ(0, std::memcmp(longest->data(), current->data(), min_size * sizeof(T))) + << "Different sizes with same seed should have the same prefix"; + if(current->size() > longest->size()) + { + longest = std::move(current); + } + } +} + +// Test edge cases +TYPED_TEST(FillUniformDistributionTest, EdgeCases) +{ + using T = TypeParam; + const auto a = this->a; + const auto b = this->b; + const auto seed = this->seed; + + // Empty range + std::vector empty_vec; + EXPECT_NO_THROW((FillUniformDistribution{a, b, seed}(empty_vec.begin(), empty_vec.end()))); + + // Single element + std::vector single1(1); + std::vector single2(1); + FillUniformDistribution{a, b, seed}(single1.begin(), single1.end()); + FillUniformDistribution{a, b, seed}(single2.begin(), single2.end()); + + EXPECT_EQ(0, std::memcmp(single1.data(), single2.data(), sizeof(T))) + << "Single element should be consistent"; + + // Small sizes that might affect threading decisions + std::vector small_sizes = {2, 3, 7, 15, 16, 17, 31, 32, 33, 63, 64, 65}; + for(size_t size : small_sizes) + { + std::vector vec1(size); + std::vector vec2(size); + FillUniformDistribution{a, b, seed}(vec1.begin(), vec1.end()); + FillUniformDistribution{a, b, seed}(vec2.begin(), vec2.end()); + + EXPECT_EQ(0, std::memcmp(vec1.data(), vec2.data(), size * sizeof(T))) + << "Edge case failed for size=" << size; + } +} +} // namespace test From 6f0966e1e9fca5c513d16a729237d676b583e266 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Tue, 9 Dec 2025 17:54:55 +0800 Subject: [PATCH 20/24] fix a16w4 moe bugs (#3373) * fix valid mask bug * update format --- include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index b3b34a6da0..7104547247 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1259,12 +1259,12 @@ struct MoeFlatmmKernel auto fused_token = kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0] - index_t scatter_token_id = fused_token & token_id_mask; + index_t scatter_token_id = fused_token & token_id_mask; + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); if constexpr(IsInputGemm) scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; - c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); }); }); From 50ca3f83ebc08ffe8946c3668fd879e3b2043ef7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 9 Dec 2025 07:10:34 -0800 Subject: [PATCH 21/24] Bump rocm-docs-core[api_reference] from 1.20.1 to 1.31.0 in /docs/sphinx (#3374) Bumps [rocm-docs-core[api_reference]](https://github.com/ROCm/rocm-docs-core) from 1.20.1 to 1.31.0. - [Release notes](https://github.com/ROCm/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/v1.31.0/CHANGELOG.md) - [Commits](https://github.com/ROCm/rocm-docs-core/compare/v1.20.1...v1.31.0) --- updated-dependencies: - dependency-name: rocm-docs-core[api_reference] dependency-version: 1.31.0 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index beedb4e867..b607daa9ff 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core[api_reference]==1.20.1 +rocm-docs-core[api_reference]==1.31.0 sphinxcontrib-bibtex==2.6.5 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index e8aa02aa01..fce859cf0e 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -237,7 +237,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core[api-reference]==1.20.1 +rocm-docs-core[api-reference]==1.31.0 # via -r requirements.in rpds-py==0.24.0 # via From 7582c9e73fc3e580a2255988310cb25391f80162 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 9 Dec 2025 07:35:32 -0800 Subject: [PATCH 22/24] Upgrade to ROCm7.1.1 as default compiler. (#3370) * upgrade to rocm7.1.1 as new default compiler * fix jenkinsfile --- Dockerfile | 6 +++--- Dockerfile.compiler | 2 +- Jenkinsfile | 8 ++++---- python/ck4inductor/__init__.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index 07327442fe..973dcedcb5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ FROM ubuntu:24.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=7.0.1 +ARG ROCMVERSION=7.1.1 ARG compiler_version="" ARG compiler_commit="" ARG CK_SCCACHE="" @@ -13,8 +13,8 @@ ENV DEBIAN_FRONTEND=noninteractive RUN set -xe && \ apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl -RUN wget https://repo.radeon.com/amdgpu-install/7.0.1/ubuntu/noble/amdgpu-install_7.0.1.70001-1_all.deb && \ - apt install ./amdgpu-install_7.0.1.70001-1_all.deb -y && \ +RUN wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \ + apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \ apt update && \ apt install python3-setuptools python3-wheel -y && \ apt install rocm-dev -y diff --git a/Dockerfile.compiler b/Dockerfile.compiler index 47bd8294b6..0e2219b7ff 100644 --- a/Dockerfile.compiler +++ b/Dockerfile.compiler @@ -1,4 +1,4 @@ -ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.0.1" +ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.1.1" FROM $BASE_DOCKER ARG compiler_version="" ARG compiler_commit="" diff --git a/Jenkinsfile b/Jenkinsfile index 45fd576ab6..b8c570b936 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -288,7 +288,7 @@ def getBaseDockerImageName(){ } else{ def ROCM_numeric = parseVersion("${params.ROCMVERSION}") - if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 1 ){ + if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 2 ){ img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}" } else{ @@ -434,7 +434,7 @@ def buildDocker(install_prefix){ } catch(Exception ex){ echo "Unable to locate image: ${image_name}. Building image now" - retimage = docker.build("${image_name}", dockerArgs + ' .') + retimage = docker.build("${image_name}", dockerArgs) withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.push() } @@ -1121,8 +1121,8 @@ pipeline { description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', - defaultValue: '7.0.1', - description: 'Specify which ROCM version to use: 7.0.1 (default).') + defaultValue: '7.1.1', + description: 'Specify which ROCM version to use: 7.1.1 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', diff --git a/python/ck4inductor/__init__.py b/python/ck4inductor/__init__.py index 0eee25ecaa..089a2d439b 100644 --- a/python/ck4inductor/__init__.py +++ b/python/ck4inductor/__init__.py @@ -6,7 +6,7 @@ def __version__(): import subprocess # needs to be manually updated - rocm_version = "7.0.1" + rocm_version = "7.1.1" hash_width = 6 try: hash = subprocess.check_output("git rev-parse HEAD", shell=True, text=True)[ From 0d8259affd4f59eb8b1143b658d83d3800270f43 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 9 Dec 2025 10:37:13 -0800 Subject: [PATCH 23/24] temporarily disable daily builds on gfx1010 and gfx908 (#3384) --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index b8c570b936..3f94820095 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1095,7 +1095,7 @@ def run_pytorch_tests(Map conf=[:]){ //launch develop branch daily jobs CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_PERFORMANCE_TESTS=true;FORCE_CI=true 0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true - 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX101=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX101=false;BUILD_GFX908=false;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true From 934ba1208ab7cfc82c20f73b14994b64c3843d2d Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 9 Dec 2025 14:39:08 -0800 Subject: [PATCH 24/24] use hipTensor from monorepo for daily builds (#3386) --- Jenkinsfile | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 3f94820095..5f03310cab 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -834,12 +834,14 @@ def Build_CK(Map conf=[:]){ if (params.hipTensor_test && arch == "gfx90a" ){ // build and test hipTensor on gfx90a node sh """#!/bin/bash - rm -rf "${params.hipTensor_branch}".zip - rm -rf hipTensor-"${params.hipTensor_branch}" - wget https://github.com/ROCm/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip - unzip -o "${params.hipTensor_branch}".zip + rm -rf rocm-libraries + git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git + cd rocm-libraries + git sparse-checkout init --cone + git sparse-checkout set projects/hiptensor + git checkout "${params.hipTensor_branch}" """ - dir("hipTensor-${params.hipTensor_branch}"){ + dir("rocm-libraries/projects/hiptensor"){ sh """#!/bin/bash mkdir -p build ls -ltr