diff --git a/.gitignore b/.gitignore index 2641a661d8..d8468cf24e 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,9 @@ tags # Editors .vscode +# CMake formatting configuration (local) +.cmake-format.yaml + # Cline .cline* diff --git a/CHANGELOG.md b/CHANGELOG.md index b07e322fe1..15fdb09f49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,16 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). +## (Unreleased) Composable Kernel 1.3.0 + +### Added +* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. +* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". + +### Changed + +### Upcoming changes + ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d0c4d79f9..acae1f5ece 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,6 +92,10 @@ if (DTYPES) add_definitions(-DCK_ENABLE_FP32) set(CK_ENABLE_FP32 "ON") endif() + if (DTYPES MATCHES "tf32") + # definition will be added based on the GPU target in the following section + set(CK_ENABLE_TF32 "ON") + endif() if (DTYPES MATCHES "fp64") add_definitions(-DCK_ENABLE_FP64) set(CK_ENABLE_FP64 "ON") @@ -106,6 +110,7 @@ else() set(CK_ENABLE_INT8 "ON") set(CK_ENABLE_FP16 "ON") set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_TF32 "ON") set(CK_ENABLE_FP64 "ON") set(CK_ENABLE_BF16 "ON") set(CK_ENABLE_FP8 "ON") @@ -282,6 +287,15 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") set(CK_GFX950_SUPPORT "ON") endif() +if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32) + add_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "ON") +else() + message(STATUS "Disabling TF32 instances") + remove_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "OFF") +endif() + option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH) @@ -651,6 +665,9 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") set(add_inst 1) endif() + if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32") + set(add_inst 1) + endif() if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") set(add_inst 1) endif() diff --git a/Dockerfile b/Dockerfile index 07327442fe..973dcedcb5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ FROM ubuntu:24.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=7.0.1 +ARG ROCMVERSION=7.1.1 ARG compiler_version="" ARG compiler_commit="" ARG CK_SCCACHE="" @@ -13,8 +13,8 @@ ENV DEBIAN_FRONTEND=noninteractive RUN set -xe && \ apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl -RUN wget https://repo.radeon.com/amdgpu-install/7.0.1/ubuntu/noble/amdgpu-install_7.0.1.70001-1_all.deb && \ - apt install ./amdgpu-install_7.0.1.70001-1_all.deb -y && \ +RUN wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \ + apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \ apt update && \ apt install python3-setuptools python3-wheel -y && \ apt install rocm-dev -y diff --git a/Dockerfile.compiler b/Dockerfile.compiler index 47bd8294b6..0e2219b7ff 100644 --- a/Dockerfile.compiler +++ b/Dockerfile.compiler @@ -1,4 +1,4 @@ -ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.0.1" +ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.1.1" FROM $BASE_DOCKER ARG compiler_version="" ARG compiler_commit="" diff --git a/Dockerfile.pytorch b/Dockerfile.pytorch index 4533166c06..2d3856fa2d 100644 --- a/Dockerfile.pytorch +++ b/Dockerfile.pytorch @@ -29,4 +29,4 @@ RUN groupadd -g 109 render && \ git sparse-checkout set projects/hipblaslt shared/origami && \ cd projects/hipblaslt && \ git show --oneline -s && \ - CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --logic-yaml-filter gfx950/*/* --architecture="gfx942;gfx950" -j 128 --skip_rocroller + CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx942;gfx950" -j 128 --skip_rocroller diff --git a/Jenkinsfile b/Jenkinsfile index 45fd576ab6..5f03310cab 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -288,7 +288,7 @@ def getBaseDockerImageName(){ } else{ def ROCM_numeric = parseVersion("${params.ROCMVERSION}") - if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 1 ){ + if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 2 ){ img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}" } else{ @@ -434,7 +434,7 @@ def buildDocker(install_prefix){ } catch(Exception ex){ echo "Unable to locate image: ${image_name}. Building image now" - retimage = docker.build("${image_name}", dockerArgs + ' .') + retimage = docker.build("${image_name}", dockerArgs) withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.push() } @@ -834,12 +834,14 @@ def Build_CK(Map conf=[:]){ if (params.hipTensor_test && arch == "gfx90a" ){ // build and test hipTensor on gfx90a node sh """#!/bin/bash - rm -rf "${params.hipTensor_branch}".zip - rm -rf hipTensor-"${params.hipTensor_branch}" - wget https://github.com/ROCm/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip - unzip -o "${params.hipTensor_branch}".zip + rm -rf rocm-libraries + git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git + cd rocm-libraries + git sparse-checkout init --cone + git sparse-checkout set projects/hiptensor + git checkout "${params.hipTensor_branch}" """ - dir("hipTensor-${params.hipTensor_branch}"){ + dir("rocm-libraries/projects/hiptensor"){ sh """#!/bin/bash mkdir -p build ls -ltr @@ -1095,7 +1097,7 @@ def run_pytorch_tests(Map conf=[:]){ //launch develop branch daily jobs CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_PERFORMANCE_TESTS=true;FORCE_CI=true 0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true - 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX101=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX101=false;BUILD_GFX908=false;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true @@ -1121,8 +1123,8 @@ pipeline { description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', - defaultValue: '7.0.1', - description: 'Specify which ROCM version to use: 7.0.1 (default).') + defaultValue: '7.1.1', + description: 'Specify which ROCM version to use: 7.1.1 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', diff --git a/README.md b/README.md index 01d523c2ab..8a5258bab6 100644 --- a/README.md +++ b/README.md @@ -187,7 +187,7 @@ limit the number of threads. For example, if you have a 128-core CPU and 128 Gb Additional cmake flags can be used to significantly speed-up the build: -* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build +* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build instances of select data types only. The main default data types are fp32 and fp16; you can safely skip other data types. diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 2ed338d08a..cab84f5c6c 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -27,6 +27,9 @@ if (DTYPES) add_definitions(-DCK_ENABLE_FP32) set(CK_ENABLE_FP32 "ON") endif() + if (DTYPES MATCHES "tf32") + set(CK_ENABLE_TF32 "ON") + endif() if (DTYPES MATCHES "fp64") add_definitions(-DCK_ENABLE_FP64) set(CK_ENABLE_FP64 "ON") @@ -41,6 +44,7 @@ else() set(CK_ENABLE_INT8 "ON") set(CK_ENABLE_FP16 "ON") set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_TF32 "ON") set(CK_ENABLE_FP64 "ON") set(CK_ENABLE_BF16 "ON") if (GPU_TARGETS MATCHES "gfx94") @@ -67,6 +71,14 @@ if (GPU_TARGETS) add_definitions(-DCK_USE_FNUZ_FP8) set(CK_USE_FNUZ_FP8 "ON") endif() + if ((GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32) + add_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "ON") + else() + message(STATUS "Disabling TF32 instances for this target") + remove_definitions(-DCK_ENABLE_TF32) + set(CK_ENABLE_TF32 "OFF") + endif() else() add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) set(CK_USE_XDL "ON") diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index beedb4e867..b607daa9ff 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core[api_reference]==1.20.1 +rocm-docs-core[api_reference]==1.31.0 sphinxcontrib-bibtex==2.6.5 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index e8aa02aa01..fce859cf0e 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -237,7 +237,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core[api-reference]==1.20.1 +rocm-docs-core[api-reference]==1.31.0 # via -r requirements.in rpds-py==0.24.0 # via diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 0c8102a70b..6e7d69281d 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -208,40 +208,6 @@ add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp) target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -# add fmha_fwd_v3 example -set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3") -message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") - -add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp) -target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS - "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" -) -target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE - fmha_fwd_v3.cpp - ${FMHA_FWD_V3_INSTANCES} -) - -set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS) -list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS - -fgpu-flush-denormals-to-zero - -Wno-undefined-func-template - --save-temps -) -set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) - -check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) -if(HAS_DISABLE_PACKED_FP32) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS - -mllvm --amdgpu-disable-packed-fp32=1 - ) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS - -DCK_TILE_DISABLE_PACKED_FP32=1 - ) -endif() - -target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) -target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 333579ec8d..a3cfe2622a 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -30,16 +30,24 @@ _MASK_MAP = { } -def get_mask_map(mask: str): - if mask == "generic": +def get_mask_map(mask_impl: str): + if mask_impl == "generic": return _MASK_MAP - elif mask == "simplified": + elif mask_impl == "simplified": return _MASK_SIMPLIFIED_MAP else: assert False return None +def get_mask_impl(mask: str) -> str: + return "simplified" if mask.startswith("s_") else "generic" + + +def get_mask_cpp_type(mask: str) -> str: + return get_mask_map(get_mask_impl(mask))[mask] + + _MASK_CHECK_MAP = { "no": "t.mask_type == mask_enum::no_mask", "causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", @@ -62,6 +70,10 @@ def get_mask_check_map(mask: str): return None +def get_mask_cpp_check_expr(mask: str) -> str: + return get_mask_check_map(get_mask_impl(mask))[mask] + + QSCALE_MAP = { "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", @@ -122,6 +134,7 @@ PIPELINE_MAP = { "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs": "ck_tile::BlockFmhaPipelineQSKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", + "qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline", } PIPELINE_ENUM_MAP = { @@ -131,6 +144,7 @@ PIPELINE_ENUM_MAP = { "qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS", "qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", + "qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 17d4f6e1d7..c00bdcea3b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -8,14 +8,13 @@ import os from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path -from typing import List, Optional, Tuple +from typing import Callable, ClassVar, Iterable, List, Optional, Tuple from codegen.arch import ArchTrait, get_factories_for_targets from codegen.cmake_config import GEN_DIR from codegen.cpp_symbol_map import ( LAYOUT_MAP, BIAS_CHECK_MAP, - get_mask_check_map, BOOL_MAP, PIPELINE_MAP, PIPELINE_ENUM_MAP, @@ -23,6 +22,8 @@ from codegen.cpp_symbol_map import ( FWD_DTYPE_MAP, BIAS_MAP, get_mask_map, + get_mask_cpp_type, + get_mask_cpp_check_expr, QSCALE_CHECK_MAP, QSCALE_MAP, ) @@ -48,79 +49,79 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT #include "fmha_fwd.hpp" """ -FMHA_FWD_KERNEL_BODY = """ +FMHA_FWD_KERNEL_BODY_TEMPLATE = """ #include #if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check}) -using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_dtype = {F_dtype}; -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; +using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; +using fmha_shape = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_logits}, - {F_bias}, - false, - {F_lse}, - {F_dropout}, - {F_qscale}, - {F_occupancy}, - {F_skip}>; +using fmha_traits = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_qscale}, + {F_occupancy}, + {F_skip}>; -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; +using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; -using fmha_mask_{F_idx} = {F_mask}; +using fmha_mask = {F_mask}; -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, +using fmha_pipeline_problem = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape, {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, + fmha_variant, + fmha_mask, {F_trload}, - fmha_trait_{F_idx}>; + fmha_traits>; -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; +using fmha_pipeline = {F_pipeline}< + fmha_pipeline_problem>; -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + {F_spad}, {F_dvpad}>>; -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel; +using fmha_kernel = {F_kernel}; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + +using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; template<> -float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ - using k_ = fmha_kernel_{F_idx}; + using k_ = fmha_kernel; if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + std::cout << ", {F_kname}" << std::flush; + auto [kargs, grids] = {F_kargs_creator}(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); @@ -130,40 +131,47 @@ float fmha_fwd_(const ck_tile::stream_config& s, fm """ FMHA_FWD_API_FILENAME = "fmha_fwd_api.cpp" -FMHA_FWD_API = """ +FMHA_FWD_API_HEADER = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py #include #include -namespace {{ -bool get_num_cus(unsigned& num_cus) {{ +#include "fmha_fwd.hpp" + +namespace { +bool get_num_cus(unsigned& num_cus) { int device; auto status = hipGetDevice(&device); - if(status != hipSuccess) {{ + if(status != hipSuccess) { fprintf(stderr, "failed to get device"); return false; - }} + } - hipDeviceProp_t props{{}}; + hipDeviceProp_t props{}; status = hipGetDeviceProperties(&props, device); - if(status != hipSuccess) {{ + if(status != hipSuccess) { fprintf(stderr, "failed to get device properties"); return false; - }} + } num_cus = props.multiProcessorCount; return true; -}} +} -unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) { const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 return batch * nheads * num_m_blocks * num_n_blocks; -}} -}} // namespace - -float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s) {{ +} +} // namespace +""" +FMHA_FWD_API_FUNC_TEMPLATE = """ +namespace {{ +float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{ float r = -1; [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate @@ -182,6 +190,28 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& {F_dispatch} return r; }} +}} // namespace +""" +FMHA_FWD_API_FOOTER_TEMPLATE = """ +float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{ + const std::string device_name = ck_tile::get_device_name(); + + const bool is_swa = (traits.mask_type != mask_enum::no_mask) and + ((0 < args.window_size_left) or (0 < args.window_size_right)); + const bool can_dispatch_v3 = + (device_name.compare(0, 6, "gfx950") == 0) and + (traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and + traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and + (traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and + (not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and + (not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and + (args.hdim_v == 128); + if ({F_is_v3_enabled} and can_dispatch_v3) {{ + return fmha_fwd_v3(traits, args, config); + }} else {{ + return fmha_fwd_v2(traits, args, config); + }} +}} """ FMHA_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{ @@ -261,7 +291,7 @@ class FmhaFwdApiTrait: def scheck(self) -> str: if self.mode == "group": return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true - if self.pipeline_tag in ["qr_async", "qr_async_trload"]: + if self.pipeline_tag in ["qr_async", "qr_async_trload", "qr_async_trload_v3"]: if self.spad == "t": return "true" # always support else: @@ -294,7 +324,7 @@ class FmhaFwdApiTrait: return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) else: return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" - elif self.pipeline_tag == "qr_async_trload": + elif self.pipeline_tag in ["qr_async_trload", "qr_async_trload_v3"]: if self.skpad == "t": return "true" else: @@ -310,7 +340,7 @@ class FmhaFwdApiTrait: return f"a.hdim_q % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == "t": return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -327,7 +357,7 @@ class FmhaFwdApiTrait: return f"a.hdim_v % {vec} == 0" else: assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == "t": return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) @@ -429,9 +459,8 @@ class FmhaFwdPipeline: class FmhaFwdApiPool: - def __init__(self, mask_impl): + def __init__(self): self.pool = OrderedDict() - self.mask_impl = mask_impl def register_traits(self, trait: FmhaFwdApiTrait) -> None: hdim = trait.hdim, trait.bn1 @@ -443,19 +472,60 @@ class FmhaFwdApiPool: check_duplicates_and_paddings(ts, trait) ts.append(copy.copy(trait)) - @property - def api(self) -> str: + def get_num_traits( + self, filter_fn: Optional[Callable[[FmhaFwdApiTrait], bool]] = None + ) -> int: + if filter_fn is None: + + def accept_all(trait: FmhaFwdApiTrait) -> bool: + return True + + filter_fn = accept_all + + return sum( + sum(1 for trait in pool_by_hdim if filter_fn(trait)) + for pool_by_arch in self.pool.values() + for pool_by_dtype in pool_by_arch.values() + for pool_by_hdim in pool_by_dtype.values() + ) + + def render( + self, func_name, filter_fn: Optional[Callable[[FmhaFwdApiTrait], bool]] = None + ) -> str: + if filter_fn is None: + + def accept_all(trait: FmhaFwdApiTrait) -> bool: + return True + + filter_fn = accept_all + + def has_traits(node) -> bool: + """Recursively traverse nested OrderedDicts and lists to determine if any FmhaFwdApiTrait satisfies filter_fn().""" + if isinstance(node, list): + return any(filter_fn(elem) for elem in node) + elif isinstance(node, OrderedDict): + return any(has_traits(val) for val in node.values()) + return False + per_arch = str() - for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()): + for i_arch, (arch, pool_by_arch) in enumerate( + item for item in self.pool.items() if has_traits(item[1]) + ): per_dtypes = str() - for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()): + for i_dtype, (dtype, pool_by_dtype) in enumerate( + item for item in pool_by_arch.items() if has_traits(item[1]) + ): per_hdim_case = str() for i_hdim, ((hdim, hdim_v), pool_by_hdim) in enumerate( - pool_by_dtype.items() + item for item in pool_by_dtype.items() if has_traits(item[1]) ): - max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0) + max_bm0 = max( + (t.bm0 for t in pool_by_hdim if filter_fn(t)), default=0 + ) inners = str() - for i_trait, trait in enumerate(pool_by_hdim): + for i_trait, trait in enumerate( + [trait for trait in pool_by_hdim if filter_fn(trait)] + ): inners += FMHA_FWD_API_INNER_DISPATCH.format( F_if=if_(i_trait), F_arch=arch, @@ -463,8 +533,8 @@ class FmhaFwdApiPool: F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], - F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_mask=get_mask_cpp_type(trait.mask), + F_mask_check=get_mask_cpp_check_expr(trait.mask), F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], @@ -506,10 +576,9 @@ class FmhaFwdApiPool: F_arch=arch, F_dtype_case=indent(per_dtypes), ) - if not per_arch: - # empty string we add some ignore to suppress warning in api - per_arch = "(void)t; (void)s; (void)a;" - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=indent(per_arch)) + return FMHA_FWD_API_FUNC_TEMPLATE.format( + F_func_name=func_name, F_dispatch=indent(per_arch) + ) @dataclass @@ -548,18 +617,32 @@ class FmhaFwdTileSize: @dataclass class FmhaFwdKernel: F_arch: ArchTrait - F_idx: int # this is not a tunable, but a counter to differentiate symbol F_hdim: int # hdim F_dtype: str # data type F_mode: str # value from MODE_MAP F_tile: FmhaFwdTileSize F_pipeline: FmhaFwdPipeline - mask_impl: str - @property - def template(self) -> str: - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( - F_idx=self.F_idx, + _KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER + _KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE + + @classmethod + def _get_cpp_kernel_class_name(cls, pipeline_tag): + if pipeline_tag == "qr_async_trload_v3": + return "ck_tile::FmhaFwdV3Kernel" + else: + return "ck_tile::FmhaFwdKernel" + + @classmethod + def _get_cpp_kargs_creator_func_name(cls, pipeline_tag): + if pipeline_tag == "qr_async_trload_v3": + return "fmha_fwd_v3_create_kargs_and_grids" + else: + return "fmha_fwd_create_kargs_and_grids" + + def render(self) -> str: + return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format( + F_kname=self.name, F_arch=self.F_arch, F_hdim=self.F_hdim, F_dtype=FWD_DTYPE_MAP[self.F_dtype], @@ -594,10 +677,12 @@ class FmhaFwdKernel: F_skip=BOOL_MAP[self.F_pipeline.F_skip], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mask=get_mask_cpp_type(self.F_pipeline.F_mask), F_mode=MODE_MAP[self.F_mode], - F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag), + F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag), ) @property @@ -644,16 +729,179 @@ class FmhaFwdKernel: ) -class KernelComponentFactoryGfx9: +@dataclass +class ProblemContext: + dtype: str + mode: str + hdim: int + hdim_v: int + + +@dataclass +class KernelContext: + tile: FmhaFwdTileSize + pipeline: FmhaFwdPipeline + mask_impl: str + + +CompatibilityRule = Callable[[ProblemContext, KernelContext], bool] + + +def is_compatible( + problem_ctx: ProblemContext, + kernel_ctx: KernelContext, + rules: Iterable[CompatibilityRule], +) -> bool: + return all(rule(problem_ctx, kernel_ctx) for rule in rules) + + +def create_kernel( + arch: ArchTrait, problem_ctx: ProblemContext, kernel_ctx: KernelContext +) -> FmhaFwdKernel: + return FmhaFwdKernel( + F_arch=arch, + F_dtype=problem_ctx.dtype, + F_mode=problem_ctx.mode, + F_hdim=problem_ctx.hdim, + F_tile=kernel_ctx.tile, + F_pipeline=kernel_ctx.pipeline, + ) + + +class CompatibilityRuleFactory: + @staticmethod + def get_rules() -> list[CompatibilityRule]: + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + if problem_ctx.mode == "group": + if ( + kernel_ctx.pipeline.F_spad != "t" + or kernel_ctx.pipeline.F_skpad != "t" + ): + return False + return True + + def check_hdim(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if (problem_ctx.hdim, problem_ctx.hdim_v) == (192, 128): + if ( + kernel_ctx.pipeline.F_bias != "no" + or kernel_ctx.pipeline.F_dropout == "t" + ): + False + return True + + def check_feature( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + # logits_soft_cap is only allowed if no bias + if not ( + ( + kernel_ctx.pipeline.F_logits == "t" + and kernel_ctx.pipeline.F_bias == "no" + ) + or kernel_ctx.pipeline.F_logits == "f" + ): + return False + return True + + return [check_mode, check_hdim, check_feature] + + +class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): + _AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"}) + + @classmethod + def get_rules(cls) -> list[CompatibilityRule]: + rules = CompatibilityRuleFactory.get_rules() + + def check_hdim_tile( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + if problem_ctx.dtype != "fp32": + # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support + if kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES and ( + ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.tile.F_bn0 != 128 + ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128) + and kernel_ctx.tile.F_bm0 != 128 + ) + ): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + return False + return True + + rules.append(check_hdim_tile) + return rules + + +class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9): + _AVAILABLE_PIPELINES = ( + CompatibilityRuleFactoryGfx9._AVAILABLE_PIPELINES + | frozenset({"qr_async_trload", "qr_async_trload_v3"}) + ) + + @classmethod + def get_rules(cls) -> list[CompatibilityRule]: + rules = CompatibilityRuleFactoryGfx9.get_rules() + + def check_tile_pipeline( + problem_ctx: ProblemContext, kernel_ctx: KernelContext + ) -> bool: + if kernel_ctx.pipeline.tag == "qr_async_trload" and ( + ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.tile.F_bn0 == 128 + ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) not in [(64, 64), (128, 128)] + ) + ): + return False + + # only qr_async_trload_v3 use km0=256 & 8-warps + is_v3_dedicated_tile = ( + kernel_ctx.tile.F_bm0 == 256 + and (kernel_ctx.tile.F_rm0 * kernel_ctx.tile.F_rn0 * kernel_ctx.tile.F_rk0) == 8 + and (kernel_ctx.tile.F_rm1 * kernel_ctx.tile.F_rn1 * kernel_ctx.tile.F_rk1) == 8 + ) # fmt: skip + is_v3_pipeline = kernel_ctx.pipeline.tag == "qr_async_trload_v3" + return is_v3_dedicated_tile == is_v3_pipeline + + rules.extend([check_tile_pipeline]) + return rules + + +class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): arch = ArchTrait( "gfx9", preprocessor_check="defined(__gfx9__) && !defined(__gfx950__)" ) + _DT_FP32 = ("fp32",) + _DT_FP16_BF16 = ("fp16", "bf16") + _DT_FP8 = ("fp8",) + _DT_FP8BF16 = ("fp8bf16",) + _DT_FP8FP32 = ("fp8fp32",) + + @classmethod + def supported_dtypes(cls) -> Tuple[str]: + return ( + cls._DT_FP32 + + cls._DT_FP16_BF16 + + cls._DT_FP8 + + cls._DT_FP8BF16 + + cls._DT_FP8FP32 + ) + # TODO: design a more practical way to do it # this is current supported tile size per hdim - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype in ["fp32"]: + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_FP32: return { # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], @@ -666,7 +914,7 @@ class KernelComponentFactoryGfx9: (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp16", "bf16"]: + elif dtype in cls._DT_FP16_BF16: return { ( 32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], ( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), @@ -682,30 +930,32 @@ class KernelComponentFactoryGfx9: (192, 192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } # fmt: skip - elif dtype in ["fp8", "fp8bf16"]: + elif dtype in cls._DT_FP8 or dtype in cls._DT_FP8BF16: return { ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (256, 256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip - elif dtype in ["fp8fp32"]: + elif dtype in cls._DT_FP8FP32: return { (128, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip else: - return None + raise ValueError(f"unsupported dtype={dtype}") # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let "t" padding to appear later!! # TODO: how to design this more generic? pipelines = [] - if dtype in ["fp32"]: + if dtype in cls._DT_FP32: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -718,7 +968,7 @@ class KernelComponentFactoryGfx9: pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - elif dtype in ["fp16", "bf16"]: + elif dtype in cls._DT_FP16_BF16: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -743,7 +993,7 @@ class KernelComponentFactoryGfx9: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip - elif dtype in ["fp8bf16", "fp8fp32"]: + elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( ["f"], @@ -755,21 +1005,33 @@ class KernelComponentFactoryGfx9: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip elif dtype in ["fp8", "fp8fp16", "bf8"]: # TODO - None - else: - assert False + pass return pipelines -class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): +class KernelComponentFactoryGfx950( + KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx950 +): arch = ArchTrait("gfx950") - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) + if dtype in cls._DT_FP16_BF16: + # add tile for qr_async_trload_v3 + if (128, 128) in result.keys(): + result[(128, 128)].append( + FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip + return result + + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: pipelines = KernelComponentFactoryGfx9.get_pipelines( dtype, hdim, hdim_v, receipt, mask_impl ) - if dtype in ["fp16", "bf16"]: + if dtype in cls._DT_FP16_BF16: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -788,15 +1050,31 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): ): pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip + + # qr_async_trload_v3 only supports hdim=hdim_v=128 for now + if (hdim, hdim_v) == (128, 128): + # qr_async_trload_v3 only supports (generic) causal mask + for mask in ["no", "causal"]: + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", + F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t")) # fmt: skip + return pipelines -class KernelComponentFactoryGfx12: +class KernelComponentFactoryGfx12(CompatibilityRuleFactory): arch = ArchTrait("gfx12") - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: - if dtype in ["fp16", "bf16"]: + _DT_FP16_BF16 = ("fp16", "bf16") + _DT_FP8_FP8BF16 = ("fp8", "fp8bf16") + _DT_FP8FP32 = ("fp8fp32",) + + @classmethod + def supported_dtypes(cls) -> Tuple[str]: + return cls._DT_FP16_BF16 + cls._DT_FP8_FP8BF16 + cls._DT_FP8FP32 + + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_FP16_BF16: return { # bm0, bn0, bk0, bn1, bk1, ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], @@ -805,25 +1083,27 @@ class KernelComponentFactoryGfx12: (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp8", "fp8bf16"]: + elif dtype in cls._DT_FP8_FP8BF16: return { # bm0, bn0, bk0, bn1, bk1, ( 64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], (256, 256) : [FmhaFwdTileSize( 64, 32, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip - elif dtype in ["fp8fp32"]: + elif dtype in cls._DT_FP8FP32: return { # bm0, bn0, bk0, bn1, bk1, (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip else: - return None + raise ValueError(f"unsupported dtype={dtype}") - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: pipelines = [] - if dtype in ["fp16", "bf16"]: + if dtype in cls._DT_FP16_BF16: qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], @@ -835,23 +1115,21 @@ class KernelComponentFactoryGfx12: ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip - elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: + elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip - else: - assert False return pipelines -class CustomFactory(KernelComponentFactoryGfx9): - @staticmethod - def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: +class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9): + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype) - if dtype == "fp16" or dtype == "bf16": + if dtype in cls._DT_FP16_BF16: if (128, 128) in result.keys(): result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip return result @@ -874,150 +1152,162 @@ def get_factory(target: str): raise Exception(f"Unsupported device target {target}") +@dataclass(frozen=True) +class Product: + name: str + rule: CompatibilityRule + + def __call__(self, problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + return self.rule(problem_ctx, kernel_ctx) + + +def get_product(receipt: int) -> Product: + # Flash attention integration + if receipt in (2, 3): + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= kernel_ctx.pipeline.F_skip == "f" + return cond + + return Product(name="Flash attention integration", rule=fit) + # PyTorch integration + elif receipt == 4: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_bias in ["no", "bias"] + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + return cond + + return Product(name="PyTorch integration", rule=fit) + # Aiter(mha_fwd) integration + elif receipt == 100: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="Aiter(mha_fwd) integration", rule=fit) + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= problem_ctx.mode == "group" + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="Aiter(mha_varlen_fwd) integration", rule=fit) + # aiter::mha_fwd C++ api integration + elif receipt == 600: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16", "fp8bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + if problem_ctx.dtype == "fp8bf16": + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="aiter::mha_fwd C++ api integration", rule=fit) + elif receipt == 888: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp8bf16", "fp8fp32"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= problem_ctx.hdim == 128 or problem_ctx.hdim == 192 + return cond + + return Product(name="receipt = 888", rule=fit) + # fp32 only, all variations + elif receipt == 800: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype == "fp32" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + return cond + + return Product(name="fp32 only, all variations", rule=fit) + # fp32 only, minimal set of parameters + elif receipt == 801: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype == "fp32" + cond &= problem_ctx.hdim in [48, 128] + cond &= problem_ctx.mode == "batch" + cond &= kernel_ctx.pipeline.F_bias == "no" + cond &= kernel_ctx.pipeline.F_lse == "f" + cond &= kernel_ctx.pipeline.F_dropout == "f" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + cond &= kernel_ctx.pipeline.F_mask == "s_no" + return cond + + return Product(name="fp32 only, minimal set of parameters", rule=fit) + # Don't build fp32 by default + else: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + return problem_ctx.dtype != "fp32" + + return Product(name="Default", rule=fit) + + def get_fwd_blobs( targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() - api_pool = FmhaFwdApiPool(mask_impl) + api_pool = FmhaFwdApiPool() factories = get_factories_for_targets(targets, get_factory) - for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()): + for factory, dtype in ((f, t) for f in factories for t in f.supported_dtypes()): d = factory.get_hdim_tile_size_dict(dtype) - if d is None: - continue # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for ((hdim, hdim_v), tiles), mode in itertools.product( d.items(), MODE_MAP.keys() ): + if optdim_list != [-1]: + if hdim not in optdim_list: + continue for tile, next_tile in zip(tiles, tiles[1:]): assert next_tile.F_bm0 >= tile.F_bm0, ( "Tiles must be ordered by increasing bm0" ) + for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if mode == "group": - if pipeline.F_spad != "t" or pipeline.F_skpad != "t": - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - if (hdim, hdim_v) == (192, 128): - # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != "no" or pipeline.F_dropout == "t": - continue - if factory.arch.name.startswith("gfx9") and dtype != "fp32": - # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support - if pipeline.tag != "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) - or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128) - ): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) - or ((hdim, hdim_v) not in [(64, 64), (128, 128)]) - ): - continue - # logits_soft_cap is only allowed if no bias - if not ( - (pipeline.F_logits == "t" and pipeline.F_bias == "no") - or pipeline.F_logits == "f" - ): - continue - k = FmhaFwdKernel( - F_arch=factory.arch, - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, + problem_ctx = ProblemContext( + dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v ) + kernel_ctx = KernelContext( + tile=tile, pipeline=pipeline, mask_impl=mask_impl + ) + rules = factory.get_rules() + product = get_product(receipt) + + if not is_compatible(problem_ctx, kernel_ctx, [*rules, product]): + continue + + k = create_kernel(factory.arch, problem_ctx, kernel_ctx) if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_qscale == "no" - cond &= pipeline.F_skip == "f" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_qscale == "no" - cond &= mode == "batch" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - # aiter::mha_fwd C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= pipeline.F_vlayout == "row" - if dtype == "fp8bf16": - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - elif receipt == 888: - cond = dtype in ["fp8bf16", "fp8fp32"] - cond &= pipeline.F_vlayout == "row" - cond &= hdim == 128 or hdim == 192 - if not cond: - continue - - # fp32 only, all variations - if receipt == 800: - cond = dtype == "fp32" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - if not cond: - continue - # fp32 only, minimal set of parameters - elif receipt == 801: - cond = dtype == "fp32" - cond &= hdim in [48, 128] - cond &= mode == "batch" - cond &= pipeline.F_bias == "no" - cond &= pipeline.F_lse == "f" - cond &= pipeline.F_dropout == "f" - cond &= pipeline.F_skip == "f" - cond &= pipeline.F_logits == "f" - cond &= pipeline.F_mask == "s_no" - if not cond: - continue - else: - # Don't build fp32 by default - if dtype == "fp32": - continue api_pool.register_traits(k.api_trait()) gen.append(k) @@ -1026,11 +1316,34 @@ def get_fwd_blobs( def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - update_file(autogen_dir / kernel.filename, kernel.template) + update_file(autogen_dir / kernel.filename, kernel.render()) -def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: - update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) +def write_fwd_api( + api_pool: FmhaFwdApiPool, + autogen_dir: Path, +) -> None: + def accept_only_v3(trait: FmhaFwdApiTrait) -> bool: + return trait.pipeline_tag == "qr_async_trload_v3" + + def accept_only_v2(trait: FmhaFwdApiTrait) -> bool: + return not accept_only_v3(trait) + + content = "".join( + [ + FMHA_FWD_API_HEADER, + api_pool.render("fmha_fwd_v2", filter_fn=accept_only_v2), + api_pool.render("fmha_fwd_v3", filter_fn=accept_only_v3), + FMHA_FWD_API_FOOTER_TEMPLATE.format( + F_is_v3_enabled=BOOL_MAP[ + # NOTE: enable v3 pipelines when ready + # 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) + False + ] + ), + ] + ) + update_file(autogen_dir / FMHA_FWD_API_FILENAME, content) def write_blobs( diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp deleted file mode 100644 index c510b36bb5..0000000000 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ /dev/null @@ -1,616 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "fmha_fwd.hpp" -#include "fmha_fwd_v3.hpp" -#include "mask.hpp" - -auto parse_cmd_args(int argc, char* argv[]) -> std::pair -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("prec", "fp16", "data type. fp16/bf16") - .insert("b", "2", "batch size") - .insert("h", "8", "num of head, for q") - .insert("h_k", - "-1", - "num of head, for k/v, -1 means equal to h\n" - "if not equal to h, then this is GQA/MQA case") - .insert("s", "3328", "seqlen_q") - .insert("s_k", "-1", "seqlen_k, -1 means equal to s") - .insert("d", "128", "head dim for q & k") - .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)") - .insert("iperm", - "0", - "permute input\n" - "if true, will be b*h*s*d, else b*s*h*d") - .insert("operm", "0", "permute output") - .insert("causal", "0", "0: no mask, 1: causal mask") - .insert("v", "1", "0:no verify, 1:verify") - .insert("seed", - "11939", - "random seed used for initializing input tensors. 0 for " - "non-deterministic seed") - .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "30", "number of iterations to benchmark the kernel") - // Optional effective seqlen override (exclude PAD) for batch mode - .insert("q_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override.") - .insert("kv_eff_lens", - "", - "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override."); - - bool result = arg_parser.parse(argc, argv); - return std::make_pair(result, arg_parser); -} - -enum class TensorLayout -{ - bhsd, - bshd, -}; - -std::ostream& operator<<(std::ostream& stream, TensorLayout layout) -{ - switch(layout) - { - case TensorLayout::bhsd: return stream << "bhsd"; - case TensorLayout::bshd: return stream << "bshd"; - default: return stream << "unknown"; - } -} - -struct Problem -{ - explicit Problem(const ck_tile::ArgParser& args) - { - data_type = args.get_str("prec") == "fp16" - ? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16 - : ck_tile::fmha_fwd_v3_args::data_type_enum::bf16; - batch = args.get_int("b"); - seqlen_q = args.get_int("s"); - seqlen_k = args.get_int("s_k"); - if(seqlen_k < 0) - { - seqlen_k = seqlen_q; - } - nhead_q = args.get_int("h"); - nhead_kv = args.get_int("h_k"); - if(nhead_kv < 0) - { - nhead_kv = nhead_q; - } - hdim = args.get_int("d"); - softmax_scale = args.get_float("scale_s"); - if(softmax_scale == .0f) - softmax_scale = 1.0 / ck_tile::sqrt(static_cast(hdim)); - - const auto is_causal = args.get_bool("causal"); - if(is_causal) - { - mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k); - } - else - { - mask = mask_info::decode("0", seqlen_q, seqlen_k); - } - - input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; - output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; - q_eff_lens = args.get_int_vec("q_eff_lens"); - kv_eff_lens = args.get_int_vec("kv_eff_lens"); - } - - std::vector get_query_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_q, seqlen_q, hdim}; - } - else - { - return {batch, seqlen_q, nhead_q, hdim}; - } - } - - std::vector get_key_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_kv, seqlen_k, hdim}; - } - else - { - return {batch, seqlen_k, nhead_kv, hdim}; - } - } - - std::vector get_value_shape() const - { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_kv, seqlen_k, hdim}; - } - else - { - return {batch, seqlen_k, nhead_kv, hdim}; - } - } - - std::vector get_output_shape() const - { - if(output_layout == TensorLayout::bhsd) - { - return {batch, nhead_q, seqlen_q, hdim}; - } - else - { - return {batch, seqlen_q, nhead_q, hdim}; - } - } - - ck_tile::fmha_fwd_v3_args::data_type_enum data_type; - ck_tile::index_t batch; - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t nhead_q; - ck_tile::index_t nhead_kv; - ck_tile::index_t hdim; - float softmax_scale; - mask_info mask; - TensorLayout input_layout; - TensorLayout output_layout; - std::vector q_eff_lens; - std::vector kv_eff_lens; -}; - -struct RunConfig -{ - explicit RunConfig(const ck_tile::ArgParser& args) - { - seed = args.get_uint32("seed"); - if(*seed == 0) - { - seed.reset(); - } - - kernel_warmup = args.get_int("warmup"); - kernel_repeat = args.get_int("repeat"); - verify = args.get_bool("v"); - } - - std::optional seed; - int kernel_warmup; - int kernel_repeat; - bool verify; -}; - -template -auto generate_qkv(const Problem& problem, - [[maybe_unused]] std::optional seed = std::nullopt) - -> std::tuple, - ck_tile::HostTensor, - ck_tile::HostTensor> -{ - ck_tile::HostTensor q(problem.get_query_shape()); - ck_tile::HostTensor k(problem.get_key_shape()); - ck_tile::HostTensor v(problem.get_value_shape()); - - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k); - ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v); - - return std::make_tuple(q, k, v); -} - -namespace host { -template -CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, - const ck_tile::HostTensor& k_bshd, - const ck_tile::HostTensor& v_bshd, - const mask_info& mask, - ck_tile::HostTensor& o_bshd, - const QElementOp& q_element_op = {}, - const KElementOp& k_element_op = {}, - const VElementOp& v_element_op = {}, - const SAccElementOp& s_acc_element_op = {}) -{ - const int batch_size = q_bshd.mDesc.get_lengths()[0]; - const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; - const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; - const int nhead_q = q_bshd.mDesc.get_lengths()[2]; - const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; - const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; - const int hdim_v = v_bshd.mDesc.get_lengths()[3]; - - const int nr = nhead_q / nhead_kv; - - ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); - ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); - ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); - ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); - - ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); - ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); - - // do computation for each batch - for(int b = 0; b < batch_size; ++b) - { - // copy per-batch data from input tensors - // clang-format off - q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); }); - k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); }); - v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); - // clang-format on - ck_tile::reference_batched_gemm( - q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - - if(mask.type == mask_enum::no_mask) - { - ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); - } - else if(mask.type == mask_enum::window_generic) - { - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, seqlen_q, seqlen_kv)); - } - else - { - // if left window size is negative, means causal - // else means generic (for current batch) - if(mask.left < 0) - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - else - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - } - - ck_tile::reference_batched_softmax( - s_host_ref, p_host_ref, ck_tile::identity{}); - - ck_tile::reference_batched_gemm( - p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - - // copy resulting per-batch data to the output tensor - o_host_ref.ForEach( - [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); - } -} -} // namespace host - -template -bool run_impl(const Problem& problem, const RunConfig& run_config) -{ - auto [q, k, v] = generate_qkv(problem, run_config.seed); - - ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes()); - /// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v - ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes()); - - q_buf.ToDevice(q.data()); - k_buf.ToDevice(k.data()); - v_buf.ToDevice(v.data()); - // Ensure output buffer is zero-initialized so padded regions compare cleanly - o_buf.SetZero(); - - ck_tile::fmha_fwd_v3_args args{}; - - args.data_type = problem.data_type; - args.batch = problem.batch; - args.seqlen_q = problem.seqlen_q; - args.seqlen_k = problem.seqlen_k; - args.nhead_q = problem.nhead_q; - args.nhead_kv = problem.nhead_kv; - args.hdim_qk = problem.hdim; - args.hdim_v = problem.hdim; - args.softmax_scale = problem.softmax_scale; - - args.window_size_left = problem.mask.left; - args.window_size_right = problem.mask.right; - args.mask_type = static_cast(problem.mask.type); - - // bshd: (batch, seqlen_q, nhead_q, hdim) - // bhsd: (batch, nhead_q, seqlen_q, hdim) - args.q_ptr = q_buf.GetDeviceBuffer(); - args.stride_q = - problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; - args.nhead_stride_q = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim; - args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim; - - // bshd: (batch, seqlen_k, nhead_kv, hdim) - // bhsd: (batch, nhead_kv, seqlen_k, hdim) - args.k_ptr = k_buf.GetDeviceBuffer(); - args.stride_k = - problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; - args.nhead_stride_k = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; - args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim; - - // bshd: (batch, seqlen_k, nhead_kv, hdim) - // bhsd: (batch, nhead_kv, seqlen_k, hdim) - args.v_ptr = v_buf.GetDeviceBuffer(); - args.stride_v = - problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; - args.nhead_stride_v = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; - args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim; - - // bshd: (batch, seqlen_q, nhead_q, hdim) - // bhsd: (batch, nhead_q, seqlen_q, hdim) - args.o_ptr = o_buf.GetDeviceBuffer(); - args.stride_o = - problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; - args.nhead_stride_o = problem.output_layout == TensorLayout::bshd - ? problem.hdim - : problem.seqlen_q * problem.hdim; - args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; - - // Optional cumulative seqlen overrides (exclude PAD) - const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; - const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; - - auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { - std::vector eff; - if(!opt_vec.empty() && opt_vec[0] != -1) - { - eff.assign(opt_vec.begin(), opt_vec.end()); - if(eff.size() < static_cast(problem.batch)) - { - eff.resize(problem.batch, eff.back()); - } - } - else - { - eff.assign(problem.batch, fallback); - } - return eff; - }; - - const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); - const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); - - // Calculate cumulative sums for kernel arguments if varlen is used - std::vector cuq_cum, cukv_cum; - auto calculate_cumulative = [&](const std::vector& per_batch_vec, - std::vector& cum_vec) { - cum_vec.resize(per_batch_vec.size() + 1); - cum_vec[0] = 0; - for(std::size_t i = 0; i < per_batch_vec.size(); ++i) - cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; - }; - - if(has_varlen_q) - { - calculate_cumulative(eff_q_vec, cuq_cum); - } - if(has_varlen_k) - { - calculate_cumulative(eff_kv_vec, cukv_cum); - } - - ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); - ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); - cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); - cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); - args.cu_seqlen_q_ptr = - !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) - : nullptr; - args.cu_seqlen_kv_ptr = - !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) - : nullptr; - - ck_tile::stream_config stream_config{nullptr, - true, - /*log_level=*/0, - run_config.kernel_warmup, - run_config.kernel_repeat}; - - auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config); - if(!result) - { - std::cerr << "faild to run fmha_fwd_v3()" << std::endl; - return false; - } - - std::size_t flop = [&] { - if(problem.mask.type == mask_enum::no_mask) - { - return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - else - { - /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. - return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - }(); - float tflops = static_cast(flop) / 1.e9 / time; - - std::cout << "[" << problem.data_type << "|"; - if(problem.input_layout == problem.output_layout) - { - std::cout << problem.input_layout; - } - else - { - std::cout << problem.input_layout << "-" << problem.output_layout; - } - std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim - << ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed - << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops - << " TFlops" << std::endl; - - if(!run_config.verify) - { - return true; - } - - // transpose tensor descriptors from bhsd to bshd if necessary - if(problem.input_layout != TensorLayout::bshd) - { - q = q.transpose({0, 2, 1, 3}); - k = k.transpose({0, 2, 1, 3}); - v = v.transpose({0, 2, 1, 3}); - } - - ck_tile::HostTensor o_ref(problem.get_output_shape()); - if(problem.output_layout != TensorLayout::bshd) - { - o_ref = o_ref.transpose({0, 2, 1, 3}); - } - - // If variable lengths are provided, compute per-batch references - // with the effective lengths; else compute a single full reference. - if(has_varlen_q || has_varlen_k) - { - // Variable-length aware verification: zero-fill padded region and only compute valid part. - o_ref.SetZero(); - - for(int b = 0; b < problem.batch; ++b) - { - const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; - const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; - - if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) - continue; - - // Slice current batch from inputs (bshd) and build single-batch tensors - ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - - // Copy effective region - q_b.ForEach([&](auto& self, auto idx) { - // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); - }); - k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); - v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); - - // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) - host::fmha_fwd(q_b, - k_b, - v_b, - problem.mask, - o_b, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - - // Scatter into o_ref's bshd descriptor memory - for(int s = 0; s < seqlen_q_eff; ++s) - { - for(int h = 0; h < problem.nhead_q; ++h) - { - for(int d = 0; d < problem.hdim; ++d) - { - o_ref(b, s, h, d) = o_b(0, s, h, d); - } - } - } - } - } - else - { - // No varlen override: compute the full reference once - host::fmha_fwd(q, - k, - v, - problem.mask, - o_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - } - - ck_tile::HostTensor o(problem.get_output_shape()); - o_buf.FromDevice(o.data()); - - const auto [rtol, atol] = [&] { - if constexpr(std::is_same_v) - return std::make_tuple(1e-3, 1e-3); - else - return std::make_tuple(1e-2, 1e-2); - }(); - return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); -} - -int main(int argc, char* argv[]) -{ - auto [parse_result, args] = parse_cmd_args(argc, argv); - if(!parse_result) - { - std::cerr << "failed to parse command line arguments" << std::endl; - } - - Problem problem(args); - RunConfig run_config(args); - - const auto run = [&] { - if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16) - { - return run_impl(problem, run_config); - } - else - { - return run_impl(problem, run_config); - } - }; - - return !run(); -} diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index f279ebfcea..002d0a1035 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -686,6 +686,100 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) } } +template +auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args) +{ + /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly + /// maximizes the kernel's performance. + int remap_opt = 2; + if(args.mask_type != static_cast(mask_enum::no_mask) && + ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) + { + if(65536 <= args.seqlen_q) + { + remap_opt = 0; + } + else + { + remap_opt = 1; + } + } + + auto kargs = [&] { + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + nullptr, // lse_ptr + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_q_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + 0, // nhead_stride_lse + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); + } + else + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + nullptr, // lse_ptr + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + 0, // nhead_stride_lse + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + 0, // batch_stride_lse + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + + return ck_tile::make_tuple(kargs, grids); +} + template auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) { diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/fmha_fwd_v3.cpp deleted file mode 100644 index 1c0256cc0f..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" -#include "mask.hpp" - -namespace ck_tile { - -std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type) -{ - switch(data_type) - { - case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16"; - case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16"; - default: return stream << "unknown"; - } -} - -std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config) -{ - if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16) - { - if(args.mask_type == static_cast(mask_enum::no_mask)) - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - else - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - } - else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16) - { - if(args.mask_type == static_cast(mask_enum::no_mask)) - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - else - { - using kernel_traits = - fmha_fwd_v3_kernel_traits; - - return fmha_fwd_v3_kernel_dispatch(args, config); - } - } - - return std::make_pair(false, -1.f); -} - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp deleted file mode 100644 index 54cc4960a5..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck_tile/core/numeric/integer.hpp" -#include "ck_tile/host/stream_config.hpp" - -namespace ck_tile { - -struct fmha_fwd_v3_args -{ - enum class data_type_enum - { - fp16, - bf16 - }; - - data_type_enum data_type; - // bool is_varlen; - - index_t batch; - index_t seqlen_q; - index_t seqlen_k; - index_t nhead_q; - index_t nhead_kv; - index_t hdim_qk; - index_t hdim_v; - - float softmax_scale; - - index_t window_size_left; - index_t window_size_right; - index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and - // window_size_right == 0). - - const void* q_ptr; - index_t stride_q; - index_t nhead_stride_q; - index_t batch_stride_q; - - const void* k_ptr; - index_t stride_k; - index_t nhead_stride_k; - index_t batch_stride_k; - - const void* v_ptr; - index_t stride_v; - index_t nhead_stride_v; - index_t batch_stride_v; - - void* o_ptr; - index_t stride_o; - index_t nhead_stride_o; - index_t batch_stride_o; - - // Optional batch-mode cumulative seqlen overrides (exclude PAD) - // If provided, they override per-batch effective lengths to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] -}; - -std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); - -// return value: -// first = whether the kernel was launched (true = launched, false = skipped) -// second = elapsed time (ms) of the kernel launch, valid only if first == true -std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config); - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp deleted file mode 100644 index 19b8dfed4e..0000000000 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include - -#include "ck_tile/core/numeric/bfloat16.hpp" -#include "ck_tile/core/numeric/half.hpp" -#include "ck_tile/core/container/sequence.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" -#include "ck_tile/ops/fmha/block/block_masking.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" - -#include "fmha_fwd_v3.hpp" -#include "mask.hpp" - -#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ - template <> \ - std::pair fmha_fwd_v3_kernel_dispatch( \ - const fmha_fwd_v3_args& args, const stream_config& config) \ - { \ - return std::make_pair(true, \ - fmha_fwd_v3_kernel_launch(args, config)); \ - } - -namespace ck_tile { - -template -struct fmha_fwd_v3_problem_traits; - -template <> -struct fmha_fwd_v3_problem_traits -{ - using qkvp_dtype = ck_tile::half_t; - using acc_dtype = float; - using o_dtype = ck_tile::half_t; - using lse_dtype = float; -}; - -template <> -struct fmha_fwd_v3_problem_traits -{ - using qkvp_dtype = ck_tile::bf16_t; - using acc_dtype = float; - using o_dtype = ck_tile::bf16_t; - using lse_dtype = float; -}; - -template -struct fmha_fwd_v3_kernel_traits -{ - static constexpr auto date_type = DataType; - static constexpr bool is_variable_seqlen = IsVariableSeqlen; - static constexpr bool is_masking = IsMasking; - - // M0 N0 K0 N1 K1 - using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>; - using fmha_warp_gemm_shape = sequence<32, 32, 16>; - using fmha_block_warps = sequence<8, 1, 1>; - - using fmha_shape = TileFmhaShape; - - using fmha_traits = TileFmhaFwdV3Traits; - - using fmha_mask = GenericAttentionMask; - - using fmha_pipeline_problem = - BlockFmhaFwdV3PipelineProblem::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::lse_dtype, - typename fmha_fwd_v3_problem_traits::qkvp_dtype, - typename fmha_fwd_v3_problem_traits::acc_dtype, - typename fmha_fwd_v3_problem_traits::o_dtype, - fmha_shape, - IsVariableSeqlen, - fmha_mask, - fmha_traits>; - - using fmha_pipeline = BlockFmhaFwdV3Pipeline; - - using epilogue = Default2DEpilogue< - Default2DEpilogueProblem::acc_dtype, - typename fmha_fwd_v3_problem_traits::o_dtype, - true, // kPadM - true, // kPadM - true // UseRawStore - >>; - - using kernel = FmhaFwdV3Kernel; -}; - -template -float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config) -{ - /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly - /// maximizes the kernel's performance. - int remap_opt = 2; - if(args.mask_type != static_cast(mask_enum::no_mask) && - ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) - { - if(65536 <= args.seqlen_q) - { - remap_opt = 0; - } - else - { - remap_opt = 1; - } - } - - auto kargs = Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - nullptr, // lse_ptr - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_qk, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_kv, - args.softmax_scale, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - 0, // nhead_stride_lse - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - 0, // batch_stride_lse - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - remap_opt, - args.cu_seqlen_q_ptr, - args.cu_seqlen_kv_ptr); - - dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); - constexpr dim3 blocks = Kernel::BlockSize(); - constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; - - return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); -} - -// return value: -// first = whether the kernel was launched (true = launched, false = skipped) -// second = elapsed time (ms) of the kernel launch, valid only if first == true -template -std::pair fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args, - const stream_config& config); - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp deleted file mode 100644 index 463c52b824..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp deleted file mode 100644 index acf79e43f4..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp deleted file mode 100644 index a6366209b2..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp deleted file mode 100644 index a83e37cc68..0000000000 --- a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "fmha_fwd_v3.hpp" -#include "fmha_fwd_v3_impl.hpp" - -namespace ck_tile { - -using kernel_traits = - fmha_fwd_v3_kernel_traits; - -INST_FMHA_FWD_V3_DISPATCH(kernel_traits) - -} // namespace ck_tile diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index ac174379df..efb83efbd2 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -284,26 +284,25 @@ bool run(const ck_tile::ArgParser& arg_parser) } else if(init == 1) { - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(a_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(g_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(d_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sa_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sg_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sd_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sy_host); - ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}( - topk_weight_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(g_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(d_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sa_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sg_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sd_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sy_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed}(topk_weight_host); } else if(init == 2) { - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(a_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(g_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(d_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sa_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sg_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sd_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sy_host); - ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(topk_weight_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(a_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(g_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(d_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sa_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sg_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sd_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sy_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed}(topk_weight_host); } // permute weight diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index d8b905fe3d..d3b75ac72f 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -9,14 +9,190 @@ #include #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm_quant.hpp" #include "ck_tile/host.hpp" #include "quant_grouped_gemm.hpp" +template +float grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + GemmQuantConfig::template BaseGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; + + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + template ; // Persistence + GemmConfig::Persistent>; float ave_time{0}; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = GemmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - constexpr bool transpose_c = false; - using QuantGemmProblem = typename std::conditional< - QuantMode == ck_tile::QuantType::BQuantGrouped, - ck_tile::GemmBQuantPipelineProblem, + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; + + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, ck_tile::GemmRowColTensorQuantPipelineProblem>::type; + scheduler>>; - using GemmPipeline = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(argc, argv); + int result1 = run_grouped_gemm_example(argc, argv); return result1; } diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index ede683abe6..0317685770 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -64,6 +64,7 @@ struct GemmTypeConfig using CDataType = ck_tile::half_t; }; +template struct GemmConfigBase { static constexpr bool kPadM = false; @@ -83,10 +84,11 @@ struct GemmConfigBase static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool DoubleSmemBuffer = false; static constexpr bool PreshuffleB = false; + static constexpr bool Persistent = Persistent_; }; -template -struct GemmConfigComputeV3_2 : public GemmConfigBase +template +struct GemmConfigComputeV3_2 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -101,8 +103,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; -template -struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase +template +struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -121,6 +123,66 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr bool DoubleSmemBuffer = true; }; +template +struct GemmQuantConfig; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; + + template + using GemmPipeline = std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>; + + template + using BaseGemmPipeline = + std::conditional_t, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; +}; + using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; auto create_args(int argc, char* argv[]) @@ -148,8 +210,9 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") .insert("kbatch", "1", "kbatch for SplitK") - .insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol") - .insert("init", "0", "0. Random, 2. One(s) (Constant)"); + .insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol") + .insert("init", "0", "0. Random, 2. One(s) (Constant)") + .insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 37fab44f77..37832b54ba 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -57,56 +57,83 @@ float invoke_gemm(int n_warmup, float ave_time = 0; - // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have - // the gemm problems known on the host. Instead, we can just pass the pointer - // to the kernel and let the workgroups figure out which tiles to work on. - // This is useful when the gemm problems are generated dynamically. - // In this example however, we generate the `kargs` using the known gemm_descs, - // and copy the gemm descriptions to the device memory. - // The contents of the memory pointed to by `kargs_ptr` pointer could be - // written by e.g. another kernel from earlier stage. - std::vector kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - assert(args[0].k_batch == 1); - for(const auto& arg : args) + if constexpr(!GemmConfig::Persistent) { - kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, - arg.b_ptr, - arg.aq_ptr, - arg.bq_ptr, - arg.e_ptr, - arg.M, - arg.N, - arg.K, - arg.QK_A, - arg.QK_B, - arg.stride_A, - arg.stride_B, - arg.stride_E, - arg.stride_AQ, - arg.stride_BQ, - arg.k_batch}); + ave_time = + grouped_gemm(args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if(args[0].k_batch != 1) + { + throw std::runtime_error("Split-K not supported yet for persistent kernel"); + } + + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); } - const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), - hipMemcpyHostToDevice, - stream.stream_id_)); - ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; @@ -259,13 +286,24 @@ int run_grouped_gemm_example_with_layouts(int argc, AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize + BQK = 0; // No B quantization + if(K % QuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for AQuantGrouped mode"); + } + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { AQK = 0; // No A quantization BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize if(K % QuantGroupSize::kK != 0) { - throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for BQuantGrouped mode"); } } @@ -284,6 +322,12 @@ int run_grouped_gemm_example_with_layouts(int argc, stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout)); + stride_BQs[i] = 0; // No B quantization + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { stride_AQs[i] = 0; // No A quantization @@ -311,10 +355,17 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout)))); } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(0, 0, stride_BQs[i], is_row_major(bq_layout)))); + } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { aq_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout)))); + ck_tile::host_tensor_descriptor(0, 0, stride_AQs[i], is_row_major(aq_layout)))); bq_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); } @@ -444,7 +495,7 @@ int run_grouped_gemm_example_with_layouts(int argc, bq_tensors[i], c_m_n_host_ref); } - else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } + else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); } @@ -477,7 +539,7 @@ int run_grouped_gemm_example_with_layouts(int argc, return pass; } -template +template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; @@ -494,6 +556,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a if(a_layout == "R" && b_layout == "C") { + return run_grouped_gemm_example_with_layouts typename GemmConfig> +template +int run_gemm_example_persistency( + std::string a_layout, std::string b_layout, bool persistent, int argc, char* argv[]) +{ + if(persistent) + { + using GemmConfig = GemmQuantConfig::template GemmConfig; + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else + { + using GemmConfig = GemmQuantConfig::template GemmConfig; + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } +} + int run_grouped_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -524,29 +604,29 @@ int run_grouped_gemm_example(int argc, char* argv[]) const std::string b_layout = arg_parser.get_str("b_layout"); const std::string data_type = arg_parser.get_str("prec"); std::string quant_mode = arg_parser.get_str("quant_mode"); + bool persistent = arg_parser.get_bool("persistent"); if(data_type == "fp8") { if(quant_mode == "tensor") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::TensorQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "rowcol") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::RowColQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); + } + else if(quant_mode == "aquant") + { + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "bquant") { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::QuantType::BQuantGrouped>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else { @@ -557,24 +637,23 @@ int run_grouped_gemm_example(int argc, char* argv[]) { if(quant_mode == "tensor") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::TensorQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "rowcol") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::RowColQuant>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); + } + else if(quant_mode == "aquant") + { + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else if(quant_mode == "bquant") { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::QuantType::BQuantGrouped>( - a_layout, b_layout, argc, argv); + return run_gemm_example_persistency( + a_layout, b_layout, persistent, argc, argv); } else { diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index 44fd12e2d9..cc2c041ed6 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -71,17 +71,17 @@ int run_mx_flatmm_with_layouts(int argc, if(init_method == 0) { - ck_tile::FillUniformDistribution{0.0f, 1.0f}(a_host); - ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); - ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_a); - ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_b); + ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host); + ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b); } else if(init_method == 1) { - ck_tile::FillUniformDistribution{1.f, 1.f}(a_host); - ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); - ck_tile::FillUniformDistribution{1.f, 1.f}(scale_a); - ck_tile::FillUniformDistribution{1.f, 1.f}(scale_b); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b); } else { diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index b0a0d3fee7..5773db236f 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -69,7 +69,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str using BaseGemmPipeline = std::conditional_t< GemmConfig::PreshuffleB == true, ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseGemmPipelineAgBgCrCompV3>; + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, + ck_tile::BaseGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BaseGemmPipelineAgBgCrCompV3>>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -128,7 +133,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::GemmPipelineAgBgCrCompV3, std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::AQuantGemmPipelineAgBgCrMem>, std::conditional_t, ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; @@ -433,7 +440,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { stride_AQ = 0; // No A quantization - stride_BQ = ck_tile::get_default_stride(BQK, N, stride_BQ, is_row_major(bq_layout)); + stride_BQ = ck_tile::get_default_stride(BQK, BQN, stride_BQ, is_row_major(bq_layout)); } else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) { diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index ecb1ff933e..bf7e89fcaa 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -95,6 +95,47 @@ concept AccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; }; +// Concept for thread block dimensions for a GEMM problem for CK Tile (Block +// size is deduced from block gemm structure). +template +concept TileThreadBlockDescriptor = requires(T t) { + { t.tile_size.m } -> std::convertible_to; + { t.tile_size.n } -> std::convertible_to; + { t.tile_size.k } -> std::convertible_to; +}; + +// Concept for thread block dimensions for a GEMM problem for CK Tile (Block +// size is deduced from block gemm structure). +template +concept TileTransferDescriptor = requires(T t) { + { t.a_scalar_per_vector } -> std::convertible_to; + { t.b_scalar_per_vector } -> std::convertible_to; + { t.c_scalar_per_vector } -> std::convertible_to; +}; + +// Concept to check if struct specifies block GEMM (CK Tile). +template +concept TileBlockGemmDescriptor = requires(T t) { + { t.warps.m } -> std::convertible_to; + { t.warps.n } -> std::convertible_to; + { t.warps.k } -> std::convertible_to; + { t.warp_tile.m } -> std::convertible_to; + { t.warp_tile.n } -> std::convertible_to; + { t.warp_tile.k } -> std::convertible_to; + { t.double_smem_buffer } -> std::convertible_to; + { t.num_wave_groups } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; + { t.scheduler } -> std::convertible_to; +}; + +// Concept to check if struct specifies optimizations (CK Tile). +template +concept TileOptimizationsDescriptor = requires(T t) { + { t.num_groups_to_merge } -> std::convertible_to; + { t.split_image } -> std::convertible_to; + { t.explicit_gemm } -> std::convertible_to; +}; + // Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this // concept. template @@ -110,6 +151,12 @@ concept SpecifiesThreadBlock = requires { { T::thread_block } -> ThreadBlockDescriptor; }; +// Concept to check if struct specifies thread block info (CK Tile). +template +concept SpecifiesTileThreadBlock = requires { + { T::thread_block } -> TileThreadBlockDescriptor; +}; + // Concept to check if a struct specifies gridwise XDL GEMM info. template concept SpecifiesGridwiseXdlGemm = requires { @@ -130,6 +177,14 @@ concept SpecifiesBlockTransfer = requires(T t) { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; +// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C. +template +concept SpecifiesTileTransfer = requires(T t) { + { T::transfer.a_scalar_per_vector } -> std::convertible_to; + { T::transfer.b_scalar_per_vector } -> std::convertible_to; + { T::transfer.c_scalar_per_vector } -> std::convertible_to; +}; + // Concept to check if a struct specifies LDS transfer info for tensors A, B, and C. template concept SpecifiesLdsTransfer = requires(T t) { @@ -159,8 +214,36 @@ concept SpecifiesBlockGemm = requires { { T::block_gemm.scheduler } -> std::convertible_to; }; +// Concept to check if struct specifies block GEMM (CK Tile). template -concept SpecifiesFwdConcSpecialization = requires { +concept SpecifiesTileBlockGemm = requires { + { T::block_gemm.warps.m } -> std::convertible_to; + { T::block_gemm.warps.n } -> std::convertible_to; + { T::block_gemm.warps.k } -> std::convertible_to; + { T::block_gemm.warp_tile.m } -> std::convertible_to; + { T::block_gemm.warp_tile.n } -> std::convertible_to; + { T::block_gemm.warp_tile.k } -> std::convertible_to; + { T::block_gemm.double_smem_buffer } -> std::convertible_to; + { T::block_gemm.num_wave_groups } -> std::convertible_to; + { T::block_gemm.pipeline_version } -> std::convertible_to; + { T::block_gemm.scheduler } -> std::convertible_to; +}; + +// Concept to check if struct specifies block GEMM (CK Tile). +template +concept SpecifiesTileOptimizations = requires { + { T::optimizations.num_groups_to_merge } -> std::convertible_to; + { T::optimizations.split_image } -> std::convertible_to; + { T::optimizations.explicit_gemm } -> std::convertible_to; +}; + +template +concept SpecifiesTileConvSpecialization = requires { + { T::specialization } -> std::convertible_to; +}; + +template +concept SpecifiesFwdConvSpecialization = requires { { T::fwd_specialization } -> std::convertible_to; }; diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index 093916dac3..10a619024a 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -15,6 +15,11 @@ concept InputVectorTransferLimits = requires { Value.lds_dst_scalar_per_vector > 0; }; +// Limits for input and output vector transfer (CK Tile). +template +concept TileInputOutputVectorTransferLimits = + requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; }; + // Limits for output vector transfer. template concept OutputVectorTransferLimits = requires { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 51945544b2..9a9c2235e0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -59,6 +59,7 @@ #include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" +#include "ck_tile/builder/factory/conv_tile_factory.hpp" namespace ck_tile::builder::factory { @@ -81,6 +82,15 @@ namespace ck_tile::builder::factory { // // TODO: Make this dispatch logic much more robust and clear for users. +// CK Tile kernel +template +consteval bool IsTileAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && SpecifiesTileTransfer && + SpecifiesTileConvSpecialization && SpecifiesTileBlockGemm && + SpecifiesTileOptimizations; +} + // XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) template consteval bool IsXdlV3Algorithm() @@ -88,7 +98,7 @@ consteval bool IsXdlV3Algorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesBlockGemm; } @@ -99,7 +109,7 @@ consteval bool IsXdlAlgorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; } @@ -111,7 +121,7 @@ consteval bool IsWmmaAlgorithm() return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; } @@ -120,7 +130,7 @@ template consteval bool IsDlAlgorithm() { return ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; } @@ -137,10 +147,15 @@ template constexpr auto make_conv_instance() { - if constexpr(ConvDirectionIsForward) - { - using AlgoType = std::remove_const_t; + using AlgoType = std::remove_const_t; + // CK Tile supports common factory for each direction + if constexpr(IsTileAlgorithm()) + { + return typename ConvTileFactory::Instance{}; + } + else if constexpr(ConvDirectionIsForward) + { if constexpr(IsXdlV3Algorithm()) { return typename ConvFwdXdlV3Factory::Instance{}; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index 0c675ac7f1..ca202aabfd 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -7,11 +7,11 @@ #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 98e368ca61..fadf41f48a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 79955a1f44..89787cc1b3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index fcce46aea7..bb84479071 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index df7fb25168..8ec5c633ce 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -8,12 +8,12 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/conv_algorithm_limits.hpp" #include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" -#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace ck_tile::builder::factory { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp new file mode 100644 index 0000000000..cce95cb3f1 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp @@ -0,0 +1,131 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp" + +namespace ck_tile::builder::factory { + +// Factory for CK Tile Grouped Convolution kernels. +template +struct ConvTileFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::TileConvTensorLayouts; + using Types = internal::TileConvTensorTypes; + using Ops = internal::TileElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto CONV_SPECIALIZATION = internal::SetTileConvSpecialization(); + static constexpr auto BLOCK = internal::SetTileThreadBlockInfo(); + static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm(); + static constexpr auto OPTIMIZATIONS = internal::SetTileOptimizations(); + static constexpr auto SCALAR_PER_VECTOR = internal::SetTileBlockTransfer(); + static constexpr auto CONV_DIRECTION = internal::SetTileConvDirection(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(TileInputOutputVectorTransferLimits); + + using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + BLOCK_GEMM.double_smem_buffer, + typename GroupedConvTraitsType::template GemmLayouts::AsLayout, + typename GroupedConvTraitsType::template GemmLayouts::BsLayout, + typename GroupedConvTraitsType::template GemmLayouts::CLayout, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + BLOCK_GEMM.num_wave_groups>; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + GemmShape, + GemmUniversalTraits, + BLOCK_GEMM.scheduler, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Types::EDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = typename internal::TilePipelineType< + BLOCK_GEMM.pipeline_version>::template GemmPipeline; + + using ConvEpilogue = ck_tile::CShuffleEpilogue>; + + using Instance = typename internal::GroupedConvolutionTileKernel::Instance; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_block_transfer.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_thread_block.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_thread_block.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp similarity index 100% rename from experimental/builder/include/ck_tile/builder/factory/helpers/conv_tuning_params.hpp rename to experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp new file mode 100644 index 0000000000..fbeb48b045 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp @@ -0,0 +1,25 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +struct TileScalarPerVector +{ + size_t a = 0; + size_t b = 0; + size_t c = 0; +}; + +template +constexpr TileScalarPerVector SetTileBlockTransfer() +{ + return TileScalarPerVector{.a = ALGORITHM.transfer.a_scalar_per_vector, + .b = ALGORITHM.transfer.b_scalar_per_vector, + .c = ALGORITHM.transfer.c_scalar_per_vector}; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp new file mode 100644 index 0000000000..45ff7d265d --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp @@ -0,0 +1,62 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder::factory::internal { + +template +struct ElementwiseOpToCKTile +{ + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Unsupported elementwise operation conversion to CK."); +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::PassThrough; +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::Scale; +}; + +template <> +struct ElementwiseOpToCKTile +{ + using Op = ck_tile::element_wise::Clamp; +}; + +template +consteval auto GetTileElementwiseOp() +{ + if constexpr(HasTensorOp) + { + constexpr auto op = TensorDesc.operation.elementwise_operation; + return ElementwiseOpToCKTile{}; + } + else + { + return ElementwiseOpToCKTile{}; + } +} + +template +struct TileElementwiseOps +{ + static constexpr auto input_op = GetTileElementwiseOp(); + static constexpr auto weight_op = GetTileElementwiseOp(); + static constexpr auto output_op = GetTileElementwiseOp(); + using AElementwiseOp = typename decltype(input_op)::Op; + using BElementwiseOp = typename decltype(weight_op)::Op; + using CDEElementwiseOp = typename decltype(output_op)::Op; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp new file mode 100644 index 0000000000..189b199ffc --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_signature_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +template +struct GroupedConvolutionTileKernel +{ + static_assert(false, "Unknown Direction"); +}; + +template + requires ConvDirectionIsForward +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionForwardKernel; +}; + +template + requires ConvDirectionIsBackwardData +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionBackwardDataKernel; +}; + +template + requires ConvDirectionIsBackwardWeight +struct GroupedConvolutionTileKernel +{ + using Instance = ck_tile::GroupedConvolutionBackwardWeightKernel; +}; + +template +consteval ck_tile::GroupedConvDirection SetTileConvDirection() +{ + constexpr auto direction = SIGNATURE.direction; + using ck_tile_direction = ck_tile::GroupedConvDirection; + switch(direction) + { + case ConvDirection::FORWARD: return ck_tile_direction::FORWARD; + case ConvDirection::BACKWARD_DATA: return ck_tile_direction::BACKWARD_DATA; + case ConvDirection::BACKWARD_WEIGHT: return ck_tile_direction::BACKWARD_WEIGHT; + default: throw "Unknown Direction"; + } +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp new file mode 100644 index 0000000000..2aaca98586 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp @@ -0,0 +1,200 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" + +namespace ck_tile::builder::factory::internal { +using ALayout = ck_tile::tensor_layout::convolution::NWGC; +template +struct LayoutToCKTile +{ + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Unsupported layout conversion to CK."); +}; + +// Bias layouts +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::G_K; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::G_C; +}; + +// Input 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNWC; +}; + +// Input 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NHWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNHWC; +}; + +// Input 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NDHWGC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNDHWC; +}; + +// Weight 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKXC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCX; +}; + +// Weight 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKYXC; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCYX; +}; + +// Weight 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKCZYX; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GKZYXC; +}; + +// Output 1D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNWK; +}; + +// Output 2D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NHWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNHWK; +}; + +// Output 3D +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::NDHWGK; +}; +template <> +struct LayoutToCKTile +{ + using type = ck_tile::tensor_layout::convolution::GNDHWK; +}; + +template +consteval auto TensorLayoutToCKTile() +{ + return typename LayoutToCKTile::type{}; +} + +struct EmptyAuxiliaryTileTensorLayout +{ + using type = ck_tile::tuple<>; +}; + +template +consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence) +{ + return ck_tile::tuple< + decltype(TensorLayoutToCKTile())...>{}; +} + +template + requires(ConvSpatialDim) +struct AuxiliaryTileTensorLayouts +{ + static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size(); + using type = decltype(GetAuxiliaryTileTensorLayoutTuple( + std::make_index_sequence{})); +}; + +// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias). +template + requires(HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTileTensorLayouts() +{ + return AuxiliaryTileTensorLayouts{}; +} + +template + requires(!HasElementwiseOpWithAuxiliaryOperands) +consteval auto GetAuxiliaryTileTensorLayouts() +{ + return EmptyAuxiliaryTileTensorLayout{}; +} + +template + requires(ConvSpatialDim && + ValidConvInputLayoutForSpatialDim && + ValidConvWeightLayoutForSpatialDim && + ValidConvOutputLayoutForSpatialDim) +struct TileConvTensorLayouts +{ + using ALayout = decltype(TensorLayoutToCKTile()); + using BLayout = decltype(TensorLayoutToCKTile()); + using ELayout = decltype(TensorLayoutToCKTile()); + using DsLayout = decltype(GetAuxiliaryTileTensorLayouts())::type; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp new file mode 100644 index 0000000000..493fbb7d9b --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp @@ -0,0 +1,87 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/builder/types.hpp" +#include "ck_tile/builder/builder_utils.hpp" + +namespace ck_tile::builder::factory::internal { + +// Type mappings from builder convolution data type to CK Tile tensor types. +template +struct TileConvTensorTypes +{ + // This will trigger if a specialization for the given DataType is not found. + // We should always catch this in an earlier validation check. + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Internal error. Unsupported data type for convolution factory."); +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::half_t; + using AComputeType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using BComputeType = ck_tile::half_t; + using CShuffleDataType = ck_tile::half_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::half_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::bf16_t; + using AComputeType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using BComputeType = ck_tile::bf16_t; + using CShuffleDataType = ck_tile::bf16_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::bf16_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = float; + using AComputeType = float; + using BDataType = float; + using BComputeType = float; + using CShuffleDataType = float; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = float; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = int8_t; + using AComputeType = int8_t; + using BDataType = int8_t; + using BComputeType = int8_t; + using CShuffleDataType = int8_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = int32_t; + using EDataType = int8_t; +}; + +template <> +struct TileConvTensorTypes +{ + using ADataType = ck_tile::fp8_t; + using AComputeType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using BComputeType = ck_tile::fp8_t; + using CShuffleDataType = ck_tile::fp8_t; + using DsDataTypes = ck_tile::tuple<>; + using AccDataType = float; + using EDataType = ck_tile::fp8_t; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp new file mode 100644 index 0000000000..65d81a49c4 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +// Convenience struct for a tuple of m, n, and k values. +struct TileBlockMNK +{ + int m{}; + int n{}; + int k{}; +}; + +struct TileConvBlock +{ + TileBlockMNK per_block = {}; +}; + +template +constexpr TileConvBlock SetTileThreadBlockInfo() +{ + constexpr auto& TB = ALGORITHM.thread_block; + return TileConvBlock{ + .per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k}, + }; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp new file mode 100644 index 0000000000..b7df0e4d0e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp @@ -0,0 +1,158 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder::factory::internal { + +// Convenience struct for a tuple of m, n, and k values. +struct TileBlockGemmMNK +{ + int m{}; + int n{}; + int k{}; +}; + +struct TileBlockGemmSpec +{ + TileBlockGemmMNK warps = {}; + TileBlockGemmMNK warp_tile = {}; + + bool double_smem_buffer = false; + int num_wave_groups = 1; + + ck_tile::GemmPipeline pipeline_version; + ck_tile::GemmPipelineScheduler scheduler; +}; + +struct TileOptimizations +{ + int num_groups_to_merge = 1; + bool split_image = false; + bool explicit_gemm = false; +}; + +template +consteval ck_tile::GemmPipelineScheduler SetTileScheduler() +{ + constexpr auto scheduler = ALGORITHM.block_gemm.scheduler; + using ck_tile_sched = ck_tile::GemmPipelineScheduler; + switch(scheduler) + { + case PipelineScheduler::DEFAULT: return ck_tile_sched::Default; + case PipelineScheduler::INTERWAVE: return ck_tile_sched::Interwave; + case PipelineScheduler::INTRAWAVE: return ck_tile_sched::Intrawave; + default: throw "Unknown PipelineScheduler"; + } +} + +template +struct TilePipelineType +{ + static_assert(false, "Unknown PipelineScheduler"); +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; +}; + +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; +}; + +template +consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion() +{ + constexpr auto version = ALGORITHM.block_gemm.pipeline_version; + using ck_tile_pipeline = ck_tile::GemmPipeline; + switch(version) + { + case PipelineVersion::V1: return ck_tile_pipeline::BASIC_V1; + case PipelineVersion::V2: return ck_tile_pipeline::MEMORY; + case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3; + case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4; + case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5; + case PipelineVersion::WEIGHT_ONLY: + throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version."; + default: throw "Unknown block GEMM PipelineVersion"; + } +} + +template +consteval ck_tile::ConvolutionSpecialization SetTileConvSpecialization() +{ + constexpr auto specialization = ALGORITHM.specialization; + using ck_tile_conv_spec = ck_tile::ConvolutionSpecialization; + switch(specialization) + { + case TileConvSpecialization::DEFAULT: return ck_tile_conv_spec::Default; + case TileConvSpecialization::FILTER_1X1_PAD0: return ck_tile_conv_spec::Filter1x1Pad0; + case TileConvSpecialization::FILTER_1X1_STRIDE1_PAD0: + return ck_tile_conv_spec::Filter1x1Stride1Pad0; + case TileConvSpecialization::FILTER_3x3: return ck_tile_conv_spec::Filter3x3; + default: throw "Unknown ConvFwdSpecialization"; + } +} + +template +consteval TileBlockGemmSpec SetTileBlockGemm() +{ + constexpr auto& BG = ALGORITHM.block_gemm; + + constexpr bool double_smem_buffer = BG.double_smem_buffer; + constexpr int num_wave_groups = BG.num_wave_groups; + + constexpr ck_tile::GemmPipeline pipeline_version = SetTileBlockGemmPipelineVersion(); + constexpr ck_tile::GemmPipelineScheduler scheduler = SetTileScheduler(); + + return TileBlockGemmSpec{ + .warps = {.m = BG.warps.m, .n = BG.warps.n, .k = BG.warps.k}, + .warp_tile = {.m = BG.warp_tile.m, .n = BG.warp_tile.n, .k = BG.warp_tile.k}, + .double_smem_buffer = double_smem_buffer, + .num_wave_groups = num_wave_groups, + .pipeline_version = pipeline_version, + .scheduler = scheduler}; +} + +template +consteval TileOptimizations SetTileOptimizations() +{ + constexpr auto& OPT = ALGORITHM.optimizations; + + return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge, + .split_image = OPT.split_image, + .explicit_gemm = OPT.explicit_gemm}; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index 261c3f103d..59ff83c238 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -251,14 +251,10 @@ class ConvDescription : public Description }; } // namespace conv -/// @brief Helper concept to detect if a type has ConvTraits specialization -template -concept HasConvTraits = requires { typename conv::ConvTraits; }; - /// @brief Factory function to create ConvDescription from a convolution instance type -/// @tparam Instance The convolution instance type (must have InstanceTraits specialization) +/// @tparam Instance The convolution instance type (must have ConvTraits specialization) /// @return A ConvDescription object populated with the instance's configuration details -template +template conv::ConvDescription describe() { using Traits = conv::ConvTraits; diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 29ac49e549..e5a5638887 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -4,23 +4,74 @@ #pragma once #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/pipeline_enum.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck_tile/builder/conv_builder.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_util.hpp" +#include "ck_tile/builder/types.hpp" #include "ck_tile/ops/epilogue.hpp" -#include +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" namespace ck_tile::reflect::conv { +// Forward convolution layout concept - checks for A/B/E layout types +template +concept HasFwdConvLayouts = requires { + typename T::ALayout; + typename T::BLayout; + typename T::ELayout; +}; + +// GEMM specialization concept - checks for kGemmSpecialization member +template +concept HasGemmSpec = requires { + { + T::kGemmSpecialization + } -> std::convertible_to; +}; + +// Data types concept - checks for ADataType member +template +concept HasDataTypes = requires { typename T::ADataType; }; + +// Elementwise operations concept - checks for A/B/CDE elementwise operation types +template +concept HasElementwiseOps = requires { + typename T::AElementwiseOperation; + typename T::BElementwiseOperation; + typename T::CDEElementwiseOperation; +}; + +// Tile parameters concept - checks for tile dimension and transfer members +template +concept HasTileParams = requires { + { T::kKPerBlock } -> std::convertible_to; + { T::kMPerBlock } -> std::convertible_to; + { T::kNPerBlock } -> std::convertible_to; + { T::kAK1 } -> std::convertible_to; + { T::kBK1 } -> std::convertible_to; + T::kCThreadClusterLengths; +}; + +// Comprehensive concept that checks if an instance has all XDL forward convolution traits +// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions +template +concept IsXdlFwdConv = HasFwdConvLayouts && HasGemmSpec && HasDataTypes && + HasElementwiseOps && HasTileParams; + +// Primary concept for checking if a type can be described +// Currently only forward convolutions are supported, but this can be extended +// in the future to include backward data and backward weight convolutions +template +concept HasConvTraits = IsXdlFwdConv>; + // Helper metafunctions to convert from ck enums to builder enums /// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. @@ -35,16 +86,15 @@ constexpr auto convert_pipeline_version() { using enum ck::BlockGemmPipelineVersion; using enum builder::PipelineVersion; - if constexpr(ck_ver == v1) - return V1; - else if constexpr(ck_ver == v2) - return V2; - else if constexpr(ck_ver == v3) - return V3; - else if constexpr(ck_ver == v4) - return V4; - else if constexpr(ck_ver == v5) - return V5; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v3: return V3; + case v4: return V4; + case v5: return V5; + } } /// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. @@ -59,14 +109,14 @@ constexpr auto convert_pipeline_version() { using enum ck::PipelineVersion; using enum builder::PipelineVersion; - if constexpr(ck_ver == v1) - return V1; - else if constexpr(ck_ver == v2) - return V2; - else if constexpr(ck_ver == v4) - return V4; - else if constexpr(ck_ver == weight_only) - return WEIGHT_ONLY; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v4: return V4; + case weight_only: return WEIGHT_ONLY; + } } /// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. @@ -82,10 +132,12 @@ constexpr auto convert_pipeline_scheduler() { using enum ck::BlockGemmPipelineScheduler; using enum builder::PipelineScheduler; - if constexpr(ck_sched == Intrawave) - return INTRAWAVE; - else if constexpr(ck_sched == Interwave) - return INTERWAVE; + + switch(ck_sched) + { + case Intrawave: return INTRAWAVE; + case Interwave: return INTERWAVE; + } } /// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. @@ -101,10 +153,12 @@ constexpr auto convert_pipeline_scheduler() { using enum ck::LoopScheduler; using enum builder::PipelineScheduler; - if constexpr(ck_sched == Default) - return DEFAULT; - else if constexpr(ck_sched == Interwave) - return INTERWAVE; + + switch(ck_sched) + { + case Default: return DEFAULT; + case Interwave: return INTERWAVE; + } } /// @brief Helper structures for organizing trait data with domain-specific naming @@ -213,21 +267,13 @@ constexpr builder::ConvDirection conv_direction() using InstTraits = InstanceTraits; if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) - { return builder::ConvDirection::FORWARD; - } else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) - { return builder::ConvDirection::BACKWARD_DATA; - } else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) - { return builder::ConvDirection::BACKWARD_WEIGHT; - } else - { return builder::ConvDirection::FORWARD; // Default fallback - } } /// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. @@ -242,60 +288,52 @@ constexpr auto conv_spec() if constexpr(requires { InstTraits::kConvForwardSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + using enum builder::ConvFwdSpecialization; - if constexpr(InstTraits::kConvForwardSpecialization == Default) + switch(InstTraits::kConvForwardSpecialization) { - return builder::ConvFwdSpecialization::DEFAULT; - } - else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0) - { - return builder::ConvFwdSpecialization::FILTER_1X1_PAD0; - } - else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0) - { - return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0; - } - else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3) - { - return builder::ConvFwdSpecialization::FILTER_3x3; + case Default: return DEFAULT; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter3x3: return FILTER_3x3; } } else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + using enum builder::ConvBwdDataSpecialization; - if constexpr(InstTraits::kConvBwdDataSpecialization == Default) + switch(InstTraits::kConvBwdDataSpecialization) { - return builder::ConvBwdDataSpecialization::DEFAULT; - } - else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0) - { - return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0; + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; } } else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + using enum builder::ConvBwdWeightSpecialization; - if constexpr(InstTraits::kConvBwdWeightSpecialization == Default) + switch(InstTraits::kConvBwdWeightSpecialization) { - return builder::ConvBwdWeightSpecialization::DEFAULT; - } - else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0) - { - return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0; - } - else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0) - { - return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0; - } - else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC) - { - return builder::ConvBwdWeightSpecialization::ODD_C; + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case OddC: return ODD_C; } } } +// Helper variable template to check if CK layout enums match +template +inline constexpr bool layouts_are = + std::is_same_v && std::is_same_v && std::is_same_v; + /// @brief Derives the grouped convolution layout from a device kernel `Instance` type. /// @tparam Instance The device kernel instance type. /// @return An std::array corresponding to the tensor layouts: @@ -304,112 +342,49 @@ constexpr auto conv_spec() /// index 2 -> Output layout template constexpr auto conv_layout() + requires HasFwdConvLayouts> { - using InstTraits = InstanceTraits; - using ALayout = typename InstTraits::ALayout; - using BLayout = typename InstTraits::BLayout; - using ELayout = typename InstTraits::ELayout; + // Helper lambda to construct layout array + auto layouts = [](auto... Ls) { return std::array{Ls...}; }; - namespace ctc = ck::tensor_layout::convolution; + using A = typename InstanceTraits::ALayout; + using B = typename InstanceTraits::BLayout; + using E = typename InstanceTraits::ELayout; + namespace ctl = ck::tensor_layout::convolution; + using enum builder::TensorLayout; - if constexpr(InstTraits::kSpatialDim == 1) + switch(InstanceTraits::kSpatialDim) { - if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::GNWC, - builder::TensorLayout::GKXC, - builder::TensorLayout::GNWK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && std::is_same_v) - { - return std::array{builder::TensorLayout::NWGC, - builder::TensorLayout::GKXC, - builder::TensorLayout::NWGK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && std::is_same_v) - { - return std::array{builder::TensorLayout::NGCW, - builder::TensorLayout::GKXC, - builder::TensorLayout::NGKW}; - } - else if constexpr(std::is_same_v && - std::is_same_v && std::is_same_v) - { - return std::array{builder::TensorLayout::NGCW, - builder::TensorLayout::GKCX, - builder::TensorLayout::NGKW}; - } - } - else if constexpr(InstTraits::kSpatialDim == 2) - { - if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::GNHWC, - builder::TensorLayout::GKYXC, - builder::TensorLayout::GNHWK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NHWGC, - builder::TensorLayout::GKYXC, - builder::TensorLayout::NHWGK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCHW, - builder::TensorLayout::GKYXC, - builder::TensorLayout::NGKHW}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCHW, - builder::TensorLayout::GKCYX, - builder::TensorLayout::NGKHW}; - } - } - else if constexpr(InstTraits::kSpatialDim == 3) - { - if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::GNDHWC, - builder::TensorLayout::GKZYXC, - builder::TensorLayout::GNDHWK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NDHWGC, - builder::TensorLayout::GKZYXC, - builder::TensorLayout::NDHWGK}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCDHW, - builder::TensorLayout::GKZYXC, - builder::TensorLayout::NGKDHW}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return std::array{builder::TensorLayout::NGCDHW, - builder::TensorLayout::GKCZYX, - builder::TensorLayout::NGKDHW}; - } + case 1: + if constexpr(layouts_are) + return layouts(GNWC, GKXC, GNWK); + if constexpr(layouts_are) + return layouts(NWGC, GKXC, NWGK); + if constexpr(layouts_are) + return layouts(NGCW, GKXC, NGKW); + if constexpr(layouts_are) + return layouts(NGCW, GKCX, NGKW); + break; + case 2: + if constexpr(layouts_are) + return layouts(GNHWC, GKYXC, GNHWK); + if constexpr(layouts_are) + return layouts(NHWGC, GKYXC, NHWGK); + if constexpr(layouts_are) + return layouts(NGCHW, GKYXC, NGKHW); + if constexpr(layouts_are) + return layouts(NGCHW, GKCYX, NGKHW); + break; + case 3: + if constexpr(layouts_are) + return layouts(GNDHWC, GKZYXC, GNDHWK); + if constexpr(layouts_are) + return layouts(NDHWGC, GKZYXC, NDHWGK); + if constexpr(layouts_are) + return layouts(NGCDHW, GKZYXC, NGKDHW); + if constexpr(layouts_are) + return layouts(NGCDHW, GKCZYX, NGKDHW); + break; } } @@ -418,39 +393,26 @@ constexpr auto conv_layout() /// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32). template constexpr builder::DataType conv_data_type() + requires HasDataTypes> { using InstTraits = InstanceTraits; using ADataType = typename InstTraits::ADataType; + using enum builder::DataType; if constexpr(std::is_same_v) - { - return builder::DataType::FP16; - } + return FP16; else if constexpr(std::is_same_v) - { - return builder::DataType::BF16; - } + return BF16; else if constexpr(std::is_same_v) - { - return builder::DataType::FP32; - } + return FP32; else if constexpr(std::is_same_v) - { - return builder::DataType::FP8; - } + return FP8; else if constexpr(std::is_same_v) - { - return builder::DataType::I8; - } + return I8; else if constexpr(std::is_same_v) - { - return builder::DataType::U8; - } + return U8; else - { - // Default fallback - return builder::DataType::FP32; - } + return FP32; // Default fallback } /// @brief Derives the elementwise operation from op type. @@ -459,27 +421,19 @@ constexpr builder::DataType conv_data_type() template constexpr builder::ElementwiseOperation elementwise_op() { + using enum builder::ElementwiseOperation; constexpr std::string_view name = detail::elementwise_op_name(); + if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) - { - return builder::ElementwiseOperation::BIAS_BNORM_CLAMP; - } - else if constexpr(detail::case_insensitive_equal(name, "Clamp")) - { - return builder::ElementwiseOperation::CLAMP; - } - else if constexpr(detail::case_insensitive_equal(name, "Scale")) - { - return builder::ElementwiseOperation::SCALE; - } - else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) - { - return builder::ElementwiseOperation::PASS_THROUGH; - } - else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) - { - return builder::ElementwiseOperation::SCALEADD_SCALEADD_RELU; - } + return BIAS_BNORM_CLAMP; + if constexpr(detail::case_insensitive_equal(name, "Clamp")) + return CLAMP; + if constexpr(detail::case_insensitive_equal(name, "Scale")) + return SCALE; + if constexpr(detail::case_insensitive_equal(name, "PassThrough")) + return PASS_THROUGH; + if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) + return SCALEADD_SCALEADD_RELU; } /// @brief Derives a gemm padding from a kernel instance type. @@ -487,6 +441,7 @@ constexpr builder::ElementwiseOperation elementwise_op() /// @return A `builder::GemmPadding` enum value corresponding to kernel padding. template constexpr builder::GemmPadding gemm_spec() + requires HasGemmSpec> { using InstTraits = InstanceTraits; using enum builder::GemmPadding; @@ -494,69 +449,24 @@ constexpr builder::GemmPadding gemm_spec() constexpr auto gemm_spec = InstTraits::kGemmSpecialization; - if constexpr(gemm_spec == Default) + switch(gemm_spec) { - return DEFAULT; - } - else if constexpr(gemm_spec == MPadding) - { - return M_PADDING; - } - else if constexpr(gemm_spec == NPadding) - { - return N_PADDING; - } - else if constexpr(gemm_spec == KPadding) - { - return K_PADDING; - } - else if constexpr(gemm_spec == MNPadding) - { - return MN_PADDING; - } - else if constexpr(gemm_spec == MKPadding) - { - return MK_PADDING; - } - else if constexpr(gemm_spec == NKPadding) - { - return NK_PADDING; - } - else if constexpr(gemm_spec == MNKPadding) - { - return MNK_PADDING; - } - else if constexpr(gemm_spec == OPadding) - { - return O_PADDING; - } - else if constexpr(gemm_spec == MOPadding) - { - return MO_PADDING; - } - else if constexpr(gemm_spec == NOPadding) - { - return NO_PADDING; - } - else if constexpr(gemm_spec == KOPadding) - { - return KO_PADDING; - } - else if constexpr(gemm_spec == MNOPadding) - { - return MNO_PADDING; - } - else if constexpr(gemm_spec == MKOPadding) - { - return MKO_PADDING; - } - else if constexpr(gemm_spec == NKOPadding) - { - return NKO_PADDING; - } - else if constexpr(gemm_spec == MNKOPadding) - { - return MNKO_PADDING; + case Default: return DEFAULT; + case MPadding: return M_PADDING; + case NPadding: return N_PADDING; + case KPadding: return K_PADDING; + case MNPadding: return MN_PADDING; + case MKPadding: return MK_PADDING; + case NKPadding: return NK_PADDING; + case MNKPadding: return MNK_PADDING; + case OPadding: return O_PADDING; + case MOPadding: return MO_PADDING; + case NOPadding: return NO_PADDING; + case KOPadding: return KO_PADDING; + case MNOPadding: return MNO_PADDING; + case MKOPadding: return MKO_PADDING; + case NKOPadding: return NKO_PADDING; + case MNKOPadding: return MNKO_PADDING; } } @@ -571,6 +481,7 @@ struct ConvTraits; /// set of traits directly from a fully-formed device kernel `Instance` type. /// It uses `InstanceTraits` to access the kernel's template parameters. template + requires IsXdlFwdConv> struct ConvTraits { using InstTraits = InstanceTraits; diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 64996f96f7..1055cbc038 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -8,28 +8,30 @@ #pragma once #include -#include -#include -#include -#include -#include -#include #include -#include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ck_tile/ops/epilogue.hpp" +#include +#include +#include +#include +#include +#include +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/pipeline_enum.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 565bb98528..532d8a1882 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -145,6 +145,15 @@ enum struct GemmSpecialization MNKOPadding }; +// Enums for the CK Tile convolution specialization. +enum class TileConvSpecialization +{ + DEFAULT, + FILTER_1X1_PAD0, + FILTER_1X1_STRIDE1_PAD0, + FILTER_3x3 +}; + // Enums for the forward convolution specialization. enum class ConvFwdSpecialization { diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index a340a789de..eef1110d27 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -90,7 +90,7 @@ add_ck_builder_test(test_ckb_conv_builder # Tests convolution trait selection and configuration add_ck_builder_test(test_ckb_conv_traits - conv/test_conv_traits.cpp) + conv/ck/test_conv_traits.cpp) # Tests convolution problem description and parameter handling add_ck_builder_test(test_ckb_conv_description @@ -119,19 +119,22 @@ add_ck_builder_test(test_ckb_instance_string # Tests the forward convolution builder across multiple data types and dimensions. # Individual tests are split into separate files to enable parallel compilation. add_ck_builder_test(test_ckb_build_fwd_instances - conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp - conv/test_ckb_conv_fwd_1d_fp16.cpp - conv/test_ckb_conv_fwd_1d_bf16.cpp - conv/test_ckb_conv_fwd_1d_i8.cpp - conv/test_ckb_conv_fwd_2d_fp8.cpp - conv/test_ckb_conv_fwd_2d_bf16.cpp - conv/test_ckb_conv_fwd_2d_fp16.cpp - conv/test_ckb_conv_fwd_2d_fp32.cpp - conv/test_ckb_conv_fwd_2d_dl_fp16.cpp - conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp - conv/test_ckb_conv_fwd_3d_bf16.cpp - conv/test_ckb_conv_fwd_3d_fp16.cpp - conv/test_ckb_conv_fwd_3d_fp32.cpp + conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp + conv/ck/test_ckb_conv_fwd_1d_fp16.cpp + conv/ck/test_ckb_conv_fwd_1d_bf16.cpp + conv/ck/test_ckb_conv_fwd_1d_i8.cpp + conv/ck/test_ckb_conv_fwd_2d_fp8.cpp + conv/ck/test_ckb_conv_fwd_2d_bf16.cpp + conv/ck/test_ckb_conv_fwd_2d_fp16.cpp + conv/ck/test_ckb_conv_fwd_2d_fp32.cpp + conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp + conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp + conv/ck/test_ckb_conv_fwd_3d_bf16.cpp + conv/ck/test_ckb_conv_fwd_3d_fp16.cpp + conv/ck/test_ckb_conv_fwd_3d_fp32.cpp + conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp + conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp + conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp ) diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_1d_i8.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp8.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp similarity index 100% rename from experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp rename to experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp similarity index 100% rename from experimental/builder/test/conv/test_conv_traits.cpp rename to experimental/builder/test/conv/ck/test_conv_traits.cpp diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp new file mode 100644 index 0000000000..ad31fc52bc --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::BACKWARD_DATA, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_backward_data", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp new file mode 100644 index 0000000000..47908e0e5b --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::BACKWARD_WEIGHT, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_backward_weight", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp new file mode 100644 index 0000000000..083d9d9955 --- /dev/null +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_tile_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" + +namespace { + +using namespace ck_tile::builder::test_utils; + +TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +{ + constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, + .direction = ConvDirection::FORWARD, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; + + constexpr auto FwdConvAlgorithm = + ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(TileConvSpecialization::DEFAULT) + .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_optimizations(TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + + using Builder = ConvBuilder; + run_ck_tile_test({ + "grouped_convolution_forward", + "fp16", + "NHWGC_GKYXC_NHWGK", + "64x64x64", + "2x2", + "16x16x16", + // "4x4x4", // TODO: Enable this check + "Default", + "Intrawave", + "CShuffleEpilogue", + "set", + "pipeline_AgBgCrCompV3", + "DoubleSmemBuffer_0", + "NumWaveGroups_1", + "MergedGroups_1", + "SplitImage_0", + "ExplicitGemm_0", + }); +} + +} // namespace diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index d89d83357f..29c7f3cdcc 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -243,6 +243,73 @@ struct LargeTensorWrapper ConvAlgorithmSpecialization::LARGE_TENSOR; }; +// Specify thread block dimensions for a GEMM (CK Tile). +struct TileThreadBlock +{ + // Size of the submatrix problem in a thread block. + MNK tile_size; +}; +static_assert(ckb::TileThreadBlockDescriptor); + +struct TileTransfer +{ + size_t a_scalar_per_vector; + size_t b_scalar_per_vector; + size_t c_scalar_per_vector; +}; +static_assert(ckb::TileTransferDescriptor); + +struct TileBlockGemm +{ + // Number of warps per each dimension. + MNK warps; + // Number of data processed per each dimension for each XDL/WMMA instruction. + MNK warp_tile; + // Double LDS buffer. + bool double_smem_buffer; + // Waves grouping (Ping-Pong scheduler). + int num_wave_groups; + PipelineVersion pipeline_version; + PipelineScheduler scheduler; +}; +static_assert(ckb::TileBlockGemmDescriptor); + +struct TileOptimizations +{ + // Number of convolution groups processed per one workgroup + int num_groups_to_merge; + // Split image for large tensors + bool split_image; + // Explicit gemm for 1x1, stride=0, pad=0 cases + bool explicit_gemm; +}; +static_assert(ckb::TileOptimizationsDescriptor); + +struct TileConvSpecialization_ +{ + TileConvSpecialization specialization; +}; + +struct TileThreadBlock_ +{ + TileThreadBlock thread_block; +}; + +struct TileTransfer_ +{ + TileTransfer transfer; +}; + +struct TileBlockGemm_ +{ + TileBlockGemm block_gemm; +}; + +struct TileOptimizations_ +{ + TileOptimizations optimizations; +}; + // Factory template @@ -339,6 +406,51 @@ struct ConvAlgorithmTemplate : Components... result.transfer = t; return result; } + + template + constexpr auto with_tile_specializations(const S& s) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.specialization = s; + return result; + } + + template + constexpr auto with_tile_thread_block(const TB& tb) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.thread_block = tb; + return result; + } + + template + constexpr auto with_tile_block_gemm(const BG& bg) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.block_gemm = bg; + return result; + } + + template + constexpr auto with_tile_transfer(const T& t) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.transfer = t; + return result; + } + + template + constexpr auto with_tile_optimizations(const O& o) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.optimizations = o; + return result; + } }; // Algorithm types @@ -361,4 +473,10 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = LargeTensorWrapper; +using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate; + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/test_bwd_data_instance_traits.cpp b/experimental/builder/test/test_bwd_data_instance_traits.cpp index 80e8ae8d98..f26b5d7caf 100644 --- a/experimental/builder/test/test_bwd_data_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_data_instance_traits.cpp @@ -2,9 +2,10 @@ // SPDX-License-Identifier: MIT #include -#include -#include -#include +#include "ck/ck.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" namespace { diff --git a/experimental/builder/test/test_bwd_weight_instance_traits.cpp b/experimental/builder/test/test_bwd_weight_instance_traits.cpp index 9b3cd169bb..c7c4e370e2 100644 --- a/experimental/builder/test/test_bwd_weight_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_weight_instance_traits.cpp @@ -2,10 +2,12 @@ // SPDX-License-Identifier: MIT #include -#include -#include -#include -#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" namespace { diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index 6a8f1f14e3..396533cef4 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -1,17 +1,19 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" namespace { diff --git a/experimental/builder/test/test_instance_traits_util.cpp b/experimental/builder/test/test_instance_traits_util.cpp index 42810ace72..852174b805 100644 --- a/experimental/builder/test/test_instance_traits_util.cpp +++ b/experimental/builder/test/test_instance_traits_util.cpp @@ -1,16 +1,16 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include #include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/sequence.hpp" +#include "ck_tile/builder/reflect/instance_traits_util.hpp" namespace ck_tile::reflect::detail { namespace { diff --git a/experimental/builder/test/unit_conv_elementwise_op.cpp b/experimental/builder/test/unit_conv_elementwise_op.cpp index 84a9c533f6..610edd281e 100644 --- a/experimental/builder/test/unit_conv_elementwise_op.cpp +++ b/experimental/builder/test/unit_conv_elementwise_op.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp index 7764e94dc6..26df33cc8d 100644 --- a/experimental/builder/test/unit_conv_tensor_layout.cpp +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" #include "impl/conv_signature_types.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tensor_type.cpp b/experimental/builder/test/unit_conv_tensor_type.cpp index c92b24626e..7ffd446966 100644 --- a/experimental/builder/test/unit_conv_tensor_type.cpp +++ b/experimental/builder/test/unit_conv_tensor_type.cpp @@ -4,7 +4,7 @@ #include #include -#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_thread_block.cpp b/experimental/builder/test/unit_conv_thread_block.cpp index f829708696..ce5a772cfa 100644 --- a/experimental/builder/test/unit_conv_thread_block.cpp +++ b/experimental/builder/test/unit_conv_thread_block.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include -#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" namespace { diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index 82117c53d8..b35a1ced55 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -3,7 +3,7 @@ #include -#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" namespace { diff --git a/experimental/builder/test/utils/ckb_conv_test_utils.hpp b/experimental/builder/test/utils/ckb_conv_test_utils.hpp index 508c621c2e..1acf170455 100644 --- a/experimental/builder/test/utils/ckb_conv_test_utils.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_utils.hpp @@ -28,4 +28,20 @@ constexpr void run_test(const std::vector& kernel_instance_componen } } +// Common CK Tile test implementation +template +constexpr void run_ck_tile_test(const std::vector& kernel_instance_components) +{ + auto instance = typename Builder::Instance{}; + + const auto kernel_string = instance.GetTypeString(); + std::cout << "Generated kernel: " << kernel_string << std::endl; + EXPECT_GT(kernel_string.size(), 0); + std::cout << kernel_string << std::endl; + for(const auto& component : kernel_instance_components) + { + EXPECT_THAT(kernel_string, ::testing::HasSubstr(component)); + } +} + } // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp new file mode 100644 index 0000000000..377234dd19 --- /dev/null +++ b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "impl/conv_algorithm_types.hpp" +#include "impl/conv_signature_types.hpp" +#include "ck_tile/builder/conv_builder.hpp" + +namespace ck_tile::builder::test_utils { + +using namespace ck_tile::builder; +using namespace test; + +constexpr TileTransfer FwdTileTransfer_1x1x1{ + .a_scalar_per_vector = 1, + .b_scalar_per_vector = 1, + .c_scalar_per_vector = 1, +}; + +constexpr TileTransfer FwdTileTransfer_4x4x4{ + .a_scalar_per_vector = 4, + .b_scalar_per_vector = 4, + .c_scalar_per_vector = 4, +}; + +constexpr TileTransfer FwdTileTransfer_8x8x8{ + .a_scalar_per_vector = 8, + .b_scalar_per_vector = 8, + .c_scalar_per_vector = 8, +}; + +constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}}; + +constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}}; + +constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V1, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v2_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V2, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v3_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V3, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v4_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V4, + .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr TileBlockGemm TileBlockGemmDesc_16x16_v5_intrawave = { + .warps = {.m = 2, .n = 2, .k = 1}, + .warp_tile = {.m = 16, .n = 16, .k = 16}, + .double_smem_buffer = false, + .num_wave_groups = 1, + .pipeline_version = PipelineVersion::V5, + .scheduler = PipelineScheduler::INTRAWAVE}; + +} // namespace ck_tile::builder::test_utils diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 306a6c2ff1..113bf99243 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -55,6 +55,11 @@ #ifndef CK_ENABLE_FP32 #define CK_ENABLE_FP32 "ON" #endif +#ifndef CK_ENABLE_TF32 +#if defined(__gfx942__) || defined(__gfx95__) +#define CK_ENABLE_TF32 "ON" +#endif +#endif #ifndef CK_ENABLE_FP64 #define CK_ENABLE_FP64 "ON" #endif @@ -85,6 +90,12 @@ #cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@ #endif +#ifndef CK_ENABLE_TF32 +#if defined(__gfx942__) || defined(__gfx95__) +#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@ +#endif +#endif + #ifndef CK_ENABLE_FP64 #cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@ #endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 8d45b8fd74..751608299c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -5,24 +5,16 @@ #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include -#include #endif +#include "ck/utility/pipeline_enum.hpp" +#include "ck/utility/loop_scheduler.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp" namespace ck { -enum struct PipelineVersion -{ - v1, - v2, - // v3 is only used in the Stream-K implementation. - v4, - weight_only, -}; - template Prefetch stages, number of loop is multiple of unroll stages - Empty, - // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add - // prefetchstages - Full, -}; - enum SchedulerGroup : uint32_t { SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions diff --git a/include/ck/utility/loop_scheduler.hpp b/include/ck/utility/loop_scheduler.hpp index f186d0fea9..b3303e1138 100644 --- a/include/ck/utility/loop_scheduler.hpp +++ b/include/ck/utility/loop_scheduler.hpp @@ -3,40 +3,20 @@ #pragma once -#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -#include -#endif - #include "ck/utility/common_header.hpp" +#include "ck/utility/scheduler_enum.hpp" namespace ck { -enum struct LoopScheduler -{ - Default, - Interwave, -}; - +/// @brief Helper function to get default loop scheduler +/// @details Returns the default loop scheduler based on compile-time configuration. constexpr LoopScheduler make_default_loop_scheduler() { #if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING return LoopScheduler::Interwave; #else return LoopScheduler::Default; -#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING +#endif } } // namespace ck - -#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) -{ - switch(s) - { - case ck::LoopScheduler::Default: os << "Default"; break; - case ck::LoopScheduler::Interwave: os << "Interwave"; break; - default: os << ""; - } - return os; -} -#endif diff --git a/include/ck/utility/pipeline_enum.hpp b/include/ck/utility/pipeline_enum.hpp new file mode 100644 index 0000000000..4421386f59 --- /dev/null +++ b/include/ck/utility/pipeline_enum.hpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#include +#endif + +namespace ck { + +/// @brief Pipeline version enumeration for GEMM kernels +/// @details Defines different pipeline strategies for data movement and computation overlap +/// in GEMM kernels. This is a lightweight header containing only the enum definition, +/// extracted from gridwise_gemm_pipeline_selector.hpp to minimize dependencies. +enum struct PipelineVersion +{ + v1, ///< Version 1 pipeline + v2, ///< Version 2 pipeline + // v3 is only used in the Stream-K implementation. + v4, ///< Version 4 pipeline + weight_only, ///< Weight-only specialized pipeline +}; + +} // namespace ck + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) +{ + switch(p) + { + case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break; + case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break; + case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break; + case ck::PipelineVersion::weight_only: os << "PipelineVersion::weight_only"; break; + default: os << ""; + } + return os; +} +#endif diff --git a/include/ck/utility/scheduler_enum.hpp b/include/ck/utility/scheduler_enum.hpp new file mode 100644 index 0000000000..0c4bfabaf3 --- /dev/null +++ b/include/ck/utility/scheduler_enum.hpp @@ -0,0 +1,83 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#include +#endif + +namespace ck { + +/// @brief Block GEMM pipeline version enumeration +/// @details Defines different block GEMM pipeline strategies. +/// This is a lightweight header containing only enum definitions, +/// extracted from blkgemmpipe_scheduler.hpp to minimize dependencies. +enum struct BlockGemmPipelineVersion +{ + // For GEMM + v1, ///< Naive pipeline + v2, ///< Memory-optimized pipeline + v3, ///< Compute-optimized pipeline + v4, ///< Compute-optimized with double LDS buffer + v5, ///< Compute-optimized with double global prefetch register buffer + + // For GEMM with preshuffled weight + // v1, single lds buffer + // v2, double lds buffer +}; + +/// @brief Block GEMM pipeline scheduler enumeration +/// @details Defines scheduling strategies for block GEMM pipelines. +enum struct BlockGemmPipelineScheduler +{ + Intrawave, ///< Schedule within a single wavefront + Interwave, ///< Schedule across multiple wavefronts +}; + +/// @brief Loop scheduler enumeration +/// @details Defines scheduling strategies for computational loops. +enum struct LoopScheduler +{ + Default, ///< Default scheduling strategy + Interwave, ///< Cross-wavefront scheduling +}; + +/// @brief Tail number enumeration for pipeline buffering +/// @details Defines the number of tail iterations in pipelined loops. +enum struct TailNumber +{ + // Single / Double buffer pipeline + Odd, ///< Odd number of iterations + Even, ///< Even number of iterations + + // Long prefetch pipeline, up to 8 + One, ///< One tail iteration + Two, ///< Two tail iterations + Three, ///< Three tail iterations + Four, ///< Four tail iterations + Five, ///< Five tail iterations + Six, ///< Six tail iterations + Seven, ///< Seven tail iterations + + // Unroll stages > Prefetch stages, number of loop is multiple of unroll stages + Empty, ///< No tail iterations + // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add + // prefetchstages + Full, ///< Full tail iterations +}; + +} // namespace ck + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) +{ + switch(s) + { + case ck::LoopScheduler::Default: os << "Default"; break; + case ck::LoopScheduler::Interwave: os << "Interwave"; break; + default: os << ""; + } + return os; +} +#endif diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index 81eea60c2f..29a7e2593e 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -1552,6 +1552,81 @@ CK_TILE_HOST_DEVICE static void print(const indexing& printf("}"); } +template +struct functor_transform : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + Functor functor_; + UpLengths up_lengths_; + + CK_TILE_HOST_DEVICE constexpr functor_transform() = default; + + CK_TILE_HOST_DEVICE constexpr functor_transform(const Functor& functor, + const LowLength& low_length) + : functor_{functor}, up_lengths_{make_tuple(low_length)} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = functor_(idx_up[number<0>{}]); + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& up_idx) const + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + const auto idx_low_old = idx_low; + calculate_lower_index(idx_low, up_idx); + idx_diff_low = idx_low - idx_low_old; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + // Note: When using functor_transform, ensure that the transformed coordinates + // are always valid for vectorized load/store operations. + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + return make_tuple(low_vector_lengths, low_vector_strides); + } +}; + //******************************************************************************************************* template @@ -1671,6 +1746,13 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le return offset{low_length, offset_length}; } +template +CK_TILE_HOST_DEVICE constexpr auto make_functor_transform(const Functor& functor, + const LowLength& low_length) +{ + return functor_transform{functor, low_length}; +} + } // namespace ck_tile #include "ck_tile/core/algorithm/indexing_adaptor.hpp" diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index de97b46336..678a2fbfff 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -357,6 +357,12 @@ struct amdgcn_compiler_target_state #endif // __gfx950__ // GFX10 +#if defined(__gfx1010__) + static constexpr bool CK_TILE_ARCH_GFX1010 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1010 = false; +#endif + #if defined(__gfx1030__) static constexpr bool CK_TILE_ARCH_GFX1030 = true; #else @@ -493,6 +499,7 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se amdgcn_compiler_target_state::CK_TILE_ARCH_GFX90A, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \ diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 97a44f38e8..7a4da64c4a 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -533,7 +533,8 @@ struct tile_scatter_gather size_per_buf; const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - m0_set_with_memory(m0_init_value); // This should be wave independent + m0_set_with_memory( + amd_wave_read_first_lane(m0_init_value)); // This should be wave independent using Traits = load_store_traits; diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 770bca4a4c..6bfbbb09dd 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -1263,7 +1263,9 @@ struct tile_window_with_static_lengths } }; -template +template >> CK_TILE_DEVICE constexpr auto make_tile_window(const TensorView_& tensor_view, const WindowLengths_& window_lengths, @@ -1310,7 +1312,10 @@ make_tile_window(const tile_window_with_static_lengths +template >> CK_TILE_DEVICE constexpr auto make_tile_window(const tile_window_with_static_lengths& tile_window, const StaticTileDistribution& tile_distribution, diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 815c1bf158..6c84122d01 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -517,7 +517,8 @@ struct tile_window_linear size_per_buf; const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - m0_set_with_memory(m0_init_value); // This should be wave independent + m0_set_with_memory( + amd_wave_read_first_lane(m0_init_value)); // This should be wave independent using vector_t = typename Base::Traits::vector_t; diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 12f43ebc5e..4bbf8cbf3f 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -33,59 +33,73 @@ namespace ck_tile { * @example * * // Direct usage without creating a separate variable: - * ck_tile::FillUniformDistribution{-1.f, 1.f}(a_host_tensor); + * ck_tile::FillUniformDistribution<>{-1.f, 1.f}(a_host_tensor); */ -template +template struct FillUniformDistribution { float a_{-5.f}; float b_{5.f}; std::optional seed_{11939}; - // ATTENTION: Whether to use multi-threading (note: not guaranteed to be perfectly distributed - // across threads). - bool threaded = false; template void operator()(ForwardIter first, ForwardIter last) const { - if(threaded) - { - uint32_t num_thread = std::thread::hardware_concurrency(); - auto total = static_cast(std::distance(first, last)); - auto work_per_thread = static_cast((total + num_thread - 1) / num_thread); + if(first == last) + return; + using T_iter = std::decay_t; + static_assert(std::is_same_v || std::is_void_v, + "Iterator value type must match template type T"); + constexpr auto PackedSize = numeric_traits::PackedSize; + const auto total = static_cast(std::distance(first, last)); + const auto total_bytes = total * sizeof(T_iter); - std::vector threads(num_thread); - for(std::size_t it = 0; it < num_thread; ++it) - { - std::size_t iw_begin = it * work_per_thread; - std::size_t iw_end = std::min((it + 1) * work_per_thread, total); - auto thread_f = [this, total, iw_begin, iw_end, &first] { - if(iw_begin > total || iw_end > total) - return; - // need to make each thread unique, add an offset to current seed - std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin) - : std::random_device{}()); - std::uniform_real_distribution dis(a_, b_); - std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { - if constexpr(numeric_traits::PackedSize == 2) - return ck_tile::type_convert(fp32x2_t{dis(gen), dis(gen)}); - else - return ck_tile::type_convert(dis(gen)); - }); - }; - threads[it] = joinable_thread(thread_f); - } - } - else + // max 80 threads; at least 2MB per thread + const size_t available_cpu_cores = get_available_cpu_cores(); + const size_t num_thread = + min(80UL, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL)); + constexpr size_t BLOCK_BYTES = 64; + constexpr size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T_iter); + const size_t num_blocks = integer_divide_ceil(total_bytes, BLOCK_BYTES); + const size_t blocks_per_thread = integer_divide_ceil(num_blocks, num_thread); + + // use minstd_rand for better performance on discard() + std::minstd_rand gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::uniform_real_distribution dis(a_, b_); + + std::vector threads; + threads.reserve(num_thread - 1); // last job run in the main thread + for(int it = num_thread - 1; it >= 0; --it) { - std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); - std::uniform_real_distribution dis(a_, b_); - std::generate(first, last, [&dis, &gen]() { - if constexpr(numeric_traits::PackedSize == 2) - return ck_tile::type_convert(fp32x2_t{dis(gen), dis(gen)}); - else - return ck_tile::type_convert(dis(gen)); - }); + const size_t ib_begin = it * blocks_per_thread; + const size_t ib_end = min(ib_begin + blocks_per_thread, num_blocks); + + auto job = [=]() { + auto g_ = gen; // copy + auto d_ = dis; // copy + g_.discard(ib_begin * BLOCK_SIZE * PackedSize); + auto t_fn = [&]() { + if constexpr(PackedSize == 2) + return type_convert(fp32x2_t{d_(g_), d_(g_)}); + else + return type_convert(d_(g_)); + }; + + size_t ib = ib_begin; + for(; ib < ib_end - 1; ++ib) // full blocks + static_for<0, BLOCK_SIZE, 1>{}([&](auto iw_) { + constexpr size_t iw = iw_.value; + *(first + ib * BLOCK_SIZE + iw) = t_fn(); + }); + for(size_t iw = 0; iw < BLOCK_SIZE; ++iw) // last block + if(ib * BLOCK_SIZE + iw < total) + *(first + ib * BLOCK_SIZE + iw) = t_fn(); + }; + + if(it > 0) + threads.emplace_back(std::move(job)); + else + job(); // last job run in the main thread } } diff --git a/include/ck_tile/host/joinable_thread.hpp b/include/ck_tile/host/joinable_thread.hpp index bf84858ee2..b2e1fc4dac 100644 --- a/include/ck_tile/host/joinable_thread.hpp +++ b/include/ck_tile/host/joinable_thread.hpp @@ -3,6 +3,9 @@ #pragma once +#ifdef __linux__ +#include +#endif #include #include @@ -24,4 +27,50 @@ struct joinable_thread : std::thread this->join(); } }; + +inline unsigned int get_available_cpu_cores() +{ +#if defined(__linux__) + cpu_set_t cpu_set; + if(sched_getaffinity(0, sizeof(cpu_set_t), &cpu_set) == 0) + { + unsigned int cpu_count = CPU_COUNT(&cpu_set); + if(cpu_count > 0) + return cpu_count; + } +#endif + // Fallback if sched_getaffinity unavailable or fails + return std::thread::hardware_concurrency(); +} + +class cpu_core_guard +{ +#if defined(__linux__) + cpu_set_t original_cpu_set_; + + public: + cpu_core_guard(unsigned int num_cores) : original_cpu_set_() + { + // save original cpu set + sched_getaffinity(0, sizeof(cpu_set_t), &original_cpu_set_); + + // set new cpu set + cpu_set_t new_cpu_set; + CPU_ZERO(&new_cpu_set); + for(unsigned int i = 0; i < num_cores; ++i) + { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + CPU_SET(i, &new_cpu_set); // NOLINT(old-style-cast) +#pragma clang diagnostic pop + } + sched_setaffinity(0, sizeof(cpu_set_t), &new_cpu_set); + } + ~cpu_core_guard() + { + // restore original cpu set + sched_setaffinity(0, sizeof(cpu_set_t), &original_cpu_set_); + } +#endif +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index b3b34a6da0..7104547247 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1259,12 +1259,12 @@ struct MoeFlatmmKernel auto fused_token = kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0] - index_t scatter_token_id = fused_token & token_id_mask; + index_t scatter_token_id = fused_token & token_id_mask; + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); if constexpr(IsInputGemm) scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; - c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); }); }); diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index d9fb144176..1133da33ad 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -18,21 +18,21 @@ struct MXFlatmmKernel : FlatmmKernel; - using TilePartitioner = remove_cvref_t; - using FlatmmPipeline = remove_cvref_t; + using TilePartitioner = remove_cvref_t; + using MXFlatmmPipeline = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; using DsLayout = remove_cvref_t; using DsDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; - static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel; + static constexpr index_t KernelBlockSize = MXFlatmmPipeline::BlockSize; + static constexpr bool UsePersistentKernel = MXFlatmmPipeline::UsePersistentKernel; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. using EDataType = remove_cvref_t; @@ -43,9 +43,9 @@ struct MXFlatmmKernel : FlatmmKernel::PackedSize; static constexpr int BPackedSize = numeric_traits::PackedSize; - static constexpr int MXdlPack = FlatmmPipeline::MXdlPack; - static constexpr int NXdlPack = FlatmmPipeline::NXdlPack; - static constexpr int KXdlPack = FlatmmPipeline::KXdlPack; + static constexpr int MXdlPack = MXFlatmmPipeline::MXdlPack; + static constexpr int NXdlPack = MXFlatmmPipeline::NXdlPack; + static constexpr int KXdlPack = MXFlatmmPipeline::KXdlPack; static constexpr index_t NumDTensor = DsDataType::size(); @@ -63,7 +63,7 @@ struct MXFlatmmKernel : FlatmmKernel, FlatmmPipeline::GetName()); + return concat('_', "mx_flatmm_gemm", gemm_prec_str, MXFlatmmPipeline::GetName()); // clang-format on } @@ -123,33 +123,23 @@ struct MXFlatmmKernel : FlatmmKernel) - { - return make_naive_tensor_view( - a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } + static_assert(std::is_same_v, + "A tensor for mx must be RowMajor"); + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); }(); - constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock; + constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock; constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1); constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile; const index_t kFlatKBlocks = kargs.K / kKPerBlock; const index_t kFlatN = kargs.N / kNWarpTile; const auto& b_flat_tensor_view = [&]() { - static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0, + static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0, "wrong! vector size for B tensor"); auto&& naive_desc = make_naive_tensor_descriptor_packed( make_tuple(kFlatN, kFlatKBlocks, number{})); @@ -262,20 +252,12 @@ struct MXFlatmmKernel : FlatmmKernel) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } + static_assert(std::is_same_v, + "A tensor for mx must be RowMajor"); + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); }(); const auto& b_flat_tensor_view = views.at(I1); @@ -289,14 +271,14 @@ struct MXFlatmmKernel : FlatmmKernel{}, number{}), - sequence{}); + sequence{}); } else { return pad_tensor_view(d_tensor_view[i], make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }, number{}); @@ -309,14 +291,14 @@ struct MXFlatmmKernel : FlatmmKernel{}, number{}), - sequence{}); + sequence{}); } else { return pad_tensor_view(e_tensor_view, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }(); @@ -334,26 +316,18 @@ struct MXFlatmmKernel : FlatmmKernel) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } + static_assert(std::is_same_v, + "A tensor for mx must be RowMajor"); + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); }(); const auto& b_flat_block_window = make_tile_window(b_flat_pad_view, - make_tuple(number{}, - number{}), + make_tuple(number{}, + number{}), {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); const auto ds_block_window = generate_tuple( @@ -444,14 +418,14 @@ struct MXFlatmmKernel : FlatmmKernel(kargs.a_ptr) + - splitk_batch_offset.a_k_split_offset / APackedSize; - const BDataType* b_flat_ptr = static_cast(kargs.b_ptr) + - splitk_batch_offset.b_k_split_offset / BPackedSize; + const auto a_ptr = static_cast(kargs.a_ptr) + + splitk_batch_offset.a_k_split_offset / APackedSize; + const auto b_flat_ptr = static_cast(kargs.b_ptr) + + splitk_batch_offset.b_k_split_offset / BPackedSize; EDataType* e_ptr = static_cast(kargs.e_ptr); // allocate LDS @@ -501,7 +475,7 @@ struct MXFlatmmKernel : FlatmmKernel::value)) { - constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); + constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1); RunFlatmm(a_ptr, b_flat_ptr, kargs.ds_ptr, diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index ff799cb0fc..87ae7f57d8 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -34,13 +34,11 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem; using CLayout = remove_cvref_t; + static constexpr index_t APackedSize = numeric_traits::PackedSize; + static constexpr index_t BPackedSize = numeric_traits::PackedSize; + using BlockFlatmm = remove_cvref_t())>; @@ -81,8 +82,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1::PackedSize; - static constexpr index_t BPackedSize = numeric_traits::PackedSize; + // static constexpr index_t WG_AKPacks = WG::kK / APackedSize; + // static constexpr index_t WG_BKPacks = WG::kK / BPackedSize; static constexpr index_t MXdlPack = Problem::MXdlPack; static constexpr index_t NXdlPack = Problem::NXdlPack; static constexpr index_t KXdlPack = Problem::KXdlPack; static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK; - static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize; - static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize; + static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType); + static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType); static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) ? DsReadPreload @@ -562,11 +563,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeMX_BFlatDramTileDistribution()); + auto b_flat_dram_window = PipelinePolicy::template MakeMX_BFlatBytesDramWindow( + b_flat_dram_block_window_tmp); auto b_flat_dram_offsets = generate_tuple( [&](auto nIter) { constexpr auto packed_n_idx = nIter / number{}; @@ -621,7 +619,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, true_type{}, false_type{}); + async_load_tile(lds, dram, number<-1>{}, true_type{}, true_type{}); }; // HEAD @@ -633,11 +631,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); }); // move B window to next flat K b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); + tuple, number>{}); }); // prefetch Scale A @@ -698,12 +697,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); + tuple, number>{}); }); }); @@ -739,8 +738,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_ping(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) @@ -792,12 +793,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); + tuple, number>{}); }); }); @@ -833,8 +834,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_pong(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_pong(number{})(number{})), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) @@ -897,7 +900,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); }); }); @@ -932,8 +935,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_ping(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) @@ -986,8 +991,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_pong(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_pong(number{})(number{})), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) @@ -1029,8 +1036,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_ping(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 969cddf3e7..4d76ab7da2 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -255,9 +255,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; + using TileShape = typename Problem::BlockGemmShape; + using BDataType = remove_cvref_t; + constexpr index_t BPack = numeric_traits::PackedSize; static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16"); @@ -282,21 +284,56 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tile_distribution_encoding< // sequence, tuple, // 4 2 - sequence>, // 1 64 32 + sequence>, // 1 64 32 tuple, sequence<2>>, tuple, sequence<1>>, sequence<2>, sequence<2>>, tile_distribution_encoding< // sequence, - tuple, // 4 2 - sequence>, // 2 1 64 16 + tuple, // 4 2 + sequence>, // 2 1 64 16 tuple, sequence<2>>, tuple, sequence<2>>, sequence<2, 2>, sequence<0, 3>>>{}); } + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp) + { + + using BDataType = remove_cvref_t; + constexpr auto BPackedSize = numeric_traits::PackedSize; + constexpr auto kKPerBlock = Problem::BlockGemmShape::kK; + constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1); + constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp; + constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp; + + static_assert(std::decay_t::get_num_of_dimension() == 2); + auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); + const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); + constexpr auto flat_k_per_block = kKPerBlock * M_Warp_Tile; + auto&& byte_tensor_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple( + flat_n, flat_k / flat_k_per_block, number{})), + make_tuple(make_pass_through_transform(flat_n), + make_merge_transform_v3_division_mod(make_tuple( + flat_k / flat_k_per_block, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); + auto&& byte_tensor_view = + make_tensor_view(byte_ptr, byte_tensor_desc); + auto&& origin_tmp = window_tmp.get_window_origin(); + return make_tile_window( + byte_tensor_view, + make_tuple(number{}, number{}), + {origin_tmp[0], origin_tmp[1] / BPackedSize}, + MakeMX_BFlatBytesDramTileDistribution()); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() { diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 1a79aebef5..756968871d 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -600,6 +600,19 @@ struct SimplifiedRatioAttentionMask mdiv y_ratio_mdiv; }; +template +struct is_generic_attention_mask : std::false_type +{ +}; + +template +struct is_generic_attention_mask> : std::true_type +{ +}; + +template +static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask::value; + // TODO: prefer use this function in host code // can convert from the FA style left/right to our generic coordinate // if left_size < 0 && right_size = 0, it is normal causal mask diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 38830ee6fe..9890d1f2e4 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -73,54 +73,6 @@ struct FmhaFwdKernel #endif static constexpr std::string_view kPipelineName = FmhaPipeline::name; - // clang-format off - template struct t2s; - template <> struct t2s { static constexpr const char * name = "fp32"; }; - template <> struct t2s { static constexpr const char * name = "fp16"; }; - template <> struct t2s { static constexpr const char * name = "bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8"; }; - template <> struct t2s { static constexpr const char * name = "bf8"; }; - template <> struct t2s { static constexpr const char * name = "fp8bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8fp32"; }; - // clang-format on - - CK_TILE_HOST static std::string GetName() - { - // sync with generate.py - // clang-format off - using bfs = typename FmhaPipeline::BlockFmhaShape; - using g0br = typename bfs::Gemm0BlockWarps; - using g1br = typename bfs::Gemm1BlockWarps; - using g0wt = typename bfs::Gemm0WarpTile; - using g1wt = typename bfs::Gemm1WarpTile; - #define _SS_ std::string - #define _TS_ std::to_string - auto pn = [&] () { - std::string n; - if (kPadSeqLenQ) n += "s"; - if (kPadSeqLenK) n += "sk"; - if (kPadHeadDimQ) n += "d"; - if (kPadHeadDimV) n += "dv"; - return n.empty() ? n : std::string("p") + n; }(); - return - _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" - "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + - _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + - "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + - "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + - (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + - "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + - (QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr::name)) + (kUseTrLoad ? "_trload" : "_ntrload"); - #undef _SS_ - #undef _TS_ - // clang-format on - } - template // to avoid duplicated base class prblem, introduce an template // arg struct FmhaFwdEmptyKargs diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index df17bdd879..f981c54bd8 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -12,6 +12,8 @@ namespace ck_tile { +/// NOTICE: This kernel is a work in progress and is awaiting upcoming compiler fixes and +/// instruction scheduling optimizations. template struct FmhaFwdV3Kernel { @@ -103,8 +105,8 @@ struct FmhaFwdV3Kernel // Optional cumulative sequence length pointers for batch mode // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_k_ptr = nullptr; // [batch+1] }; struct FmhaFwdGroupModeKargs @@ -114,12 +116,13 @@ struct FmhaFwdV3Kernel { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; + const int32_t* seqlen_q_ptr; const int32_t* seqlen_k_ptr; // Optional cumulative padded sequence starts (including PAD tokens) // Used solely to compute memory offsets when sequences are physically padded. - const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1] - const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1] + const int32_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const int32_t* cu_seqlen_k_ptr = nullptr; // [batch+1] }; using Kargs = std::conditional_t; @@ -156,8 +159,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -199,8 +202,8 @@ struct FmhaFwdV3Kernel kargs.batch_stride_lse = batch_stride_lse; } - kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; - kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } @@ -213,6 +216,7 @@ struct FmhaFwdV3Kernel void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + const void* seqlen_q_ptr, const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -232,8 +236,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, - const void* seqstart_padded_q_ptr = nullptr, - const void* seqstart_padded_k_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -258,6 +262,7 @@ struct FmhaFwdV3Kernel {}, // placeholder for lse reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_q_ptr), reinterpret_cast(seqlen_k_ptr)}; if constexpr(kHasMask) @@ -273,30 +278,29 @@ struct FmhaFwdV3Kernel kargs.nhead_stride_lse = nhead_stride_lse; } - kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); - kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t hdim_v) { - // TODO: this may need tuning - if constexpr(kHasMask) + if constexpr(kIsGroupMode) { - return dim3(nhead_, - ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - batch_size_); + return dim3(nhead, + batch_size, + ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1)); } else { - return dim3(nhead_, - ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - batch_size_); + return dim3(nhead, + ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1), + batch_size); } } @@ -344,13 +348,20 @@ struct FmhaFwdV3Kernel // FmhaPipeline::kN1); // assume that num_tile_n1 is always 1 - if constexpr(kHasMask) + if constexpr(kIsGroupMode) { const index_t i_nhead = blockIdx.x; - const index_t i_block = blockIdx.y; - const index_t i_batch = blockIdx.z; + const index_t i_batch = blockIdx.y; + const index_t i_block = blockIdx.z; - return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch); + if constexpr(kHasMask) + { + return ck_tile::make_tuple(gridDim.z - 1 - i_block, 0, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + } } else { @@ -358,7 +369,14 @@ struct FmhaFwdV3Kernel const index_t i_block = blockIdx.y; const index_t i_batch = blockIdx.z; - return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + if constexpr(kHasMask) + { + return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); + } } } @@ -390,32 +408,36 @@ struct FmhaFwdV3Kernel if constexpr(kIsGroupMode) { - // get starting offset for each batch - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; + // Use seqstart_q_ptr and seqstart_k_ptr for physical starts + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; if constexpr(kStoreLSE) { // LSE layout is [nhead, total_seqlen], index by unpadded start - batch_offset_lse = query_start_unpadded; + batch_offset_lse = query_start; } - batch_offset_o = query_start_padded * kargs.stride_o; - - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + batch_offset_o = query_start * kargs.stride_o; + // real logical lengths (exclude PAD) + // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr + if(kargs.seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch]; + } + else if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; + } // # of required blocks is different in each groups, terminate unnecessary blocks // earlier if(kargs.seqlen_q <= i_m0) @@ -427,10 +449,14 @@ struct FmhaFwdV3Kernel { kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } + else if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } else { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; } } else @@ -450,10 +476,10 @@ struct FmhaFwdV3Kernel kargs.seqlen_q = kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; } - if(kargs.cu_seqlen_kv_ptr != nullptr) + if(kargs.cu_seqlen_k_ptr != nullptr) { kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 854e45c432..7cc424597a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -552,6 +552,15 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR }); }); } +#if defined(__gfx9__) + else + { + // Workaround for a compiler issue: sometimes there are not enough wait-states + // between v_mfma_f32... and v_accvgpr_read_b32 instructions if they are separated + // by s_cbranch. + tile_elementwise_inout([](auto& x) { asm("; force move to %0" : "+v"(x)); }, s_acc); + } +#endif { bool need_perpixel_check = mask.IsEdgeTile( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 8bf24be386..68ec349694 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -246,6 +248,8 @@ CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs) } } // namespace detail +/// NOTICE: This pipeline is a work in progress and is awaiting upcoming compiler fixes and +/// instruction scheduling optimizations. template struct BlockFmhaFwdV3Pipeline { @@ -261,12 +265,16 @@ struct BlockFmhaFwdV3Pipeline using OaccDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; + static_assert(is_generic_attention_mask_v); static_assert(std::is_same_v, "we will the same dist tensor 'sp_compute' for both gemm0 & softmax"); using BlockFmhaShape = ck_tile::remove_cvref_t; + using VLayout = remove_cvref_t; + static_assert(std::is_same_v); + static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; @@ -277,14 +285,24 @@ struct BlockFmhaFwdV3Pipeline static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; - static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static_assert(kQKHeaddim == 128 && kSubQKHeaddim == 128, "only supports hdim=hdim_v=128"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto QScaleEnum = Problem::QScaleEnum; + static constexpr bool kSkipMinSeqlenQ = Problem::kSkipMinSeqlenQ; + static_assert((!kHasLogitsSoftCap && BiasEnum == BlockAttentionBiasEnum::NO_BIAS && + !kStoreLSE && !kHasDropout && + (QScaleEnum == ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE) && + !kSkipMinSeqlenQ), + "enable unsupported features"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index da0fa16ee1..659bdd995b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -12,6 +12,7 @@ enum class BlockFmhaPipelineEnum QRKSVS_ASYNC, QSKSVS, QRKSVS_ASYNC_TRLOAD, + QRKSVS_ASYNC_TRLOAD_V3, }; template diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index b90b760a0d..7c4a921b70 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -264,47 +264,4 @@ struct BlockFmhaFwdAppendKVPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; -template -struct BlockFmhaFwdV3PipelineProblem -{ - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using BlockFmhaShape = remove_cvref_t; - using FmhaMask = remove_cvref_t; - using Traits = remove_cvref_t; - - static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps; - static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps; - static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); - - static constexpr bool kIsGroupMode = kIsGroupMode_; - - // attributes from traits - static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr bool kStoreLSE = Traits::kStoreLSE; - static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; -}; - } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index b9e18de1e5..df33a93696 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -166,20 +166,4 @@ struct TileFmhaBwdConvertQGradTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; -template -struct TileFmhaFwdV3Traits -{ - static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; - static constexpr bool kPadSeqLenK = kPadSeqLenK_; - static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; - static constexpr bool kPadHeadDimV = kPadHeadDimV_; - static constexpr bool kStoreLSE = kStoreLSE_; - static constexpr index_t kBlockPerCu = kBlockPerCu_; -}; - } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index d4475e8c60..8fae704203 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -176,8 +176,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); return concat('_', "pipeline_AgBgCrCompV3", concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), concat('x', WaveNumM, WaveNumN), - concat('x', kPadM, kPadN, kPadK)); + concat('x', kPadM, kPadN, kPadK), + Problem::GetName()); // clang-format on } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index 71e0ebb957..38a22e38ac 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -36,17 +36,13 @@ struct BaseGemmPipelineAgBgCrMem // TODO: Is this 32K value gfx9 arch specific? static constexpr index_t MinMemInFlyBytes = 32768; - static constexpr index_t WgpPerCU = - (4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1; + static constexpr index_t WgpPerCU = ck_tile::max(4 * get_warp_size() / BlockSize, 1); static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil(MinMemInFlyBytes / WgpPerCU, (MPerBlock * sizeof(ADataType) / APackedSize + NPerBlock * sizeof(BDataType) / BPackedSize) * KPerBlock); - static constexpr index_t PrefetchStages = - FullMemBandPrefetchStages >= 2 - ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 - : 2; + static constexpr index_t PrefetchStages = ck_tile::clamp(FullMemBandPrefetchStages, 2, 8); static constexpr index_t LocalPrefillStages = 1; static constexpr index_t GlobalBufferNum = PrefetchStages; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 2c6b1f3d48..e35f4ce70d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -301,7 +301,12 @@ struct UniversalGemmPipelineProblem return concat('_', "gemm_problem", concat('x', kBlockSize), concat('x', kPadM, kPadN, kPadK), - Scheduler); + Scheduler, + "NumWaveGroups", + NumWaveGroups, + "DoubleSmemBuffer", + DoubleSmemBuffer + ); // clang-format on } }; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index d843916f5e..76341af70b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -545,7 +545,7 @@ struct UniversalGemmBasePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA() { using AsLayout = remove_cvref_t; using AsDataType = remove_cvref_t; @@ -555,6 +555,11 @@ struct UniversalGemmBasePolicy using ALayout = remove_cvref_t{}, AsLayout>>; using ADataType = remove_cvref_t{}, AsDataType>>; + if constexpr(Problem::FixedVectorSize) + { + return Problem::VectorSizeA; + } + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB() { using BsLayout = remove_cvref_t; using BsDataType = remove_cvref_t; @@ -584,6 +589,11 @@ struct UniversalGemmBasePolicy using BLayout = remove_cvref_t{}, BsLayout>>; using BDataType = remove_cvref_t{}, BsDataType>>; + if constexpr(Problem::FixedVectorSize) + { + return Problem::VectorSizeB; + } + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using BQDataType = remove_cvref_t; + using BQLayout = remove_cvref_t; using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; @@ -154,6 +155,10 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading + using OverrideBDataType = + std::conditional_t, ADataType, BDataType>; + using Base = BlockGemmBQuantBase; using WarpGemm = remove_cvref_t; @@ -271,12 +276,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase ALdsTile a_warp_tile_; BLdsTile b_warp_tile_; - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { - load_int4_tile(a_warp_tile_, a_block_window); - load_int4_tile(b_warp_tile_, b_block_window); + load_int4_tile( + a_warp_tile_, a_block_window); + // If B datatype were pkint4 it would be converted prior to storing in LDS + load_int4_tile( + b_warp_tile_, b_block_window); } // C += A * B @@ -410,11 +423,16 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase MakeCBlockTile(); } - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window); + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 0ded65ce2e..13897d24c8 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -466,7 +466,6 @@ struct QuantGemmKernel if constexpr(kQuantType == QuantType::BQuantGrouped) { - static_assert(std::is_same_v); if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) @@ -821,7 +820,9 @@ struct QuantGemmKernel { if constexpr(PreshuffleQuant) { - static_assert(std::is_same_v); + static_assert(std::is_same_v, + "PreshuffleQuant with BQuantGrouped currently only supports " + "ColumnMajor BQ layout"); using QuantGroupSize = remove_cvref_t; return MakePreshuffledQuantTensorView< GemmPipeline::KPerBlockBQ, @@ -836,14 +837,35 @@ struct QuantGemmKernel } else { - static_assert(std::is_same_v); using QuantGroupSize = remove_cvref_t; - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), - make_tuple(kargs.stride_BQ, 1), - number{}, - number<1>{}); + + if constexpr(std::is_same_v) + { + // For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN] + // Dimensions: [K/QuantGroupK, N/QuantGroupN] + // Strides: [N/QuantGroupN, 1] + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), + integer_divide_ceil(kargs.N, QuantGroupSize::kN)), + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), + number{}, + number<1>{}); + } + else + { + static_assert(std::is_same_v); + // For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK] + // Dimensions: [N/QuantGroupN, K/QuantGroupK] + // Strides: [K/QuantGroupK, 1] + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), + integer_divide_ceil(kargs.K, QuantGroupSize::kK)), + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1), + number{}, + number<1>{}); + } } } else @@ -1068,10 +1090,10 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::BQuantGrouped) { + using QuantGroupSize = remove_cvref_t; if constexpr(PreshuffleQuant) { static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; @@ -1087,13 +1109,23 @@ struct QuantGemmKernel } else { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); + if constexpr(std::is_same_v) + { + return make_tile_window( + bq_pad_view, + make_tuple(number{}, + number{}), + {0, i_n / QuantGroupSize::kN}); + } + else + { + static_assert(std::is_same_v); + return make_tile_window( + bq_pad_view, + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); + } } } else diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index caa6aad363..726f678d37 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -163,7 +163,6 @@ struct QuantGroupedGemmKernel static constexpr index_t kBlockSize = GemmPipeline::BlockSize; static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel; - static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true"); [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -262,10 +261,9 @@ struct QuantGroupedGemmKernel auto karg = QuantGroupedGemmKernelArgs{type_convert(gemm_descs[i].a_ptr), type_convert(gemm_descs[i].b_ptr), - type_convert(gemm_descs[i].e_ptr), type_convert(gemm_descs[i].aq_ptr), type_convert(gemm_descs[i].bq_ptr), - gemm_descs[i].k_batch, + type_convert(gemm_descs[i].e_ptr), M, N, K, @@ -275,7 +273,8 @@ struct QuantGroupedGemmKernel stride_b, stride_e, gemm_descs[i].stride_AQ, - gemm_descs[i].stride_BQ}; + gemm_descs[i].stride_BQ, + gemm_descs[i].k_batch}; gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } @@ -342,16 +341,32 @@ struct QuantGroupedGemmKernel else { - RunGemmWithPipelineSelection(a_ptr, - b_ptr, - aq_ptr, - bq_ptr, - c_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); + if constexpr(UsePersistentKernel) + { + RunGemmWithPipelineSelection(a_ptr, + b_ptr, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else // Non-persistent kernel + { + Base::RunGemm({a_ptr}, + {b_ptr}, + aq_ptr, + bq_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } } } @@ -451,7 +466,24 @@ struct QuantGroupedGemmKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - if constexpr(kQuantType == QuantType::BQuantGrouped) + if constexpr(kQuantType == QuantType::AQuantGrouped) + { + const auto& aq_block_window = gemm_tile_windows.at(Base::I1); + // Run GEMM pipeline + const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, + b_block_window, + aq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + + auto& c_block_window = gemm_tile_windows.at(Base::I4); + + // Run Epilogue Pipeline + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) { const auto& bq_block_window = gemm_tile_windows.at(Base::I3); // Run GEMM pipeline @@ -496,6 +528,53 @@ struct QuantGroupedGemmKernel } } + CK_TILE_DEVICE index_t FindGroupId(const QuantGemmTransKernelArg* gemm_desc_ptr, + index_t block_id, + index_t group_count) const + { + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) >> 1); + + while((!(block_id >= gemm_desc_ptr[group_id].block_start && + block_id < gemm_desc_ptr[group_id].block_end)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) >> 1); + } + + return group_id; + } + + // For non-persistent kernels + template > + CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + index_t group_count) const + { + const index_t block_id = ck_tile::get_block_1d_id(); + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count); + const auto& kargs = gemm_desc_ptr[group_id]; + + const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N); + const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex( + 0, + kargs.group_karg.M, + kargs.group_karg.N, + (block_id - kargs.block_start) % grid_size_2d); + Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d); + } + // For persistent kernels template , diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index f3c8b7a1a3..7f89d98349 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -80,6 +80,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -165,6 +168,19 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { using Base = PipelineImplBase; + template + CK_TILE_DEVICE static void + LoadAndConvertATile(ABlockTile_& a_block_tile, + ADramWindow& a_dram_window, + const DramTileWindowStep& dram_tile_window_step) + { + using DestDataType = typename ABlockTile_::DataType; + using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(a_block_tile, a_dram_window); + move_tile_window(a_dram_window, dram_tile_window_step); + } + template const BDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, const AQDramBlockWindowTmp& aq_dram_block_window_tmp, - index_t m, + [[maybe_unused]] index_t m, index_t num_loop, void* p_smem) const { - (void)m; // unused variable static_assert( std::is_same_v> && std::is_same_v std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; - static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)"); static_assert(!PreshuffleQuant, "Memory pipeline does not support PreshuffleQuant!"); - static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}], - "Aq block window has incorrect lengths for defined AqLayout!"); static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -217,7 +228,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem "B block window has incorrect lengths for defined BLayout!"); // A/B tiles in LDS - using the same approach as regular gemm pipeline - auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem); + auto ab_lds_blocks = Base::template GetABLdsTensorViews(p_smem); auto& a_lds_block = ab_lds_blocks.at(I0{}); auto& b_lds_block = ab_lds_blocks.at(I1{}); @@ -249,7 +260,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = @@ -272,7 +283,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ); // Global prefetch initialization - DRAM to VGPRs - Base::GlobalPrefetch( + LoadAndConvertATile( a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); Base::GlobalPrefetch( b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); @@ -282,10 +293,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS prefill - VGPRs to LDS - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); } @@ -293,10 +304,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } @@ -306,9 +317,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } // Additional prefetching for memory pipeline - DRAM to VGPRs static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); + LoadAndConvertATile(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); @@ -325,16 +336,17 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, aq_block_tiles.get(number{}), a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); // Prepare next iteration data - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d( a_shuffle_tmp, @@ -348,7 +360,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -365,9 +377,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem b_element_func); } - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); + LoadAndConvertATile(a_block_tiles.get(number{}), + a_copy_dram_window, + a_dram_tile_window_step); Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); @@ -381,20 +393,89 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } // Tail handling - block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); - block_gemm( - c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window); + auto HotLoopTail = [&](auto tail_num) { + static_for<0, tail_num - 1, 1>{}([&](auto prefetch_idx) { + block_sync_lds(); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + // no second block_sync_lds because it's interwave - if constexpr(TailNum == TailNumber::Even) - { + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, + a_block_tiles.get(number{})); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + Base::LocalPrefill(a_copy_lds_window, + a_block_tiles.get(number{})); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, + b_block_tiles.get(number{})); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + Base::LocalPrefill(b_copy_lds_window, + b_block_tiles.get(number{})); + } + }); - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I1{}), a_element_func); - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I1{}), b_element_func); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm(c_block_tile, + aq_block_tiles.get(number{}), + a_lds_gemm_window, + b_lds_gemm_window); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm( - c_block_tile, aq_block_tiles.get(I1{}), a_lds_gemm_window, b_lds_gemm_window); + c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window); + } + else if constexpr(TailNum == TailNumber::Two) + { + HotLoopTail(number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + HotLoopTail(number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + HotLoopTail(number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + HotLoopTail(number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + HotLoopTail(number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + HotLoopTail(number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + HotLoopTail(number{}); } return c_block_tile; } @@ -413,7 +494,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return PipelineImpl{} .template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const BDataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 30b9d70eb8..e7bd4a2626 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -319,6 +319,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem, + index_t m = 0) const + { + const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) { + constexpr bool hot_loop = has_hot_loop_.value; + constexpr auto tail_num = tail_number_.value; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + aq_dram_block_window_tmp, + m, // dummy value, won't be used + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp index 4cd343e640..c570d4a131 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp @@ -42,14 +42,18 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase); - - using YPerTile = number; - using XPerTile = number; + using YPerTile = + std::conditional_t, + number, + number>; + using XPerTile = + std::conditional_t, + number, + number>; auto bq_copy_dram_window = make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(YPerTile(), XPerTile()), + make_tuple(YPerTile{}, XPerTile{}), bq_dram_block_window_tmp.get_window_origin(), Policy::template MakeBQDramTileDistribution()); return bq_copy_dram_window; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index a09deabab7..93c35843bb 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -25,8 +25,16 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; - static_assert(std::is_same_v); - return GetABQGlobalVectorLoadSize(); + // Support both RowMajor and ColumnMajor layouts for BQ + if constexpr(std::is_same_v) + { + return GetABQGlobalVectorLoadSize(); + } + else + { + static_assert(std::is_same_v); + return GetABQGlobalVectorLoadSize(); + } } template @@ -52,7 +60,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC WarpTile::at(I2), Problem::TransposeC>; - static_assert(std::is_same_v); if constexpr(PreshuffleQuant) { using TileEncodingPattern = tile_distribution_encoding_pattern_bq< @@ -62,18 +69,21 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC NPerBlock / WarpGemm::kN, ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()), Problem::QuantGroupSize::kN, + BQLayout, PreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution(); } else { + // KPerTile and NPerTile are LOGICAL dimensions (K quant groups and N quant groups) using TileEncodingPattern = tile_distribution_encoding_pattern_bq; + KPerBlockBQ, // Logical K dimension + NPerBlockBQ, // Logical N dimension + Problem::QuantGroupSize::kN, + BQLayout>; return TileEncodingPattern::make_2d_static_tile_distribution(); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 4883a30f57..2c191cc2b4 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -33,6 +33,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using QuantGroupSize = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading + using OverrideBDataType = + std::conditional_t, ADataType, BDataType>; + static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); using I0 = number<0>; using I1 = number<1>; @@ -83,6 +87,9 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -125,7 +132,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile, + const BDramWindow& b_dram_window) + { + using DestDataType = typename BBlockTile_::DataType; + using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(b_block_tile, b_dram_window); + } + template ; - constexpr bool is_bq_col_major = - std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; - - static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + constexpr bool is_bq_row_major = + std::is_same_v; static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -212,12 +227,22 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -237,7 +262,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(ABlockTileDistr{})); using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + decltype(make_static_distributed_tensor(BBlockTileDistr{})); using BQBlockTile = decltype(make_static_distributed_tensor(BQBlockTileDistr{})); @@ -258,18 +283,20 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}), 0) - : is_bq_col_major ? make_array(0, KPerBlockBQ) - : make_array(KPerBlockBQ, 0); + : is_bq_row_major ? make_array(KPerBlockBQ, 0) + : make_array(0, KPerBlockBQ); // DRAM prefetch (global read 0) Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // B tile gets converted to A datatype during loading + LoadAndConvertBTile(b_block_tile, b_copy_dram_window); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); Base::GlobalPrefetch( bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -281,9 +308,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // B datatype is converted to A datatype during loading + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -294,11 +322,13 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledARegTileDistribution()); @@ -322,9 +352,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: BDataType PkInt4 gets converted during loading earlier + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -335,7 +366,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + // Note: BDataType gets converted during loading from PkInt4 + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -393,7 +427,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern { @@ -210,41 +211,45 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) /// /// This function determines the optimal thread distribution pattern for loading and applying - /// quantization scales to the B matrix based on the quantization group size (XPerQ) relative + /// quantization scales to the B matrix based on the quantization group size (NPerQ) relative /// to warp dimensions. /// /// Three distinct distribution patterns are handled: /// - /// 1. Fine-grained quantization (XPerQ < WarpGemm::kN): + /// 1. Fine-grained quantization (NPerQ < WarpGemm::kN): /// - Multiple quantization groups exist within a single warp's N-dimension - /// - Each warp processes multiple scales (WarpGemm::kN / XPerQ scales per warp) - /// - Distribution includes explicit replication factor (XR = XPerQ) for scale broadcast - /// - Example: XPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp + /// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp) + /// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast + /// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp /// - /// 2. Medium-grained quantization (WarpGemm::kN <= XPerQ <= WarpGemm::kN * NWarps): + /// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps): /// - Each warp handles exactly one quantization scale - /// - Scales are distributed across warps with replication factor XR = XPerQ / WarpGemm::kN - /// - Example: XPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 + /// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN + /// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 /// - /// 3. Coarse-grained quantization (XPerQ > WarpGemm::kN * NWarps): + /// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps): /// - Quantization group spans multiple warps /// - All warps share the same scale value - /// - Example: XPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale + /// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale /// /// @return A static tile distribution encoding for the BQ scale tensor CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { + // Preshuffle only supported for ColumnMajor currently + static_assert(!(PreshuffleQuant && std::is_same_v), + "PreshuffleQuant only supported for ColumnMajor BQLayout"); + if constexpr(PreshuffleQuant) { if constexpr(YPerQ < WarpGemm::kN) { - constexpr index_t X1 = WarpGemm::kN / YPerQ; // 2 + constexpr index_t X1 = WarpGemm::kN / NPerQ; // 2 constexpr index_t X0 = 256 / 128; // 2 constexpr index_t X2 = 1; - constexpr index_t XR1 = YPerQ; // 8 + constexpr index_t XR1 = NPerQ; // 8 constexpr index_t XR0 = warp_size / (X0 * X1 * XR1); // 2 constexpr index_t Y1 = NWarps; // 4 - constexpr index_t Y0 = YPerTile / Y1; // 1 + constexpr index_t Y0 = KPerTile / Y1; // 1 constexpr index_t YR = 1; return make_static_tile_distribution( @@ -259,55 +264,97 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding } else { - if constexpr(YPerQ < WarpGemm::kN) + if constexpr(NPerQ < WarpGemm::kN) { // Case 1: Fine-grained - multiple quantization scales within a single warp - constexpr index_t X = XPerTile; // Full X dimension of tile - constexpr index_t XR = 1; // No Y replication needed - constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim - constexpr index_t Y1 = NWarps; // Number of warps in N-dim - constexpr index_t Y2 = - WarpGemm::kN / YPerQ; // Number of scales per warp 16/ 8 = 2 - constexpr index_t YR = YPerQ; // Elements per quantization group 8 + // N dimension needs to be partitioned the same way regardless of layout + constexpr index_t NR = 1; // No N replication needed + constexpr index_t N0 = NIterPerWarp; // Iterations per warp in N-dim + constexpr index_t N1 = NWarps; // Number of warps in N-dim + constexpr index_t N2 = WarpGemm::kN / NPerQ; // Number of scales per warp - static_assert(Y0 * Y1 * Y2 == YPerTile, - "Y0, Y1, Y2 must cover the blocktile along Y."); + static_assert(N0 * N1 * N2 == NPerTile, + "N0, N1, N2 must cover the blocktile along N dimension."); - return make_static_tile_distribution( - tile_distribution_encoding< - sequence, - tuple, sequence>, - tuple, sequence<0, 1, 0>>, //(Mwarp, Nwarp), (XR, Y2[no of - // scales per warp], YR) - tuple, sequence<1, 2, 2>>, - sequence<1, 2>, //(NiterPerWarp, X(threads in x dimension)) - sequence<0, 0>>{}); + if constexpr(std::is_same_v) + { + // ColumnMajor: [(N0, N1, N2), K] - N on Y-axis, partition Y + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1, 0>>, + tuple, sequence<1, 2, 2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, (N0, N1, N2)] - N on X-axis, partition X + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0>>, + tuple, sequence<1, 2, 2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } - else if constexpr(YPerQ <= WarpGemm::kN * NWarps) + else if constexpr(NPerQ <= WarpGemm::kN * NWarps) { // Case 2: Medium-grained - one quantization scale per warp - constexpr auto YR = YPerQ / WarpGemm::kN; // Scale replication factor - constexpr auto Y1 = NWarps / YR; // Warps per unique scale - constexpr auto Y0 = YPerTile / Y1; // Iterations to cover X dimension - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, - tuple, sequence<2>>, - sequence<1, 2>, - sequence<0, 0>>{}); + constexpr auto NR = NPerQ / WarpGemm::kN; // Scale replication factor + constexpr auto N1 = NWarps / NR; // Warps per unique scale + constexpr auto N0 = NPerTile / N1; // Iterations to cover N dimension + + if constexpr(std::is_same_v) + { + // ColumnMajor: [(N0, N1), K] - N on Y-axis + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, (N0, N1)] - N on X-axis + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } - else // XPerQ > WarpGemm::kN * NWarps + else // NPerQ > WarpGemm::kN * NWarps { // Case 3: Coarse-grained - quantization group spans all warps // All warps in N-dimension share the same quantization scale - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, - tuple, sequence<2>>, - sequence<2, 1>, - sequence<0, 0>>{}); + if constexpr(std::is_same_v) + { + // ColumnMajor: [N, K] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + else + { + // RowMajor: [K, N] + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, + tuple, sequence<2>>, + sequence<2, 1>, + sequence<0, 0>>{}); + } } } } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index d83338fbb2..51f0f5f1b1 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -99,28 +99,49 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV template CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { + // Estimated number of VMEM vector loads for A per block: + // total A bytes / (threads per block * vector width) constexpr index_t Aload_inst = (kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize; + // Estimated number of VMEM vector loads for B per block: + // total B bytes / (threads per block * vector width) constexpr index_t Bload_inst = (kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize; + + // Estimated number of VMEM loads for B's quant data (e.g. scales / zp). + // First ceil-divide by quant group size (how many elements share one scale), + // then by vector width to get an approximate number of vector loads. constexpr index_t BQload_inst = ck_tile::integer_divide_ceil( ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType), QuantGroupSize::kK * QuantGroupSize::kK), VectorLoadSize); - constexpr index_t kLdsVec = 8; + + // ToDo: Hardcoded, need to change in future. How many instruction emit per iteration + constexpr index_t kLdsInstCycle = 8; + // Total VMEM load instructions (A + B + quant data) constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; - constexpr index_t ds_read_inst = kMPerBlock / kLdsVec; - constexpr index_t ds_write_inst = Aload_inst; - constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); - constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + // Approximate number of LDS reads per block + constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle; + // Approximate number of LDS writes per block + // (e.g., writing A from VMEM into LDS once per A load) + constexpr index_t ds_write_inst = Aload_inst; + // Number of MFMA instructions per wave for one block tile: + constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + // How often (in MFMA units) we should insert DS (LDS) operations. + constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + // How often (in MFMA units) we should insert VMEM buffer loads. + // buffer_load_rep ≈ "MFMA per VMEM_READ", clamped so that one buffer_load + // is assumed to cover at most 4 MFMA instructions. constexpr index_t buffer_load_rep = min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma - static_for<0, nloop, 1>{}([&](auto j_inst) { - ignore = j_inst; + static_for<0, nloop, 1>{}([&](auto) { static_for<0, mfma_inst, 1>{}([&](auto i_inst) { __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA + // Insert LDS read/write groups periodically based on ds_rep. + // The % pattern staggers READ and WRITE so they don't collapse + // into the same cycle in the model. if constexpr(ds_rep > 0 && i_inst % ds_rep == 0) { __builtin_amdgcn_sched_group_barrier( @@ -140,6 +161,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read } } + // Always mark some VALU work in the loop to reflect auxiliary scalar + // or vector ALU instructions that coexist with MFMA (Blockscale calculation). __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU }); }); diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index e172e732fa..46c60cb6d7 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -560,16 +560,31 @@ struct GroupedConvolutionBackwardDataKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { + static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off return concat('_', "grouped_convolution_backward_data", gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, "gemm", GemmPipeline::GetName(), "epilogue", - EpiloguePipeline::GetName()); + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 6ef1d84a6e..f43bfdacac 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -417,26 +417,31 @@ struct GroupedConvolutionBackwardWeightKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { - constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; + constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off - if (NumGroupsToMerge > 1) { - return concat('_', "grouped_convolution_backward_weight", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName()); - } else { - return concat('_', "grouped_convolution_backward_weight", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName(), "merge", NumGroupsToMerge); - } + return concat('_', "grouped_convolution_backward_weight", + gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 72ba17c5a5..a9f3274805 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -594,26 +594,28 @@ struct GroupedConvolutionForwardKernel { constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off - if (NumGroupsToMerge > 1) { - return concat('_', "grouped_convolution_forward", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName(), - "merge", - NumGroupsToMerge); - } else { - return concat('_', "grouped_convolution_forward", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName()); - } + return concat('_', "grouped_convolution_forward", + gemm_prec_str(), + InLayout::name, + WeiLayout::name, + OutLayout::name, + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName(), + getConvSpecializationString(ConvSpecialization), + "MergedGroups", + NumGroupsToMerge, + "SplitImage", + EnableSplitImage, + "ExplicitGemm", + GroupedConvTraitsType_::ExplicitGemm + ); // clang-format on } + [[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); } + #ifdef CK_EXPERIMENTAL_BUILDER CK_TILE_HOST std::string GetInstanceString() const { diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 71739c9083..5b00e53af8 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -9,6 +9,13 @@ namespace ck_tile { +enum class GroupedConvDirection +{ + FORWARD, + BACKWARD_DATA, + BACKWARD_WEIGHT +}; + /// @brief The Grouped Conv kernel host arguments. /// /// @par Overview @@ -113,6 +120,36 @@ struct GroupedConvTraits using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor; + template + struct GemmLayouts + { + static_assert(false, "Unsupported direction."); + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutFwd; + using BsLayout = BsLayoutFwd; + using CLayout = CLayoutFwd; + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutBwdData; + using BsLayout = BsLayoutBwdData; + using CLayout = CLayoutBwdData; + }; + + template <> + struct GemmLayouts + { + using AsLayout = AsLayoutBwdWeight; + using BsLayout = BsLayoutBwdWeight; + using CLayout = CLayoutBwdWeight; + }; + template using GroupedConvImplicitGemmTraitsFwd = TileGemmTraits; diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index affa6d987b..aeec7bd471 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -90,7 +90,7 @@ submodule = submodule_t() # formatting format_procs = [] for x in all_files: - dos2unix = f"python -m dos2unix {str(x)} {str(x)}" + dos2unix = f"python3 -m dos2unix {str(x)} {str(x)}" clang_format = f"clang-format -style=file -i {str(x)}" # One process to avoid race conditions. cmd = f"{dos2unix} && {clang_format}" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 03e3ae88a3..89009c6d0b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -115,12 +115,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); } #endif -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances( @@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( op_ptrs); @@ -139,8 +141,8 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -284,12 +286,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); } #endif -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( @@ -299,7 +301,9 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); @@ -308,8 +312,8 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp index cd65a2285a..84a715b70a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp @@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_in PassThrough, PassThrough, Bilinear>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( std::vector && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { static_assert(is_same_v, "ComputeTypeA and ComputeTypeB must be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp index 36980e5935..c898dbf781 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp @@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_insta PassThrough, PassThrough, Scale>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( std::vector && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { static_assert(is_same_v, " only support same compute type"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); } - else if constexpr(is_same_v) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index e677f6f848..3fe8fa9c5a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -347,12 +347,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: ComputeTypeA and ComputeTypeB should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( @@ -367,7 +367,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -380,8 +382,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && @@ -610,12 +612,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: ComputeTypeA and ComputeTypeB should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( @@ -629,7 +631,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); @@ -642,8 +646,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp index 448a6b5d51..a0e8e46570 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp @@ -62,6 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_ PassThrough, Bilinear, PassThrough>>>& instances); +#endif + +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp index acf9c9e150..64bbdf6ec5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp @@ -62,7 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_ins PassThrough, Scale, PassThrough>>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: this operator requires the same compute type"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index ba2f6b921a..5089ea2c1e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -198,12 +198,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same!"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); @@ -219,7 +219,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances( @@ -235,8 +237,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && @@ -451,10 +453,10 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -472,7 +474,10 @@ struct DeviceOperationInstanceFactory && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( @@ -488,8 +493,8 @@ struct DeviceOperationInstanceFactory && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp index 46bc0d2320..d4729f4d13 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp @@ -129,12 +129,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "A and B compute types should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { @@ -153,7 +153,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -170,8 +172,8 @@ struct DeviceOperationInstanceFactory && @@ -229,12 +231,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "A and B compute types should be the same"); +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v) { @@ -253,7 +255,9 @@ struct DeviceOperationInstanceFactory) +#endif +#ifdef CK_ENABLE_TF32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); @@ -270,8 +274,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( @@ -152,7 +152,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); @@ -169,9 +171,8 @@ struct DeviceOperationInstanceFactory && @@ -221,12 +222,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -244,7 +245,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); @@ -261,9 +264,8 @@ struct DeviceOperationInstanceFactory>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && DLayouts::Size() == 1 && is_same_v, NDHWGK>) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp index 90852d2945..090c99819f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp @@ -127,12 +127,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( @@ -150,7 +150,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); @@ -167,9 +169,8 @@ struct DeviceOperationInstanceFactory && @@ -218,12 +219,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { static_assert(is_same_v, "Error: AComputeType and BComputeType should be the same"); +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( @@ -241,7 +242,9 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); @@ -258,8 +261,8 @@ struct DeviceOperationInstanceFactory>>& instances); +#endif +#ifdef CK_ENABLE_TF32 void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && DLayouts::Size() == 0) { -#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_TF32 if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); } - else +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } - } #endif + } #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index eeaf269394..ef037526ca 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -13,6 +13,8 @@ function(add_instance_library INSTANCE_NAME) set(type1 "_f16") elseif(type MATCHES "fp32") set(type1 "_f32") + elseif(type MATCHES "tf32") + set(type1 "_tf32") elseif(type MATCHES "fp8") set(type1 "_f8") elseif(type MATCHES "bf16") @@ -27,8 +29,8 @@ function(add_instance_library INSTANCE_NAME) #if filename matches any selected type, exit type loop and do no exclude the file from the list set(test 0) break() - elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR - source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND + elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "tf32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR + source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_tf32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND NOT (source_name MATCHES type OR source_name MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) @@ -102,9 +104,11 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() # Only build tf32 instances for gfx942 & gfx950 - if(NOT (INST_TARGETS MATCHES "gfx942|gfx950") AND source_name MATCHES "_tf32_") - message(DEBUG "removing tf32 instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") + if(source_name MATCHES "_tf32_") + if(NOT ((INST_TARGETS MATCHES "gfx942|gfx950") AND CK_ENABLE_TF32)) + message(DEBUG "removing tf32 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() endif() endforeach() @@ -223,6 +227,10 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "fp32 instance found!") set(add_inst 1) endif() + if(("${cmake_instance}" MATCHES "_tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32") + message(DEBUG "tf32 instance found!") + set(add_inst 1) + endif() if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") message(DEBUG "fp64 instance found!") set(add_inst 1) @@ -237,6 +245,7 @@ FOREACH(subdir_path ${dir_list}) "${cmake_instance}" MATCHES "_f16" OR "${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32" OR + "${cmake_instance}" MATCHES "_tf32" OR "${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64" OR "${cmake_instance}" MATCHES "_bf16" OR @@ -330,7 +339,7 @@ FOREACH(subdir_path ${dir_list}) list(APPEND CK_DEVICE_OTHER_INSTANCES $) endif() message(DEBUG "add_instance_directory ${subdir_path}") - endif() + endif() else() message(DEBUG "skip_instance_directory ${subdir_path}") endif() diff --git a/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp b/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp index 3cda620831..47a12e2d88 100644 --- a/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp @@ -75,13 +75,13 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, is_same::value || is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout); } else if constexpr(is_same::value || is_same::value || is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout); } }; diff --git a/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp b/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp index 2a7ee6fd66..ac7ab78ed7 100644 --- a/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp @@ -75,13 +75,13 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, is_same::value || is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout); } else if constexpr(is_same::value || is_same::value || is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout); } }; diff --git a/profiler/src/profile_grouped_conv_bwd_data.cpp b/profiler/src/profile_grouped_conv_bwd_data.cpp index 62d6e860f9..cbf763fc13 100644 --- a/profiler/src/profile_grouped_conv_bwd_data.cpp +++ b/profiler/src/profile_grouped_conv_bwd_data.cpp @@ -84,9 +84,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using namespace ck::tensor_layout::convolution; @@ -143,9 +141,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -164,9 +160,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -185,9 +179,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW) @@ -206,9 +198,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } } @@ -230,9 +220,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -251,9 +239,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -272,9 +258,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -293,9 +277,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{}); -#endif } } } diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index a18aab41a5..c4f154e180 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -99,9 +99,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using namespace ck::tensor_layout::convolution; @@ -162,9 +160,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -184,9 +180,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -210,9 +204,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -243,9 +235,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -270,9 +260,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -306,9 +294,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -340,9 +326,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index c94b77dd4f..4319d849c8 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -105,9 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using INT8 = int8_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; -#if defined(__gfx942__) || defined(__gfx950__) using TF32 = ck::tf32_t; -#endif // using GNWC = ck::tensor_layout::convolution::GNWC; @@ -228,9 +226,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -253,9 +249,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) @@ -280,9 +274,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } // NHWGC_GKYXC_NHWGK @@ -306,9 +298,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -331,9 +321,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) @@ -352,9 +340,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) @@ -373,9 +359,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -416,9 +400,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } // NGCDHW_GKCZYX_NGKDHW @@ -439,9 +421,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp index 4eb12e6e19..79b9beb8c7 100644 --- a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp @@ -105,9 +105,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) using F32 = float; using BF16 = ck::bhalf_t; using F16 = ck::half_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; @@ -172,9 +170,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -194,9 +190,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/profiler/src/profile_grouped_conv_fwd_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_clamp.cpp index 7df9fd6167..f497ee8da5 100644 --- a/profiler/src/profile_grouped_conv_fwd_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_clamp.cpp @@ -105,9 +105,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) using F32 = float; using BF16 = ck::bhalf_t; using F16 = ck::half_t; -#if defined(__gfx942__) using TF32 = ck::tf32_t; -#endif using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; @@ -175,9 +173,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -197,9 +193,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); -#endif } } diff --git a/python/ck4inductor/__init__.py b/python/ck4inductor/__init__.py index 0eee25ecaa..089a2d439b 100644 --- a/python/ck4inductor/__init__.py +++ b/python/ck4inductor/__init__.py @@ -6,7 +6,7 @@ def __version__(): import subprocess # needs to be manually updated - rocm_version = "7.0.1" + rocm_version = "7.1.1" hash_width = 6 try: hash = subprocess.check_output("git rev-parse HEAD", shell=True, text=True)[ diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f8498c6c03..c221f11f46 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -65,6 +65,9 @@ function(add_test_executable TEST_NAME) if((source_name MATCHES "_fp32|_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() + if((source_name MATCHES "_tf32|_tf32") AND NOT "tf32" IN_LIST DTYPES) + set(test 1) + endif() if((source_name MATCHES "_fp64|_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() @@ -156,6 +159,9 @@ function(add_gtest_executable TEST_NAME) if((source_name MATCHES "_fp32|_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() + if((source_name MATCHES "_tf32|_tf32") AND NOT "tf32" IN_LIST DTYPES) + set(test 1) + endif() if((source_name MATCHES "_fp64|_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 38bd59b882..39a7c66f38 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -86,8 +86,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test using TilePartitioner = ck_tile::GemmTile1DPartitioner; - // BQLayout is always ColumnMajor for BQuant - using BQLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + // Re-use the AQLayout for BQLayout + using BQLayout = AQLayout; using CodegenGemmTraits = ck_tile::TileGemmQuantTraits>; using GroupSize2D128N = ck_tile::QuantGroupShape>; // Type combinations for BQuant tests (without PreshuffleB) -// Tuple format: // clang-format off using BQuantTypes = ::testing::Types< - // 1d cases with grouping only on k axis (AQLayout is always RowMajor for BQuant) - std::tuple, - std::tuple, - std::tuple, - std::tuple, + // 1d cases with grouping only on k axis + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, // 2d cases with grouping also on the n axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // some cases with transpose layouts + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple, + + // pkint4 + transpose cases + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp index 6cde4bded5..3a62fc091a 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp @@ -26,60 +26,60 @@ using GroupSize2D32N = ck_tile::QuantGroupShape>; using GroupSize2D64N = ck_tile::QuantGroupShape>; // Type combinations for BQuant tests with PreshuffleB -// Tuple format: // clang-format off using BPreshuffleBQuantTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, // //2d cases with preshuffle B - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 7b16529aa8..bf9c7a138d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -389,6 +389,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBaseis_row_major(BQLayout{}) ? BQN : BQK; // Generate test data ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); ck_tile::HostTensor b_k_n( ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); - // BQ is always ColumnMajor ck_tile::HostTensor bq_bqk_bqn( - ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, ck_tile::bool_constant{})); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{}))); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt index 2bd2571993..7a7ae77730 100644 --- a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -14,6 +14,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) + target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp index 551989421f..6a1a28884a 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant.cpp @@ -18,32 +18,41 @@ using True = ck_tile::bool_constant; using False = ck_tile::bool_constant; using RowColQuant = std::integral_constant; using TensorQuant = std::integral_constant; +using AQuant = std::integral_constant; using BQuant = std::integral_constant; // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False>, - std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True>, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True> + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp new file mode 100644 index 0000000000..8dcd6d017d --- /dev/null +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_aquant.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_util_quant.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; +using AQuant = std::integral_constant; + +// clang-format off +using KernelTypes_AQuant = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, True>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, True>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuant, False, False, False> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_AQuant, KernelTypes_AQuant); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_AQuant +#include "test_grouped_gemm_quant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp index 4f44acf4c4..6c0ad545b7 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_bquant.cpp @@ -20,9 +20,14 @@ using BQuant = std::integral_constant, - std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False, True, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False, False, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp index 48720aeebf..cc1b32fb20 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_rowcol.cpp @@ -20,11 +20,14 @@ using RowColQuant = std::integral_constant, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp index f59fa29ec2..e446f7b168 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_quant_tensor.cpp @@ -20,11 +20,14 @@ using TensorQuant = std::integral_constant, - std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>, - std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False> + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB, Persistent, TransposeC + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, True, False>, + + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False>, + std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False, False, False> >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index 68b6735655..9941066c3e 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -3,6 +3,7 @@ #pragma once #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" @@ -32,24 +33,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using AQLayout = Row; using BQLayout = Col; - static constexpr bool Persistent = true; static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value; - - template - static constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() - { -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif - } + static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value; + static constexpr bool TransposeC = std::tuple_element_t<12, Tuple>::value; struct GroupedGemKernelParam_Mfma { @@ -66,11 +52,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test static const ck_tile::index_t N_Warp = 2; static const ck_tile::index_t K_Warp = 1; - static const ck_tile::index_t M_Warp_Tile = 32; - static const ck_tile::index_t N_Warp_Tile = 32; - static const ck_tile::index_t K_Warp_Tile = - TestCkTileGroupedGemmQuant::template get_k_from_preshuffled_warp_tile(); + static const ck_tile::index_t M_Warp_Tile = 16; + static const ck_tile::index_t N_Warp_Tile = 16; + static const ck_tile::index_t K_Warp_Tile = 32; }; struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma @@ -90,16 +74,201 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); } + template + float invoke_grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + constexpr bool DoubleSmemBuffer = + PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B + + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || + QuantType == ck_tile::QuantType::BQuantGrouped; + + using QuantGroupSize = ck_tile::QuantGroupShape>; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = std::conditional_t< + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrCompV3, + std::conditional_t< + PreshuffleB == true, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, + ck_tile::BaseGemmPipelineAgBgCrCompV3>>, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile; + const ck_tile::index_t K_split = + (gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; + + using GemmPipeline = std::conditional_t< + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GroupedGemKernelParam::M_Warp, + GroupedGemKernelParam::N_Warp, + GroupedGemKernelParam::M_Warp_Tile, + GroupedGemKernelParam::N_Warp_Tile, + GroupedGemKernelParam::K_Warp_Tile, + QuantGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } + template void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr) { - constexpr bool TransposeC = false; constexpr bool DoubleSmemBuffer = PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B - constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; @@ -131,40 +300,53 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test BQLayout, TransposeC, DoubleSmemBuffer, - true>; + Persistent>; const auto Run = [&](const auto memory_operation_) { constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; constexpr auto memory_operation = memory_operation_.value; - constexpr bool transpose_c = false; // We create the GEMM pipeline without specifying hotloop or tailnumber. // These are automatically run inside the kernel based on the given input data. - using QuantGemmProblem = typename std::conditional< - QuantType == ck_tile::QuantType::BQuantGrouped, - ck_tile::GemmBQuantPipelineProblem, + + constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || + QuantType == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, ck_tile::GemmRowColTensorQuantPipelineProblem>::type; + scheduler>>; using GemmPipeline = std::conditional_t< - QuantType == ck_tile::QuantType::RowColQuant || - QuantType == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem( + ck_tile::make_kernel( Kernel{}, grids, blocks, @@ -292,13 +474,24 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test AQK = 1; // Row quantization: tensor shape [M, 1] or [1] BQK = 1; // Column quantization: tensor shape [1, N] or [1] } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize + BQK = 0; // No B quantization + if(K % QuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for AQuantGrouped mode"); + } + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { - AQK = 0; // No A quantization - BQK = K / 128; // Group quantization: BQK = K / GroupSize - if(K % 128 != 0) + AQK = 0; // No A quantization + BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize + if(K % QuantGroupSize::kK != 0) { - throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode"); + throw std::runtime_error( + "K must be divisible by QuantGroupSize::kK for BQuantGrouped mode"); } } @@ -317,6 +510,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test stride_AQs[i] = 1; // Tensor quantization: tensor shape [1] stride_BQs[i] = 1; // Tensor quantization: tensor shape [1] } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout())); + stride_BQs[i] = 0; // No B quantization + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { stride_AQs[i] = 0; // No A quantization @@ -348,11 +547,20 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test ck_tile::HostTensor(ck_tile::host_tensor_descriptor( 1, 1, stride_BQs[i], is_row_major(BQLayout())))); } + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) + { + aq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + M, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + bq_tensors.push_back( + ck_tile::HostTensor(ck_tile::host_tensor_descriptor( + 0, 0, stride_BQs[i], is_row_major(BQLayout())))); + } else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) { aq_tensors.push_back( ck_tile::HostTensor(ck_tile::host_tensor_descriptor( - 0, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + 0, 0, stride_AQs[i], is_row_major(AQLayout{})))); bq_tensors.push_back( ck_tile::HostTensor(ck_tile::host_tensor_descriptor( BQK, N, stride_BQs[i], is_row_major(BQLayout())))); @@ -429,11 +637,12 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if constexpr(Persistent) { // Generate kernel arguments std::vector kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); assert(gemm_descs[0].k_batch == 1); for(const auto& arg : gemm_descs) { @@ -471,7 +680,14 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test } else { - GTEST_FAIL() << "Non-persistent kernel not implemented yet"; + const auto stream = ck_tile::stream_config{nullptr, false, 1}; +#if CK_TILE_USE_WMMA + invoke_grouped_gemm( + gemm_descs, stream, kargs_ptr); +#else + invoke_grouped_gemm( + gemm_descs, stream, kargs_ptr); +#endif } // Copy results back to host for validation @@ -512,7 +728,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test bq_tensors[i], c_m_n_host_ref); } - else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) + else if constexpr(QuantType == ck_tile::QuantType::AQuantGrouped) { ck_tile::reference_gemm_quant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + } + else if constexpr(QuantType == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::reference_gemm_quant( a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); } @@ -550,5 +777,8 @@ using TestCkTileGroupedGemmQuant_RowCol = TestCkTileGroupedGemmQuant; template using TestCkTileGroupedGemmQuant_Tensor = TestCkTileGroupedGemmQuant; +template +using TestCkTileGroupedGemmQuant_AQuant = TestCkTileGroupedGemmQuant; + template using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant; diff --git a/test/ck_tile/utility/CMakeLists.txt b/test/ck_tile/utility/CMakeLists.txt index aa15293411..01ed83841b 100644 --- a/test/ck_tile/utility/CMakeLists.txt +++ b/test/ck_tile/utility/CMakeLists.txt @@ -3,5 +3,7 @@ message("-- Adding: test/ck_tile/utility/") +add_gtest_executable(test_fill test_fill.cpp) + # Add print tests add_subdirectory(print) diff --git a/test/ck_tile/utility/test_fill.cpp b/test/ck_tile/utility/test_fill.cpp new file mode 100644 index 0000000000..18f42c4ad0 --- /dev/null +++ b/test/ck_tile/utility/test_fill.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host/fill.hpp" +#include "ck_tile/host/joinable_thread.hpp" +#include +#include +#include +#include + +using namespace ck_tile; + +namespace test { + +// Test fixture for FillUniformDistribution tests +template +class FillUniformDistributionTest : public ::testing::Test +{ + public: + static constexpr uint32_t seed = 42; + static constexpr float a = -5.0f; + static constexpr float b = 5.0f; +}; + +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(FillUniformDistributionTest, TestTypes); + +// Test that multiple runs with the same seed produce identical results +TYPED_TEST(FillUniformDistributionTest, ConsistencyWithSameSeed) +{ + using T = TypeParam; + const auto a = this->a; + const auto b = this->b; + const auto seed = this->seed; + + constexpr size_t size = 1024 * 1024 * 1024 / sizeof(T); // 1G + + std::vector vec1(size); + auto start = std::chrono::high_resolution_clock::now(); + FillUniformDistribution{a, b, seed}(vec1.begin(), vec1.end()); + auto end = std::chrono::high_resolution_clock::now(); + double sec = std::chrono::duration(end - start).count(); + std::cout << "Taking " << sec << " sec to fill 1GB of data of type " << typeid(T).name() + << std::endl; + + const auto cpu_cores = max(32U, get_available_cpu_cores()); + for(auto num_threads_diff : {-3, -1}) + { + cpu_core_guard cg(min(max(cpu_cores + num_threads_diff, 1U), get_available_cpu_cores())); + std::vector vec2(size); + FillUniformDistribution{a, b, seed}(vec2.begin(), vec2.end()); + EXPECT_EQ(0, std::memcmp(vec1.data(), vec2.data(), size * sizeof(T))) + << "First and second fill should be identical"; + } +} + +// Test consistency across different data sizes (which affects threading) +TYPED_TEST(FillUniformDistributionTest, ConsistencyAcrossSizes) +{ + using T = TypeParam; + const auto a = this->a; + const auto b = this->b; + const auto seed = this->seed; + + std::vector test_sizes = { + 100, // Small - likely single threaded + 10000, // Medium + 1000000, // Large - will use multiple threads + 5000000 // Very large - will use many threads + }; + + for(size_t size : test_sizes) + { + std::vector reference(size); + std::vector test_vec(size); + + FillUniformDistribution{a, b, seed}(reference.begin(), reference.end()); + + // Run multiple times to ensure consistency + for(int run = 0; run < 3; ++run) + { + std::fill(test_vec.begin(), test_vec.end(), T{}); + FillUniformDistribution{a, b, seed}(test_vec.begin(), test_vec.end()); + + EXPECT_EQ(0, std::memcmp(reference.data(), test_vec.data(), size * sizeof(T))) + << "Mismatch for size=" << size << " run=" << run; + } + } +} + +// Test that different seeds produce different results +TYPED_TEST(FillUniformDistributionTest, CommonPrefix) +{ + using T = TypeParam; + const auto a = this->a; + const auto b = this->b; + const auto seed = this->seed; + + std::vector test_sizes = { + 100, // Small - likely single threaded + 10000, // Medium + 1000000, // Large - will use multiple threads + 5000000 // Very large - will use many threads + }; + + auto longest = std::make_unique>(test_sizes[0]); + FillUniformDistribution{a, b, seed}(longest->begin(), longest->end()); + for(size_t i = 1; i < test_sizes.size(); ++i) + { + auto current = std::make_unique>(test_sizes[i]); + FillUniformDistribution{a, b, seed}(current->begin(), current->end()); + size_t min_size = std::min(longest->size(), current->size()); + EXPECT_EQ(0, std::memcmp(longest->data(), current->data(), min_size * sizeof(T))) + << "Different sizes with same seed should have the same prefix"; + if(current->size() > longest->size()) + { + longest = std::move(current); + } + } +} + +// Test edge cases +TYPED_TEST(FillUniformDistributionTest, EdgeCases) +{ + using T = TypeParam; + const auto a = this->a; + const auto b = this->b; + const auto seed = this->seed; + + // Empty range + std::vector empty_vec; + EXPECT_NO_THROW((FillUniformDistribution{a, b, seed}(empty_vec.begin(), empty_vec.end()))); + + // Single element + std::vector single1(1); + std::vector single2(1); + FillUniformDistribution{a, b, seed}(single1.begin(), single1.end()); + FillUniformDistribution{a, b, seed}(single2.begin(), single2.end()); + + EXPECT_EQ(0, std::memcmp(single1.data(), single2.data(), sizeof(T))) + << "Single element should be consistent"; + + // Small sizes that might affect threading decisions + std::vector small_sizes = {2, 3, 7, 15, 16, 17, 31, 32, 33, 63, 64, 65}; + for(size_t size : small_sizes) + { + std::vector vec1(size); + std::vector vec2(size); + FillUniformDistribution{a, b, seed}(vec1.begin(), vec1.end()); + FillUniformDistribution{a, b, seed}(vec2.begin(), vec2.end()); + + EXPECT_EQ(0, std::memcmp(vec1.data(), vec2.data(), size * sizeof(T))) + << "Edge case failed for size=" << size; + } +} +} // namespace test