mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
rebase with develop
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -36,6 +36,9 @@ tags
|
||||
# Editors
|
||||
.vscode
|
||||
|
||||
# CMake formatting configuration (local)
|
||||
.cmake-format.yaml
|
||||
|
||||
# Cline
|
||||
.cline*
|
||||
|
||||
|
||||
10
CHANGELOG.md
10
CHANGELOG.md
@@ -2,6 +2,16 @@
|
||||
|
||||
Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/).
|
||||
|
||||
## (Unreleased) Composable Kernel 1.3.0
|
||||
|
||||
### Added
|
||||
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
|
||||
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
|
||||
|
||||
### Changed
|
||||
|
||||
### Upcoming changes
|
||||
|
||||
## Composable Kernel 1.2.0 for ROCm 7.2.0
|
||||
|
||||
### Added
|
||||
|
||||
@@ -92,6 +92,10 @@ if (DTYPES)
|
||||
add_definitions(-DCK_ENABLE_FP32)
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "tf32")
|
||||
# definition will be added based on the GPU target in the following section
|
||||
set(CK_ENABLE_TF32 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp64")
|
||||
add_definitions(-DCK_ENABLE_FP64)
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
@@ -106,6 +110,7 @@ else()
|
||||
set(CK_ENABLE_INT8 "ON")
|
||||
set(CK_ENABLE_FP16 "ON")
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
set(CK_ENABLE_TF32 "ON")
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
set(CK_ENABLE_BF16 "ON")
|
||||
set(CK_ENABLE_FP8 "ON")
|
||||
@@ -282,6 +287,15 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
|
||||
set(CK_GFX950_SUPPORT "ON")
|
||||
endif()
|
||||
|
||||
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
|
||||
add_definitions(-DCK_ENABLE_TF32)
|
||||
set(CK_ENABLE_TF32 "ON")
|
||||
else()
|
||||
message(STATUS "Disabling TF32 instances")
|
||||
remove_definitions(-DCK_ENABLE_TF32)
|
||||
set(CK_ENABLE_TF32 "OFF")
|
||||
endif()
|
||||
|
||||
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
|
||||
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
|
||||
add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
@@ -651,6 +665,9 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
|
||||
if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
FROM ubuntu:24.04
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ARG ROCMVERSION=7.0.1
|
||||
ARG ROCMVERSION=7.1.1
|
||||
ARG compiler_version=""
|
||||
ARG compiler_commit=""
|
||||
ARG CK_SCCACHE=""
|
||||
@@ -13,8 +13,8 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN set -xe && \
|
||||
apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl
|
||||
|
||||
RUN wget https://repo.radeon.com/amdgpu-install/7.0.1/ubuntu/noble/amdgpu-install_7.0.1.70001-1_all.deb && \
|
||||
apt install ./amdgpu-install_7.0.1.70001-1_all.deb -y && \
|
||||
RUN wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \
|
||||
apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \
|
||||
apt update && \
|
||||
apt install python3-setuptools python3-wheel -y && \
|
||||
apt install rocm-dev -y
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.0.1"
|
||||
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.1.1"
|
||||
FROM $BASE_DOCKER
|
||||
ARG compiler_version=""
|
||||
ARG compiler_commit=""
|
||||
|
||||
@@ -29,4 +29,4 @@ RUN groupadd -g 109 render && \
|
||||
git sparse-checkout set projects/hipblaslt shared/origami && \
|
||||
cd projects/hipblaslt && \
|
||||
git show --oneline -s && \
|
||||
CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --logic-yaml-filter gfx950/*/* --architecture="gfx942;gfx950" -j 128 --skip_rocroller
|
||||
CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx942;gfx950" -j 128 --skip_rocroller
|
||||
|
||||
22
Jenkinsfile
vendored
22
Jenkinsfile
vendored
@@ -288,7 +288,7 @@ def getBaseDockerImageName(){
|
||||
}
|
||||
else{
|
||||
def ROCM_numeric = parseVersion("${params.ROCMVERSION}")
|
||||
if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 1 ){
|
||||
if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 2 ){
|
||||
img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}"
|
||||
}
|
||||
else{
|
||||
@@ -434,7 +434,7 @@ def buildDocker(install_prefix){
|
||||
}
|
||||
catch(Exception ex){
|
||||
echo "Unable to locate image: ${image_name}. Building image now"
|
||||
retimage = docker.build("${image_name}", dockerArgs + ' .')
|
||||
retimage = docker.build("${image_name}", dockerArgs)
|
||||
withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) {
|
||||
retimage.push()
|
||||
}
|
||||
@@ -834,12 +834,14 @@ def Build_CK(Map conf=[:]){
|
||||
if (params.hipTensor_test && arch == "gfx90a" ){
|
||||
// build and test hipTensor on gfx90a node
|
||||
sh """#!/bin/bash
|
||||
rm -rf "${params.hipTensor_branch}".zip
|
||||
rm -rf hipTensor-"${params.hipTensor_branch}"
|
||||
wget https://github.com/ROCm/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip
|
||||
unzip -o "${params.hipTensor_branch}".zip
|
||||
rm -rf rocm-libraries
|
||||
git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git
|
||||
cd rocm-libraries
|
||||
git sparse-checkout init --cone
|
||||
git sparse-checkout set projects/hiptensor
|
||||
git checkout "${params.hipTensor_branch}"
|
||||
"""
|
||||
dir("hipTensor-${params.hipTensor_branch}"){
|
||||
dir("rocm-libraries/projects/hiptensor"){
|
||||
sh """#!/bin/bash
|
||||
mkdir -p build
|
||||
ls -ltr
|
||||
@@ -1095,7 +1097,7 @@ def run_pytorch_tests(Map conf=[:]){
|
||||
//launch develop branch daily jobs
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_PERFORMANCE_TESTS=true;FORCE_CI=true
|
||||
0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX101=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX101=false;BUILD_GFX908=false;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true
|
||||
@@ -1121,8 +1123,8 @@ pipeline {
|
||||
description: 'If you want to use a custom docker image, please specify it here (default: leave blank).')
|
||||
string(
|
||||
name: 'ROCMVERSION',
|
||||
defaultValue: '7.0.1',
|
||||
description: 'Specify which ROCM version to use: 7.0.1 (default).')
|
||||
defaultValue: '7.1.1',
|
||||
description: 'Specify which ROCM version to use: 7.1.1 (default).')
|
||||
string(
|
||||
name: 'COMPILER_VERSION',
|
||||
defaultValue: '',
|
||||
|
||||
@@ -187,7 +187,7 @@ limit the number of threads. For example, if you have a 128-core CPU and 128 Gb
|
||||
|
||||
Additional cmake flags can be used to significantly speed-up the build:
|
||||
|
||||
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build
|
||||
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build
|
||||
instances of select data types only. The main default data types are fp32 and fp16; you can safely skip
|
||||
other data types.
|
||||
|
||||
|
||||
@@ -27,6 +27,9 @@ if (DTYPES)
|
||||
add_definitions(-DCK_ENABLE_FP32)
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "tf32")
|
||||
set(CK_ENABLE_TF32 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp64")
|
||||
add_definitions(-DCK_ENABLE_FP64)
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
@@ -41,6 +44,7 @@ else()
|
||||
set(CK_ENABLE_INT8 "ON")
|
||||
set(CK_ENABLE_FP16 "ON")
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
set(CK_ENABLE_TF32 "ON")
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
set(CK_ENABLE_BF16 "ON")
|
||||
if (GPU_TARGETS MATCHES "gfx94")
|
||||
@@ -67,6 +71,14 @@ if (GPU_TARGETS)
|
||||
add_definitions(-DCK_USE_FNUZ_FP8)
|
||||
set(CK_USE_FNUZ_FP8 "ON")
|
||||
endif()
|
||||
if ((GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
|
||||
add_definitions(-DCK_ENABLE_TF32)
|
||||
set(CK_ENABLE_TF32 "ON")
|
||||
else()
|
||||
message(STATUS "Disabling TF32 instances for this target")
|
||||
remove_definitions(-DCK_ENABLE_TF32)
|
||||
set(CK_ENABLE_TF32 "OFF")
|
||||
endif()
|
||||
else()
|
||||
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core[api_reference]==1.20.1
|
||||
rocm-docs-core[api_reference]==1.31.0
|
||||
sphinxcontrib-bibtex==2.6.5
|
||||
|
||||
@@ -237,7 +237,7 @@ requests==2.32.3
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core[api-reference]==1.20.1
|
||||
rocm-docs-core[api-reference]==1.31.0
|
||||
# via -r requirements.in
|
||||
rpds-py==0.24.0
|
||||
# via
|
||||
|
||||
@@ -208,40 +208,6 @@ add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
|
||||
target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# add fmha_fwd_v3 example
|
||||
set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3")
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")
|
||||
|
||||
add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS
|
||||
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
|
||||
)
|
||||
target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE
|
||||
fmha_fwd_v3.cpp
|
||||
${FMHA_FWD_V3_INSTANCES}
|
||||
)
|
||||
|
||||
set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
|
||||
-fgpu-flush-denormals-to-zero
|
||||
-Wno-undefined-func-template
|
||||
--save-temps
|
||||
)
|
||||
set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS)
|
||||
|
||||
check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32)
|
||||
if(HAS_DISABLE_PACKED_FP32)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
|
||||
-mllvm --amdgpu-disable-packed-fp32=1
|
||||
)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS
|
||||
-DCK_TILE_DISABLE_PACKED_FP32=1
|
||||
)
|
||||
endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
|
||||
target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS})
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
|
||||
@@ -30,16 +30,24 @@ _MASK_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def get_mask_map(mask: str):
|
||||
if mask == "generic":
|
||||
def get_mask_map(mask_impl: str):
|
||||
if mask_impl == "generic":
|
||||
return _MASK_MAP
|
||||
elif mask == "simplified":
|
||||
elif mask_impl == "simplified":
|
||||
return _MASK_SIMPLIFIED_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
def get_mask_impl(mask: str) -> str:
|
||||
return "simplified" if mask.startswith("s_") else "generic"
|
||||
|
||||
|
||||
def get_mask_cpp_type(mask: str) -> str:
|
||||
return get_mask_map(get_mask_impl(mask))[mask]
|
||||
|
||||
|
||||
_MASK_CHECK_MAP = {
|
||||
"no": "t.mask_type == mask_enum::no_mask",
|
||||
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
|
||||
@@ -62,6 +70,10 @@ def get_mask_check_map(mask: str):
|
||||
return None
|
||||
|
||||
|
||||
def get_mask_cpp_check_expr(mask: str) -> str:
|
||||
return get_mask_check_map(get_mask_impl(mask))[mask]
|
||||
|
||||
|
||||
QSCALE_MAP = {
|
||||
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
|
||||
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
|
||||
@@ -122,6 +134,7 @@ PIPELINE_MAP = {
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
|
||||
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
|
||||
"qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline",
|
||||
}
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
@@ -131,6 +144,7 @@ PIPELINE_ENUM_MAP = {
|
||||
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
|
||||
"qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,616 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include <ck_tile/core/numeric/math.hpp>
|
||||
#include <ck_tile/core/utility/functional.hpp>
|
||||
#include <ck_tile/host/arg_parser.hpp>
|
||||
#include <ck_tile/host/device_memory.hpp>
|
||||
#include <ck_tile/host/fill.hpp>
|
||||
#include <ck_tile/host/check_err.hpp>
|
||||
#include <ck_tile/host/host_tensor.hpp>
|
||||
#include <ck_tile/host/reference/reference_batched_gemm.hpp>
|
||||
#include <ck_tile/host/reference/reference_batched_masking.hpp>
|
||||
#include <ck_tile/host/reference/reference_batched_softmax.hpp>
|
||||
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParser>
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s", "3328", "seqlen_q")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q & k")
|
||||
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("iperm",
|
||||
"0",
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "0", "permute output")
|
||||
.insert("causal", "0", "0: no mask, 1: causal mask")
|
||||
.insert("v", "1", "0:no verify, 1:verify")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "30", "number of iterations to benchmark the kernel")
|
||||
// Optional effective seqlen override (exclude PAD) for batch mode
|
||||
.insert("q_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("kv_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_pair(result, arg_parser);
|
||||
}
|
||||
|
||||
enum class TensorLayout
|
||||
{
|
||||
bhsd,
|
||||
bshd,
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, TensorLayout layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case TensorLayout::bhsd: return stream << "bhsd";
|
||||
case TensorLayout::bshd: return stream << "bshd";
|
||||
default: return stream << "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
struct Problem
|
||||
{
|
||||
explicit Problem(const ck_tile::ArgParser& args)
|
||||
{
|
||||
data_type = args.get_str("prec") == "fp16"
|
||||
? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16
|
||||
: ck_tile::fmha_fwd_v3_args::data_type_enum::bf16;
|
||||
batch = args.get_int("b");
|
||||
seqlen_q = args.get_int("s");
|
||||
seqlen_k = args.get_int("s_k");
|
||||
if(seqlen_k < 0)
|
||||
{
|
||||
seqlen_k = seqlen_q;
|
||||
}
|
||||
nhead_q = args.get_int("h");
|
||||
nhead_kv = args.get_int("h_k");
|
||||
if(nhead_kv < 0)
|
||||
{
|
||||
nhead_kv = nhead_q;
|
||||
}
|
||||
hdim = args.get_int("d");
|
||||
softmax_scale = args.get_float("scale_s");
|
||||
if(softmax_scale == .0f)
|
||||
softmax_scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim));
|
||||
|
||||
const auto is_causal = args.get_bool("causal");
|
||||
if(is_causal)
|
||||
{
|
||||
mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
mask = mask_info::decode("0", seqlen_q, seqlen_k);
|
||||
}
|
||||
|
||||
input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
|
||||
output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
|
||||
q_eff_lens = args.get_int_vec("q_eff_lens");
|
||||
kv_eff_lens = args.get_int_vec("kv_eff_lens");
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> get_query_shape() const
|
||||
{
|
||||
if(input_layout == TensorLayout::bhsd)
|
||||
{
|
||||
return {batch, nhead_q, seqlen_q, hdim};
|
||||
}
|
||||
else
|
||||
{
|
||||
return {batch, seqlen_q, nhead_q, hdim};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> get_key_shape() const
|
||||
{
|
||||
if(input_layout == TensorLayout::bhsd)
|
||||
{
|
||||
return {batch, nhead_kv, seqlen_k, hdim};
|
||||
}
|
||||
else
|
||||
{
|
||||
return {batch, seqlen_k, nhead_kv, hdim};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> get_value_shape() const
|
||||
{
|
||||
if(input_layout == TensorLayout::bhsd)
|
||||
{
|
||||
return {batch, nhead_kv, seqlen_k, hdim};
|
||||
}
|
||||
else
|
||||
{
|
||||
return {batch, seqlen_k, nhead_kv, hdim};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> get_output_shape() const
|
||||
{
|
||||
if(output_layout == TensorLayout::bhsd)
|
||||
{
|
||||
return {batch, nhead_q, seqlen_q, hdim};
|
||||
}
|
||||
else
|
||||
{
|
||||
return {batch, seqlen_q, nhead_q, hdim};
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::fmha_fwd_v3_args::data_type_enum data_type;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_kv;
|
||||
ck_tile::index_t hdim;
|
||||
float softmax_scale;
|
||||
mask_info mask;
|
||||
TensorLayout input_layout;
|
||||
TensorLayout output_layout;
|
||||
std::vector<int> q_eff_lens;
|
||||
std::vector<int> kv_eff_lens;
|
||||
};
|
||||
|
||||
struct RunConfig
|
||||
{
|
||||
explicit RunConfig(const ck_tile::ArgParser& args)
|
||||
{
|
||||
seed = args.get_uint32("seed");
|
||||
if(*seed == 0)
|
||||
{
|
||||
seed.reset();
|
||||
}
|
||||
|
||||
kernel_warmup = args.get_int("warmup");
|
||||
kernel_repeat = args.get_int("repeat");
|
||||
verify = args.get_bool("v");
|
||||
}
|
||||
|
||||
std::optional<uint32_t> seed;
|
||||
int kernel_warmup;
|
||||
int kernel_repeat;
|
||||
bool verify;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
auto generate_qkv(const Problem& problem,
|
||||
[[maybe_unused]] std::optional<uint32_t> seed = std::nullopt)
|
||||
-> std::tuple<ck_tile::HostTensor<DataType>,
|
||||
ck_tile::HostTensor<DataType>,
|
||||
ck_tile::HostTensor<DataType>>
|
||||
{
|
||||
ck_tile::HostTensor<DataType> q(problem.get_query_shape());
|
||||
ck_tile::HostTensor<DataType> k(problem.get_key_shape());
|
||||
ck_tile::HostTensor<DataType> v(problem.get_value_shape());
|
||||
|
||||
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(q);
|
||||
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(k);
|
||||
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(v);
|
||||
|
||||
return std::make_tuple(q, k, v);
|
||||
}
|
||||
|
||||
namespace host {
|
||||
template <typename AccDataType,
|
||||
typename PDataType,
|
||||
typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename ODataType,
|
||||
typename QElementOp,
|
||||
typename KElementOp,
|
||||
typename VElementOp,
|
||||
typename SAccElementOp>
|
||||
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
|
||||
const ck_tile::HostTensor<KDataType>& k_bshd,
|
||||
const ck_tile::HostTensor<VDataType>& v_bshd,
|
||||
const mask_info& mask,
|
||||
ck_tile::HostTensor<ODataType>& o_bshd,
|
||||
const QElementOp& q_element_op = {},
|
||||
const KElementOp& k_element_op = {},
|
||||
const VElementOp& v_element_op = {},
|
||||
const SAccElementOp& s_acc_element_op = {})
|
||||
{
|
||||
const int batch_size = q_bshd.mDesc.get_lengths()[0];
|
||||
const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
|
||||
const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
|
||||
const int nhead_q = q_bshd.mDesc.get_lengths()[2];
|
||||
const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
|
||||
const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
|
||||
const int hdim_v = v_bshd.mDesc.get_lengths()[3];
|
||||
|
||||
const int nr = nhead_q / nhead_kv;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
|
||||
ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
|
||||
ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
|
||||
ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
|
||||
|
||||
ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
|
||||
ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
|
||||
|
||||
// do computation for each batch
|
||||
for(int b = 0; b < batch_size; ++b)
|
||||
{
|
||||
// copy per-batch data from input tensors
|
||||
// clang-format off
|
||||
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); });
|
||||
k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); });
|
||||
v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
|
||||
// clang-format on
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
|
||||
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
|
||||
|
||||
if(mask.type == mask_enum::no_mask)
|
||||
{
|
||||
ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
|
||||
}
|
||||
else if(mask.type == mask_enum::window_generic)
|
||||
{
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left, mask.right, seqlen_q, seqlen_kv));
|
||||
}
|
||||
else
|
||||
{
|
||||
// if left window size is negative, means causal
|
||||
// else means generic (for current batch)
|
||||
if(mask.left < 0)
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
else
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
|
||||
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
|
||||
s_host_ref, p_host_ref, ck_tile::identity{});
|
||||
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
|
||||
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
|
||||
|
||||
// copy resulting per-batch data to the output tensor
|
||||
o_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
|
||||
}
|
||||
}
|
||||
} // namespace host
|
||||
|
||||
template <typename DataType>
|
||||
bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
{
|
||||
auto [q, k, v] = generate_qkv<DataType>(problem, run_config.seed);
|
||||
|
||||
ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes());
|
||||
/// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v
|
||||
ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q.data());
|
||||
k_buf.ToDevice(k.data());
|
||||
v_buf.ToDevice(v.data());
|
||||
// Ensure output buffer is zero-initialized so padded regions compare cleanly
|
||||
o_buf.SetZero();
|
||||
|
||||
ck_tile::fmha_fwd_v3_args args{};
|
||||
|
||||
args.data_type = problem.data_type;
|
||||
args.batch = problem.batch;
|
||||
args.seqlen_q = problem.seqlen_q;
|
||||
args.seqlen_k = problem.seqlen_k;
|
||||
args.nhead_q = problem.nhead_q;
|
||||
args.nhead_kv = problem.nhead_kv;
|
||||
args.hdim_qk = problem.hdim;
|
||||
args.hdim_v = problem.hdim;
|
||||
args.softmax_scale = problem.softmax_scale;
|
||||
|
||||
args.window_size_left = problem.mask.left;
|
||||
args.window_size_right = problem.mask.right;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(problem.mask.type);
|
||||
|
||||
// bshd: (batch, seqlen_q, nhead_q, hdim)
|
||||
// bhsd: (batch, nhead_q, seqlen_q, hdim)
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.stride_q =
|
||||
problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
|
||||
args.nhead_stride_q =
|
||||
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim;
|
||||
args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim;
|
||||
|
||||
// bshd: (batch, seqlen_k, nhead_kv, hdim)
|
||||
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.stride_k =
|
||||
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
|
||||
args.nhead_stride_k =
|
||||
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
|
||||
args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim;
|
||||
|
||||
// bshd: (batch, seqlen_k, nhead_kv, hdim)
|
||||
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
args.stride_v =
|
||||
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
|
||||
args.nhead_stride_v =
|
||||
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
|
||||
args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim;
|
||||
|
||||
// bshd: (batch, seqlen_q, nhead_q, hdim)
|
||||
// bhsd: (batch, nhead_q, seqlen_q, hdim)
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
args.stride_o =
|
||||
problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
|
||||
args.nhead_stride_o = problem.output_layout == TensorLayout::bshd
|
||||
? problem.hdim
|
||||
: problem.seqlen_q * problem.hdim;
|
||||
args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim;
|
||||
|
||||
// Optional cumulative seqlen overrides (exclude PAD)
|
||||
const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1;
|
||||
const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1;
|
||||
|
||||
auto make_effective_vec = [&](const std::vector<int>& opt_vec, ck_tile::index_t fallback) {
|
||||
std::vector<ck_tile::index_t> eff;
|
||||
if(!opt_vec.empty() && opt_vec[0] != -1)
|
||||
{
|
||||
eff.assign(opt_vec.begin(), opt_vec.end());
|
||||
if(eff.size() < static_cast<size_t>(problem.batch))
|
||||
{
|
||||
eff.resize(problem.batch, eff.back());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
eff.assign(problem.batch, fallback);
|
||||
}
|
||||
return eff;
|
||||
};
|
||||
|
||||
const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q);
|
||||
const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k);
|
||||
|
||||
// Calculate cumulative sums for kernel arguments if varlen is used
|
||||
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
|
||||
auto calculate_cumulative = [&](const std::vector<ck_tile::index_t>& per_batch_vec,
|
||||
std::vector<ck_tile::index_t>& cum_vec) {
|
||||
cum_vec.resize(per_batch_vec.size() + 1);
|
||||
cum_vec[0] = 0;
|
||||
for(std::size_t i = 0; i < per_batch_vec.size(); ++i)
|
||||
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
|
||||
};
|
||||
|
||||
if(has_varlen_q)
|
||||
{
|
||||
calculate_cumulative(eff_q_vec, cuq_cum);
|
||||
}
|
||||
if(has_varlen_k)
|
||||
{
|
||||
calculate_cumulative(eff_kv_vec, cukv_cum);
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0);
|
||||
ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0);
|
||||
cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr);
|
||||
cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr);
|
||||
args.cu_seqlen_q_ptr =
|
||||
!cuq_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cuq_buf.GetDeviceBuffer())
|
||||
: nullptr;
|
||||
args.cu_seqlen_kv_ptr =
|
||||
!cukv_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cukv_buf.GetDeviceBuffer())
|
||||
: nullptr;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/*log_level=*/0,
|
||||
run_config.kernel_warmup,
|
||||
run_config.kernel_repeat};
|
||||
|
||||
auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config);
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "faild to run fmha_fwd_v3()" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::size_t flop = [&] {
|
||||
if(problem.mask.type == mask_enum::no_mask)
|
||||
{
|
||||
return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
|
||||
problem.hdim;
|
||||
}
|
||||
else
|
||||
{
|
||||
/// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2.
|
||||
return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
|
||||
problem.hdim;
|
||||
}
|
||||
}();
|
||||
float tflops = static_cast<float>(flop) / 1.e9 / time;
|
||||
|
||||
std::cout << "[" << problem.data_type << "|";
|
||||
if(problem.input_layout == problem.output_layout)
|
||||
{
|
||||
std::cout << problem.input_layout;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << problem.input_layout << "-" << problem.output_layout;
|
||||
}
|
||||
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
|
||||
<< ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim
|
||||
<< ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed
|
||||
<< ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops
|
||||
<< " TFlops" << std::endl;
|
||||
|
||||
if(!run_config.verify)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// transpose tensor descriptors from bhsd to bshd if necessary
|
||||
if(problem.input_layout != TensorLayout::bshd)
|
||||
{
|
||||
q = q.transpose({0, 2, 1, 3});
|
||||
k = k.transpose({0, 2, 1, 3});
|
||||
v = v.transpose({0, 2, 1, 3});
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
|
||||
if(problem.output_layout != TensorLayout::bshd)
|
||||
{
|
||||
o_ref = o_ref.transpose({0, 2, 1, 3});
|
||||
}
|
||||
|
||||
// If variable lengths are provided, compute per-batch references
|
||||
// with the effective lengths; else compute a single full reference.
|
||||
if(has_varlen_q || has_varlen_k)
|
||||
{
|
||||
// Variable-length aware verification: zero-fill padded region and only compute valid part.
|
||||
o_ref.SetZero();
|
||||
|
||||
for(int b = 0; b < problem.batch; ++b)
|
||||
{
|
||||
const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
|
||||
const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
|
||||
|
||||
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
|
||||
continue;
|
||||
|
||||
// Slice current batch from inputs (bshd) and build single-batch tensors
|
||||
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
|
||||
// Copy effective region
|
||||
q_b.ForEach([&](auto& self, auto idx) {
|
||||
// idx: [0, s, h, d]
|
||||
self(idx) = q(b, idx[1], idx[2], idx[3]);
|
||||
});
|
||||
k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
|
||||
v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
|
||||
|
||||
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
|
||||
host::fmha_fwd<float, DataType>(q_b,
|
||||
k_b,
|
||||
v_b,
|
||||
problem.mask,
|
||||
o_b,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales{problem.softmax_scale});
|
||||
|
||||
// Scatter into o_ref's bshd descriptor memory
|
||||
for(int s = 0; s < seqlen_q_eff; ++s)
|
||||
{
|
||||
for(int h = 0; h < problem.nhead_q; ++h)
|
||||
{
|
||||
for(int d = 0; d < problem.hdim; ++d)
|
||||
{
|
||||
o_ref(b, s, h, d) = o_b(0, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// No varlen override: compute the full reference once
|
||||
host::fmha_fwd<float, DataType>(q,
|
||||
k,
|
||||
v,
|
||||
problem.mask,
|
||||
o_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales{problem.softmax_scale});
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
|
||||
o_buf.FromDevice(o.data());
|
||||
|
||||
const auto [rtol, atol] = [&] {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
|
||||
return std::make_tuple(1e-3, 1e-3);
|
||||
else
|
||||
return std::make_tuple(1e-2, 1e-2);
|
||||
}();
|
||||
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [parse_result, args] = parse_cmd_args(argc, argv);
|
||||
if(!parse_result)
|
||||
{
|
||||
std::cerr << "failed to parse command line arguments" << std::endl;
|
||||
}
|
||||
|
||||
Problem problem(args);
|
||||
RunConfig run_config(args);
|
||||
|
||||
const auto run = [&] {
|
||||
if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16)
|
||||
{
|
||||
return run_impl<ck_tile::fp16_t>(problem, run_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_impl<ck_tile::bf16_t>(problem, run_config);
|
||||
}
|
||||
};
|
||||
|
||||
return !run();
|
||||
}
|
||||
@@ -686,6 +686,100 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
|
||||
{
|
||||
/// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly
|
||||
/// maximizes the kernel's performance.
|
||||
int remap_opt = 2;
|
||||
if(args.mask_type != static_cast<int>(mask_enum::no_mask) &&
|
||||
((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q)))
|
||||
{
|
||||
if(65536 <= args.seqlen_q)
|
||||
{
|
||||
remap_opt = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
remap_opt = 1;
|
||||
}
|
||||
}
|
||||
|
||||
auto kargs = [&] {
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
nullptr, // lse_ptr
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
0, // nhead_stride_lse
|
||||
args.nhead_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
remap_opt,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
nullptr, // lse_ptr
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
0, // nhead_stride_lse
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
0, // batch_stride_lse
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
remap_opt,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
{
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "fmha_fwd_v3_impl.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type)
|
||||
{
|
||||
switch(data_type)
|
||||
{
|
||||
case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16";
|
||||
case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16";
|
||||
default: return stream << "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config)
|
||||
{
|
||||
if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16)
|
||||
{
|
||||
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
|
||||
{
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
|
||||
|
||||
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
|
||||
}
|
||||
else
|
||||
{
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
|
||||
|
||||
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
|
||||
}
|
||||
}
|
||||
else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16)
|
||||
{
|
||||
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
|
||||
{
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
|
||||
|
||||
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
|
||||
}
|
||||
else
|
||||
{
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
|
||||
|
||||
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(false, -1.f);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,73 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct fmha_fwd_v3_args
|
||||
{
|
||||
enum class data_type_enum
|
||||
{
|
||||
fp16,
|
||||
bf16
|
||||
};
|
||||
|
||||
data_type_enum data_type;
|
||||
// bool is_varlen;
|
||||
|
||||
index_t batch;
|
||||
index_t seqlen_q;
|
||||
index_t seqlen_k;
|
||||
index_t nhead_q;
|
||||
index_t nhead_kv;
|
||||
index_t hdim_qk;
|
||||
index_t hdim_v;
|
||||
|
||||
float softmax_scale;
|
||||
|
||||
index_t window_size_left;
|
||||
index_t window_size_right;
|
||||
index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and
|
||||
// window_size_right == 0).
|
||||
|
||||
const void* q_ptr;
|
||||
index_t stride_q;
|
||||
index_t nhead_stride_q;
|
||||
index_t batch_stride_q;
|
||||
|
||||
const void* k_ptr;
|
||||
index_t stride_k;
|
||||
index_t nhead_stride_k;
|
||||
index_t batch_stride_k;
|
||||
|
||||
const void* v_ptr;
|
||||
index_t stride_v;
|
||||
index_t nhead_stride_v;
|
||||
index_t batch_stride_v;
|
||||
|
||||
void* o_ptr;
|
||||
index_t stride_o;
|
||||
index_t nhead_stride_o;
|
||||
index_t batch_stride_o;
|
||||
|
||||
// Optional batch-mode cumulative seqlen overrides (exclude PAD)
|
||||
// If provided, they override per-batch effective lengths to skip tail padding.
|
||||
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
|
||||
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type);
|
||||
|
||||
// return value:
|
||||
// first = whether the kernel was launched (true = launched, false = skipped)
|
||||
// second = elapsed time (ms) of the kernel launch, valid only if first == true
|
||||
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config);
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,179 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \
|
||||
template <> \
|
||||
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch<kernel_traits>( \
|
||||
const fmha_fwd_v3_args& args, const stream_config& config) \
|
||||
{ \
|
||||
return std::make_pair(true, \
|
||||
fmha_fwd_v3_kernel_launch<kernel_traits::kernel>(args, config)); \
|
||||
}
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <fmha_fwd_v3_args::data_type_enum DataType>
|
||||
struct fmha_fwd_v3_problem_traits;
|
||||
|
||||
template <>
|
||||
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::fp16>
|
||||
{
|
||||
using qkvp_dtype = ck_tile::half_t;
|
||||
using acc_dtype = float;
|
||||
using o_dtype = ck_tile::half_t;
|
||||
using lse_dtype = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::bf16>
|
||||
{
|
||||
using qkvp_dtype = ck_tile::bf16_t;
|
||||
using acc_dtype = float;
|
||||
using o_dtype = ck_tile::bf16_t;
|
||||
using lse_dtype = float;
|
||||
};
|
||||
|
||||
template <fmha_fwd_v3_args::data_type_enum DataType, bool IsVariableSeqlen, bool IsMasking>
|
||||
struct fmha_fwd_v3_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_variable_seqlen = IsVariableSeqlen;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
// M0 N0 K0 N1 K1
|
||||
using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>;
|
||||
using fmha_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
using fmha_block_warps = sequence<8, 1, 1>;
|
||||
|
||||
using fmha_shape = TileFmhaShape<fmha_block_tile,
|
||||
fmha_block_warps,
|
||||
fmha_warp_gemm_shape,
|
||||
fmha_block_warps,
|
||||
fmha_warp_gemm_shape,
|
||||
true // IsVLayoutRowMajor
|
||||
>;
|
||||
|
||||
using fmha_traits = TileFmhaFwdV3Traits<true, // kPadSeqLenQ
|
||||
true, // kPadSeqLenK
|
||||
false, // kPadHeadDimQ
|
||||
false, // kPadHeadDimV
|
||||
false, // kStoreLSE
|
||||
-1 // kBlockPerCu
|
||||
>;
|
||||
|
||||
using fmha_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
|
||||
|
||||
using fmha_pipeline_problem =
|
||||
BlockFmhaFwdV3PipelineProblem<typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::lse_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
|
||||
fmha_shape,
|
||||
IsVariableSeqlen,
|
||||
fmha_mask,
|
||||
fmha_traits>;
|
||||
|
||||
using fmha_pipeline = BlockFmhaFwdV3Pipeline<fmha_pipeline_problem>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
|
||||
true, // kPadM
|
||||
true, // kPadM
|
||||
true // UseRawStore
|
||||
>>;
|
||||
|
||||
using kernel = FmhaFwdV3Kernel<fmha_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config)
|
||||
{
|
||||
/// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly
|
||||
/// maximizes the kernel's performance.
|
||||
int remap_opt = 2;
|
||||
if(args.mask_type != static_cast<int>(mask_enum::no_mask) &&
|
||||
((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q)))
|
||||
{
|
||||
if(65536 <= args.seqlen_q)
|
||||
{
|
||||
remap_opt = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
remap_opt = 1;
|
||||
}
|
||||
}
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
nullptr, // lse_ptr
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_qk,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_kv,
|
||||
args.softmax_scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
0, // nhead_stride_lse
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
0, // batch_stride_lse
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
remap_opt,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_kv_ptr);
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
|
||||
|
||||
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
// return value:
|
||||
// first = whether the kernel was launched (true = launched, false = skipped)
|
||||
// second = elapsed time (ms) of the kernel launch, valid only if first == true
|
||||
template <typename KernelTraits>
|
||||
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args,
|
||||
const stream_config& config);
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "fmha_fwd_v3_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
|
||||
|
||||
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "fmha_fwd_v3_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
|
||||
|
||||
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "fmha_fwd_v3_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
|
||||
|
||||
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,14 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "fmha_fwd_v3_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
|
||||
|
||||
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -284,26 +284,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
else if(init == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed, true}(a_host);
|
||||
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed, true}(g_host);
|
||||
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed, true}(d_host);
|
||||
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host);
|
||||
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
|
||||
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
|
||||
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed, true}(sy_host);
|
||||
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f, seed, true}(
|
||||
topk_weight_host);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed}(a_host);
|
||||
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed}(g_host);
|
||||
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed}(d_host);
|
||||
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed}(sa_host);
|
||||
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed}(sg_host);
|
||||
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed}(sd_host);
|
||||
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed}(sy_host);
|
||||
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f, seed}(topk_weight_host);
|
||||
}
|
||||
else if(init == 2)
|
||||
{
|
||||
ck_tile::FillNormalDistribution<ADataType>{0.f, 1.f, seed, true}(a_host);
|
||||
ck_tile::FillNormalDistribution<GDataType>{0.f, 1.f, seed, true}(g_host);
|
||||
ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed, true}(d_host);
|
||||
ck_tile::FillNormalDistribution<AScaleDataType>{0.f, 1.f, seed, true}(sa_host);
|
||||
ck_tile::FillNormalDistribution<GScaleDataType>{0.f, 1.f, seed, true}(sg_host);
|
||||
ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed, true}(sd_host);
|
||||
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host);
|
||||
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host);
|
||||
ck_tile::FillNormalDistribution<ADataType>{0.f, 1.f, seed}(a_host);
|
||||
ck_tile::FillNormalDistribution<GDataType>{0.f, 1.f, seed}(g_host);
|
||||
ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed}(d_host);
|
||||
ck_tile::FillNormalDistribution<AScaleDataType>{0.f, 1.f, seed}(sa_host);
|
||||
ck_tile::FillNormalDistribution<GScaleDataType>{0.f, 1.f, seed}(sg_host);
|
||||
ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed}(sd_host);
|
||||
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed}(sy_host);
|
||||
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed}(topk_weight_host);
|
||||
}
|
||||
|
||||
// permute weight
|
||||
|
||||
@@ -9,14 +9,190 @@
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "quant_grouped_gemm.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename QuantGroupSize,
|
||||
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
GemmConfig::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
GemmConfig::Persistent>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
|
||||
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
ADataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
@@ -59,41 +235,48 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
true>; // Persistence
|
||||
GemmConfig::Persistent>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using QuantGemmProblem = typename std::conditional<
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize>,
|
||||
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped;
|
||||
|
||||
using QuantGemmProblem = std::conditional_t<
|
||||
UseGroupedQuant,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize,
|
||||
GemmConfig::TransposeC>,
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
QuantGroupSize>>,
|
||||
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
GemmConfig::TransposeC,
|
||||
BDataType,
|
||||
scheduler>>::type;
|
||||
scheduler>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
std::conditional_t<GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
|
||||
using GemmPipeline =
|
||||
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
|
||||
GemmConfig::PreshuffleB>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
@@ -146,6 +329,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
int result1 = !run_grouped_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
|
||||
int result1 = run_grouped_gemm_example(argc, argv);
|
||||
return result1;
|
||||
}
|
||||
|
||||
@@ -64,6 +64,7 @@ struct GemmTypeConfig<ck_tile::bf8_t>
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <bool Persistent_>
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
@@ -83,10 +84,11 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
static constexpr bool Persistent = Persistent_;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
template <typename PrecType, bool Persistent>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -101,8 +103,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
template <typename PrecType, bool Persistent>
|
||||
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -121,6 +123,66 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <ck_tile::QuantType QuantMode>
|
||||
struct GemmQuantConfig;
|
||||
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::TensorQuant>
|
||||
{
|
||||
template <typename PrecType, bool Persistent>
|
||||
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::RowColQuant>
|
||||
{
|
||||
template <typename PrecType, bool Persistent>
|
||||
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::AQuantGrouped>
|
||||
{
|
||||
template <typename PrecType, bool Persistent>
|
||||
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantConfig<ck_tile::QuantType::BQuantGrouped>
|
||||
{
|
||||
template <typename PrecType, bool Persistent>
|
||||
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<PrecType, Persistent>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using GemmPipeline = std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<GemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<GemmProblem>>;
|
||||
|
||||
template <typename GemmProblem, bool PreshuffleB = false>
|
||||
using BaseGemmPipeline =
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
@@ -148,8 +210,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol")
|
||||
.insert("init", "0", "0. Random, 2. One(s) (Constant)");
|
||||
.insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol")
|
||||
.insert("init", "0", "0. Random, 2. One(s) (Constant)")
|
||||
.insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -57,56 +57,83 @@ float invoke_gemm(int n_warmup,
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
|
||||
// the gemm problems known on the host. Instead, we can just pass the pointer
|
||||
// to the kernel and let the workgroups figure out which tiles to work on.
|
||||
// This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
assert(args[0].k_batch == 1);
|
||||
for(const auto& arg : args)
|
||||
if constexpr(!GemmConfig::Persistent)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.aq_ptr,
|
||||
arg.bq_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.QK_A,
|
||||
arg.QK_B,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_E,
|
||||
arg.stride_AQ,
|
||||
arg.stride_BQ,
|
||||
arg.k_batch});
|
||||
ave_time =
|
||||
grouped_gemm<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
QuantMode>(args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
|
||||
// the gemm problems known on the host. Instead, we can just pass the pointer
|
||||
// to the kernel and let the workgroups figure out which tiles to work on.
|
||||
// This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
if(args[0].k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("Split-K not supported yet for persistent kernel");
|
||||
}
|
||||
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.aq_ptr,
|
||||
arg.bq_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.QK_A,
|
||||
arg.QK_B,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_E,
|
||||
arg.stride_AQ,
|
||||
arg.stride_BQ,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
QuantMode>(stream, group_count, kargs_ptr);
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
QuantMode>(stream, group_count, kargs_ptr);
|
||||
|
||||
std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")";
|
||||
|
||||
@@ -259,13 +286,24 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
|
||||
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
|
||||
BQK = 0; // No B quantization
|
||||
if(K % QuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by QuantGroupSize::kK for AQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
AQK = 0; // No A quantization
|
||||
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
|
||||
if(K % QuantGroupSize::kK != 0)
|
||||
{
|
||||
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
|
||||
throw std::runtime_error(
|
||||
"K must be divisible by QuantGroupSize::kK for BQuantGrouped mode");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,6 +322,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
stride_AQs[i] =
|
||||
ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout));
|
||||
stride_BQs[i] = 0; // No B quantization
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
stride_AQs[i] = 0; // No A quantization
|
||||
@@ -311,10 +355,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(0, 0, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout))));
|
||||
ck_tile::host_tensor_descriptor(0, 0, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
|
||||
}
|
||||
@@ -444,7 +495,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
bq_tensors[i],
|
||||
c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
AQDataType,
|
||||
@@ -452,6 +503,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
true>(
|
||||
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
BQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(
|
||||
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
@@ -477,7 +539,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename PrecType, ck_tile::QuantType QuantMode>
|
||||
template <typename PrecType, ck_tile::QuantType QuantMode, typename GemmConfig>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
@@ -494,6 +556,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
@@ -511,7 +574,24 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
template <typename PrecType, ck_tile::QuantType QuantMode>
|
||||
int run_gemm_example_persistency(
|
||||
std::string a_layout, std::string b_layout, bool persistent, int argc, char* argv[])
|
||||
{
|
||||
if(persistent)
|
||||
{
|
||||
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, true>;
|
||||
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, false>;
|
||||
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
|
||||
int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -524,29 +604,29 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
std::string quant_mode = arg_parser.get_str("quant_mode");
|
||||
bool persistent = arg_parser.get_bool("persistent");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -557,24 +637,23 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
if(quant_mode == "tensor")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::TensorQuant>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "rowcol")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::RowColQuant>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "aquant")
|
||||
{
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::AQuantGrouped>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else if(quant_mode == "bquant")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::BQuantGrouped>(
|
||||
a_layout, b_layout, persistent, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -71,17 +71,17 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b);
|
||||
ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -69,7 +69,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>>;
|
||||
|
||||
const ck_tile::index_t K_split =
|
||||
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
|
||||
@@ -128,7 +133,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
|
||||
std::conditional_t<GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
|
||||
@@ -433,7 +440,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
stride_AQ = 0; // No A quantization
|
||||
stride_BQ = ck_tile::get_default_stride(BQK, N, stride_BQ, is_row_major(bq_layout));
|
||||
stride_BQ = ck_tile::get_default_stride(BQK, BQN, stride_BQ, is_row_major(bq_layout));
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
|
||||
@@ -95,6 +95,47 @@ concept AccessOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
|
||||
};
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
|
||||
// size is deduced from block gemm structure).
|
||||
template <typename T>
|
||||
concept TileThreadBlockDescriptor = requires(T t) {
|
||||
{ t.tile_size.m } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.n } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.k } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
|
||||
// size is deduced from block gemm structure).
|
||||
template <typename T>
|
||||
concept TileTransferDescriptor = requires(T t) {
|
||||
{ t.a_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.b_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.c_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM (CK Tile).
|
||||
template <typename T>
|
||||
concept TileBlockGemmDescriptor = requires(T t) {
|
||||
{ t.warps.m } -> std::convertible_to<int>;
|
||||
{ t.warps.n } -> std::convertible_to<int>;
|
||||
{ t.warps.k } -> std::convertible_to<int>;
|
||||
{ t.warp_tile.m } -> std::convertible_to<int>;
|
||||
{ t.warp_tile.n } -> std::convertible_to<int>;
|
||||
{ t.warp_tile.k } -> std::convertible_to<int>;
|
||||
{ t.double_smem_buffer } -> std::convertible_to<bool>;
|
||||
{ t.num_wave_groups } -> std::convertible_to<int>;
|
||||
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies optimizations (CK Tile).
|
||||
template <typename T>
|
||||
concept TileOptimizationsDescriptor = requires(T t) {
|
||||
{ t.num_groups_to_merge } -> std::convertible_to<int>;
|
||||
{ t.split_image } -> std::convertible_to<bool>;
|
||||
{ t.explicit_gemm } -> std::convertible_to<bool>;
|
||||
};
|
||||
|
||||
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
|
||||
// concept.
|
||||
template <typename T>
|
||||
@@ -110,6 +151,12 @@ concept SpecifiesThreadBlock = requires {
|
||||
{ T::thread_block } -> ThreadBlockDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies thread block info (CK Tile).
|
||||
template <typename T>
|
||||
concept SpecifiesTileThreadBlock = requires {
|
||||
{ T::thread_block } -> TileThreadBlockDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseXdlGemm = requires {
|
||||
@@ -130,6 +177,14 @@ concept SpecifiesBlockTransfer = requires(T t) {
|
||||
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C.
|
||||
template <typename T>
|
||||
concept SpecifiesTileTransfer = requires(T t) {
|
||||
{ T::transfer.a_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ T::transfer.b_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ T::transfer.c_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
|
||||
template <typename T>
|
||||
concept SpecifiesLdsTransfer = requires(T t) {
|
||||
@@ -159,8 +214,36 @@ concept SpecifiesBlockGemm = requires {
|
||||
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM (CK Tile).
|
||||
template <typename T>
|
||||
concept SpecifiesFwdConcSpecialization = requires {
|
||||
concept SpecifiesTileBlockGemm = requires {
|
||||
{ T::block_gemm.warps.m } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warps.n } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warps.k } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warp_tile.m } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warp_tile.n } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.warp_tile.k } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.double_smem_buffer } -> std::convertible_to<bool>;
|
||||
{ T::block_gemm.num_wave_groups } -> std::convertible_to<int>;
|
||||
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM (CK Tile).
|
||||
template <typename T>
|
||||
concept SpecifiesTileOptimizations = requires {
|
||||
{ T::optimizations.num_groups_to_merge } -> std::convertible_to<int>;
|
||||
{ T::optimizations.split_image } -> std::convertible_to<bool>;
|
||||
{ T::optimizations.explicit_gemm } -> std::convertible_to<bool>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesTileConvSpecialization = requires {
|
||||
{ T::specialization } -> std::convertible_to<TileConvSpecialization>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesFwdConvSpecialization = requires {
|
||||
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
|
||||
};
|
||||
|
||||
|
||||
@@ -15,6 +15,11 @@ concept InputVectorTransferLimits = requires {
|
||||
Value.lds_dst_scalar_per_vector > 0;
|
||||
};
|
||||
|
||||
// Limits for input and output vector transfer (CK Tile).
|
||||
template <auto Value>
|
||||
concept TileInputOutputVectorTransferLimits =
|
||||
requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; };
|
||||
|
||||
// Limits for output vector transfer.
|
||||
template <auto Value>
|
||||
concept OutputVectorTransferLimits = requires {
|
||||
|
||||
@@ -59,6 +59,7 @@
|
||||
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
@@ -81,6 +82,15 @@ namespace ck_tile::builder::factory {
|
||||
//
|
||||
// TODO: Make this dispatch logic much more robust and clear for users.
|
||||
|
||||
// CK Tile kernel
|
||||
template <typename T>
|
||||
consteval bool IsTileAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> && SpecifiesTileTransfer<T> &&
|
||||
SpecifiesTileConvSpecialization<T> && SpecifiesTileBlockGemm<T> &&
|
||||
SpecifiesTileOptimizations<T>;
|
||||
}
|
||||
|
||||
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
|
||||
template <typename T>
|
||||
consteval bool IsXdlV3Algorithm()
|
||||
@@ -88,7 +98,7 @@ consteval bool IsXdlV3Algorithm()
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
}
|
||||
|
||||
@@ -99,7 +109,7 @@ consteval bool IsXdlAlgorithm()
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesNumGroupsToMerge<T> &&
|
||||
SpecifiesLoopScheduler<T>;
|
||||
}
|
||||
@@ -111,7 +121,7 @@ consteval bool IsWmmaAlgorithm()
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
}
|
||||
|
||||
@@ -120,7 +130,7 @@ template <typename T>
|
||||
consteval bool IsDlAlgorithm()
|
||||
{
|
||||
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
|
||||
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
|
||||
SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
}
|
||||
@@ -137,10 +147,15 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
StringLiteral VERSION>
|
||||
constexpr auto make_conv_instance()
|
||||
{
|
||||
if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
|
||||
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
|
||||
|
||||
// CK Tile supports common factory for each direction
|
||||
if constexpr(IsTileAlgorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
if constexpr(IsXdlV3Algorithm<AlgoType>())
|
||||
{
|
||||
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
|
||||
@@ -7,11 +7,11 @@
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for CK Tile Grouped Convolution kernels.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
struct ConvTileFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::TileConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::TileConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::TileElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto CONV_SPECIALIZATION = internal::SetTileConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetTileThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm<ALGORITHM>();
|
||||
static constexpr auto OPTIMIZATIONS = internal::SetTileOptimizations<ALGORITHM>();
|
||||
static constexpr auto SCALAR_PER_VECTOR = internal::SetTileBlockTransfer<ALGORITHM>();
|
||||
static constexpr auto CONV_DIRECTION = internal::SetTileConvDirection<SIGNATURE>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(TileInputOutputVectorTransferLimits<SCALAR_PER_VECTOR>);
|
||||
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<SPATIAL_DIM,
|
||||
CONV_SPECIALIZATION,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
SCALAR_PER_VECTOR.a,
|
||||
SCALAR_PER_VECTOR.b,
|
||||
SCALAR_PER_VECTOR.c,
|
||||
OPTIMIZATIONS.num_groups_to_merge,
|
||||
OPTIMIZATIONS.split_image,
|
||||
OPTIMIZATIONS.explicit_gemm>;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k>,
|
||||
ck_tile::sequence<BLOCK_GEMM.warps.m, BLOCK_GEMM.warps.n, BLOCK_GEMM.warps.k>,
|
||||
ck_tile::sequence<BLOCK_GEMM.warp_tile.m, BLOCK_GEMM.warp_tile.n, BLOCK_GEMM.warp_tile.k>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
||||
BLOCK_GEMM.double_smem_buffer,
|
||||
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::AsLayout,
|
||||
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::BsLayout,
|
||||
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::CLayout,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
BLOCK_GEMM.num_wave_groups>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
BLOCK_GEMM.scheduler,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Types::EDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename internal::TilePipelineType<
|
||||
BLOCK_GEMM.pipeline_version>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::AccDataType,
|
||||
typename Types::EDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK_GEMM.warps.m,
|
||||
BLOCK_GEMM.warps.n,
|
||||
BLOCK_GEMM.warp_tile.m,
|
||||
BLOCK_GEMM.warp_tile.n,
|
||||
BLOCK_GEMM.warp_tile.k,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
// TODO:: This template parameter will be moved inside the kernel
|
||||
ck_tile::memory_operation_enum::set,
|
||||
BLOCK_GEMM.num_wave_groups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
SCALAR_PER_VECTOR.c>>;
|
||||
|
||||
using Instance = typename internal::GroupedConvolutionTileKernel<SIGNATURE,
|
||||
GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>::Instance;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,25 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
struct TileScalarPerVector
|
||||
{
|
||||
size_t a = 0;
|
||||
size_t b = 0;
|
||||
size_t c = 0;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr TileScalarPerVector SetTileBlockTransfer()
|
||||
{
|
||||
return TileScalarPerVector{.a = ALGORITHM.transfer.a_scalar_per_vector,
|
||||
.b = ALGORITHM.transfer.b_scalar_per_vector,
|
||||
.c = ALGORITHM.transfer.c_scalar_per_vector};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,62 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
template <ElementwiseOperation Op>
|
||||
struct ElementwiseOpToCKTile
|
||||
{
|
||||
static_assert(sizeof(UnsupportedEnumValue<Op>) == 0,
|
||||
"Unsupported elementwise operation conversion to CK.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCKTile<ElementwiseOperation::PASS_THROUGH>
|
||||
{
|
||||
using Op = ck_tile::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCKTile<ElementwiseOperation::SCALE>
|
||||
{
|
||||
using Op = ck_tile::element_wise::Scale;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOpToCKTile<ElementwiseOperation::CLAMP>
|
||||
{
|
||||
using Op = ck_tile::element_wise::Clamp;
|
||||
};
|
||||
|
||||
template <auto TensorDesc>
|
||||
consteval auto GetTileElementwiseOp()
|
||||
{
|
||||
if constexpr(HasTensorOp<decltype(TensorDesc)>)
|
||||
{
|
||||
constexpr auto op = TensorDesc.operation.elementwise_operation;
|
||||
return ElementwiseOpToCKTile<op>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ElementwiseOpToCKTile<ElementwiseOperation::PASS_THROUGH>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
struct TileElementwiseOps
|
||||
{
|
||||
static constexpr auto input_op = GetTileElementwiseOp<Sig.input>();
|
||||
static constexpr auto weight_op = GetTileElementwiseOp<Sig.weight>();
|
||||
static constexpr auto output_op = GetTileElementwiseOp<Sig.output>();
|
||||
using AElementwiseOp = typename decltype(input_op)::Op;
|
||||
using BElementwiseOp = typename decltype(weight_op)::Op;
|
||||
using CDEElementwiseOp = typename decltype(output_op)::Op;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
typename GroupedConvTraitsType,
|
||||
typename TilePartitioner,
|
||||
typename GemmPipeline,
|
||||
typename ConvEpilogue>
|
||||
struct GroupedConvolutionTileKernel
|
||||
{
|
||||
static_assert(false, "Unknown Direction");
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
typename GroupedConvTraitsType,
|
||||
typename TilePartitioner,
|
||||
typename GemmPipeline,
|
||||
typename ConvEpilogue>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct GroupedConvolutionTileKernel<SIGNATURE,
|
||||
GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>
|
||||
{
|
||||
using Instance = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
typename GroupedConvTraitsType,
|
||||
typename TilePartitioner,
|
||||
typename GemmPipeline,
|
||||
typename ConvEpilogue>
|
||||
requires ConvDirectionIsBackwardData<SIGNATURE>
|
||||
struct GroupedConvolutionTileKernel<SIGNATURE,
|
||||
GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>
|
||||
{
|
||||
using Instance = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
typename GroupedConvTraitsType,
|
||||
typename TilePartitioner,
|
||||
typename GemmPipeline,
|
||||
typename ConvEpilogue>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct GroupedConvolutionTileKernel<SIGNATURE,
|
||||
GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>
|
||||
{
|
||||
using Instance = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE>
|
||||
consteval ck_tile::GroupedConvDirection SetTileConvDirection()
|
||||
{
|
||||
constexpr auto direction = SIGNATURE.direction;
|
||||
using ck_tile_direction = ck_tile::GroupedConvDirection;
|
||||
switch(direction)
|
||||
{
|
||||
case ConvDirection::FORWARD: return ck_tile_direction::FORWARD;
|
||||
case ConvDirection::BACKWARD_DATA: return ck_tile_direction::BACKWARD_DATA;
|
||||
case ConvDirection::BACKWARD_WEIGHT: return ck_tile_direction::BACKWARD_WEIGHT;
|
||||
default: throw "Unknown Direction";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,200 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
using ALayout = ck_tile::tensor_layout::convolution::NWGC;
|
||||
template <TensorLayout Layout>
|
||||
struct LayoutToCKTile
|
||||
{
|
||||
static_assert(sizeof(UnsupportedEnumValue<Layout>) == 0,
|
||||
"Unsupported layout conversion to CK.");
|
||||
};
|
||||
|
||||
// Bias layouts
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::G_K_strided>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::G_K;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::G_C_strided>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::G_C;
|
||||
};
|
||||
|
||||
// Input 1D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::NWGC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::NWGC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GNWC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GNWC;
|
||||
};
|
||||
|
||||
// Input 2D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::NHWGC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GNHWC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GNHWC;
|
||||
};
|
||||
|
||||
// Input 3D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::NDHWGC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GNDHWC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GNDHWC;
|
||||
};
|
||||
|
||||
// Weight 1D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GKXC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GKXC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GKCX>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GKCX;
|
||||
};
|
||||
|
||||
// Weight 2D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GKYXC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GKCYX>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GKCYX;
|
||||
};
|
||||
|
||||
// Weight 3D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GKCZYX>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GKCZYX;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GKZYXC>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
};
|
||||
|
||||
// Output 1D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::NWGK>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::NWGK;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GNWK>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GNWK;
|
||||
};
|
||||
|
||||
// Output 2D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::NHWGK>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GNHWK>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GNHWK;
|
||||
};
|
||||
|
||||
// Output 3D
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::NDHWGK>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
};
|
||||
template <>
|
||||
struct LayoutToCKTile<TensorLayout::GNDHWK>
|
||||
{
|
||||
using type = ck_tile::tensor_layout::convolution::GNDHWK;
|
||||
};
|
||||
|
||||
template <TensorLayout Layout>
|
||||
consteval auto TensorLayoutToCKTile()
|
||||
{
|
||||
return typename LayoutToCKTile<Layout>::type{};
|
||||
}
|
||||
|
||||
struct EmptyAuxiliaryTileTensorLayout
|
||||
{
|
||||
using type = ck_tile::tuple<>;
|
||||
};
|
||||
|
||||
template <auto AuxiliaryTileTensorConfigsArray, size_t... Indices>
|
||||
consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence<Indices...>)
|
||||
{
|
||||
return ck_tile::tuple<
|
||||
decltype(TensorLayoutToCKTile<AuxiliaryTileTensorConfigsArray[Indices].layout>())...>{};
|
||||
}
|
||||
|
||||
template <auto AuxiliaryTileTensorConfigsValue, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM>)
|
||||
struct AuxiliaryTileTensorLayouts
|
||||
{
|
||||
static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size();
|
||||
using type = decltype(GetAuxiliaryTileTensorLayoutTuple<AuxiliaryTileTensorConfigsValue>(
|
||||
std::make_index_sequence<Size>{}));
|
||||
};
|
||||
|
||||
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTileTensorLayouts()
|
||||
{
|
||||
return AuxiliaryTileTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
|
||||
SPATIAL_DIM>{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTileTensorLayouts()
|
||||
{
|
||||
return EmptyAuxiliaryTileTensorLayout{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> &&
|
||||
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
|
||||
struct TileConvTensorLayouts
|
||||
{
|
||||
using ALayout = decltype(TensorLayoutToCKTile<Signature.input.config.layout>());
|
||||
using BLayout = decltype(TensorLayoutToCKTile<Signature.weight.config.layout>());
|
||||
using ELayout = decltype(TensorLayoutToCKTile<Signature.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTileTensorLayouts<Signature, SPATIAL_DIM>())::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,87 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Type mappings from builder convolution data type to CK Tile tensor types.
|
||||
template <DataType T>
|
||||
struct TileConvTensorTypes
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported data type for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::FP16>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using AComputeType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using BComputeType = ck_tile::half_t;
|
||||
using CShuffleDataType = ck_tile::half_t;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::BF16>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using AComputeType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using BComputeType = ck_tile::bf16_t;
|
||||
using CShuffleDataType = ck_tile::bf16_t;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::FP32>
|
||||
{
|
||||
using ADataType = float;
|
||||
using AComputeType = float;
|
||||
using BDataType = float;
|
||||
using BComputeType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::I8>
|
||||
{
|
||||
using ADataType = int8_t;
|
||||
using AComputeType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using BComputeType = int8_t;
|
||||
using CShuffleDataType = int8_t;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = int32_t;
|
||||
using EDataType = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TileConvTensorTypes<DataType::FP8>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using AComputeType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using BComputeType = ck_tile::fp8_t;
|
||||
using CShuffleDataType = ck_tile::fp8_t;
|
||||
using DsDataTypes = ck_tile::tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck_tile::fp8_t;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Convenience struct for a tuple of m, n, and k values.
|
||||
struct TileBlockMNK
|
||||
{
|
||||
int m{};
|
||||
int n{};
|
||||
int k{};
|
||||
};
|
||||
|
||||
struct TileConvBlock
|
||||
{
|
||||
TileBlockMNK per_block = {};
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr TileConvBlock SetTileThreadBlockInfo()
|
||||
{
|
||||
constexpr auto& TB = ALGORITHM.thread_block;
|
||||
return TileConvBlock{
|
||||
.per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k},
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,158 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Convenience struct for a tuple of m, n, and k values.
|
||||
struct TileBlockGemmMNK
|
||||
{
|
||||
int m{};
|
||||
int n{};
|
||||
int k{};
|
||||
};
|
||||
|
||||
struct TileBlockGemmSpec
|
||||
{
|
||||
TileBlockGemmMNK warps = {};
|
||||
TileBlockGemmMNK warp_tile = {};
|
||||
|
||||
bool double_smem_buffer = false;
|
||||
int num_wave_groups = 1;
|
||||
|
||||
ck_tile::GemmPipeline pipeline_version;
|
||||
ck_tile::GemmPipelineScheduler scheduler;
|
||||
};
|
||||
|
||||
struct TileOptimizations
|
||||
{
|
||||
int num_groups_to_merge = 1;
|
||||
bool split_image = false;
|
||||
bool explicit_gemm = false;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck_tile::GemmPipelineScheduler SetTileScheduler()
|
||||
{
|
||||
constexpr auto scheduler = ALGORITHM.block_gemm.scheduler;
|
||||
using ck_tile_sched = ck_tile::GemmPipelineScheduler;
|
||||
switch(scheduler)
|
||||
{
|
||||
case PipelineScheduler::DEFAULT: return ck_tile_sched::Default;
|
||||
case PipelineScheduler::INTERWAVE: return ck_tile_sched::Interwave;
|
||||
case PipelineScheduler::INTRAWAVE: return ck_tile_sched::Intrawave;
|
||||
default: throw "Unknown PipelineScheduler";
|
||||
}
|
||||
}
|
||||
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct TilePipelineType
|
||||
{
|
||||
static_assert(false, "Unknown PipelineScheduler");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::BASIC_V1>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto version = ALGORITHM.block_gemm.pipeline_version;
|
||||
using ck_tile_pipeline = ck_tile::GemmPipeline;
|
||||
switch(version)
|
||||
{
|
||||
case PipelineVersion::V1: return ck_tile_pipeline::BASIC_V1;
|
||||
case PipelineVersion::V2: return ck_tile_pipeline::MEMORY;
|
||||
case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3;
|
||||
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
|
||||
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
|
||||
default: throw "Unknown block GEMM PipelineVersion";
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck_tile::ConvolutionSpecialization SetTileConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.specialization;
|
||||
using ck_tile_conv_spec = ck_tile::ConvolutionSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
case TileConvSpecialization::DEFAULT: return ck_tile_conv_spec::Default;
|
||||
case TileConvSpecialization::FILTER_1X1_PAD0: return ck_tile_conv_spec::Filter1x1Pad0;
|
||||
case TileConvSpecialization::FILTER_1X1_STRIDE1_PAD0:
|
||||
return ck_tile_conv_spec::Filter1x1Stride1Pad0;
|
||||
case TileConvSpecialization::FILTER_3x3: return ck_tile_conv_spec::Filter3x3;
|
||||
default: throw "Unknown ConvFwdSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval TileBlockGemmSpec SetTileBlockGemm()
|
||||
{
|
||||
constexpr auto& BG = ALGORITHM.block_gemm;
|
||||
|
||||
constexpr bool double_smem_buffer = BG.double_smem_buffer;
|
||||
constexpr int num_wave_groups = BG.num_wave_groups;
|
||||
|
||||
constexpr ck_tile::GemmPipeline pipeline_version = SetTileBlockGemmPipelineVersion<ALGORITHM>();
|
||||
constexpr ck_tile::GemmPipelineScheduler scheduler = SetTileScheduler<ALGORITHM>();
|
||||
|
||||
return TileBlockGemmSpec{
|
||||
.warps = {.m = BG.warps.m, .n = BG.warps.n, .k = BG.warps.k},
|
||||
.warp_tile = {.m = BG.warp_tile.m, .n = BG.warp_tile.n, .k = BG.warp_tile.k},
|
||||
.double_smem_buffer = double_smem_buffer,
|
||||
.num_wave_groups = num_wave_groups,
|
||||
.pipeline_version = pipeline_version,
|
||||
.scheduler = scheduler};
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval TileOptimizations SetTileOptimizations()
|
||||
{
|
||||
constexpr auto& OPT = ALGORITHM.optimizations;
|
||||
|
||||
return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge,
|
||||
.split_image = OPT.split_image,
|
||||
.explicit_gemm = OPT.explicit_gemm};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -251,14 +251,10 @@ class ConvDescription : public Description
|
||||
};
|
||||
} // namespace conv
|
||||
|
||||
/// @brief Helper concept to detect if a type has ConvTraits specialization
|
||||
template <typename T>
|
||||
concept HasConvTraits = requires { typename conv::ConvTraits<T>; };
|
||||
|
||||
/// @brief Factory function to create ConvDescription from a convolution instance type
|
||||
/// @tparam Instance The convolution instance type (must have InstanceTraits specialization)
|
||||
/// @tparam Instance The convolution instance type (must have ConvTraits specialization)
|
||||
/// @return A ConvDescription object populated with the instance's configuration details
|
||||
template <HasConvTraits Instance>
|
||||
template <conv::HasConvTraits Instance>
|
||||
conv::ConvDescription describe()
|
||||
{
|
||||
using Traits = conv::ConvTraits<Instance>;
|
||||
|
||||
@@ -4,23 +4,74 @@
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <ck_tile/builder/conv_builder.hpp>
|
||||
#include <ck_tile/builder/conv_signature_concepts.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_util.hpp>
|
||||
#include <ck_tile/builder/types.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/utility/loop_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
|
||||
#include <ck_tile/ops/gemm.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/pipeline_enum.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_util.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include <ck_tile/ops/grouped_convolution.hpp>
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
// Forward convolution layout concept - checks for A/B/E layout types
|
||||
template <typename T>
|
||||
concept HasFwdConvLayouts = requires {
|
||||
typename T::ALayout;
|
||||
typename T::BLayout;
|
||||
typename T::ELayout;
|
||||
};
|
||||
|
||||
// GEMM specialization concept - checks for kGemmSpecialization member
|
||||
template <typename T>
|
||||
concept HasGemmSpec = requires {
|
||||
{
|
||||
T::kGemmSpecialization
|
||||
} -> std::convertible_to<ck::tensor_operation::device::GemmSpecialization>;
|
||||
};
|
||||
|
||||
// Data types concept - checks for ADataType member
|
||||
template <typename T>
|
||||
concept HasDataTypes = requires { typename T::ADataType; };
|
||||
|
||||
// Elementwise operations concept - checks for A/B/CDE elementwise operation types
|
||||
template <typename T>
|
||||
concept HasElementwiseOps = requires {
|
||||
typename T::AElementwiseOperation;
|
||||
typename T::BElementwiseOperation;
|
||||
typename T::CDEElementwiseOperation;
|
||||
};
|
||||
|
||||
// Tile parameters concept - checks for tile dimension and transfer members
|
||||
template <typename T>
|
||||
concept HasTileParams = requires {
|
||||
{ T::kKPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kMPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kNPerBlock } -> std::convertible_to<int>;
|
||||
{ T::kAK1 } -> std::convertible_to<int>;
|
||||
{ T::kBK1 } -> std::convertible_to<int>;
|
||||
T::kCThreadClusterLengths;
|
||||
};
|
||||
|
||||
// Comprehensive concept that checks if an instance has all XDL forward convolution traits
|
||||
// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions
|
||||
template <typename T>
|
||||
concept IsXdlFwdConv = HasFwdConvLayouts<T> && HasGemmSpec<T> && HasDataTypes<T> &&
|
||||
HasElementwiseOps<T> && HasTileParams<T>;
|
||||
|
||||
// Primary concept for checking if a type can be described
|
||||
// Currently only forward convolutions are supported, but this can be extended
|
||||
// in the future to include backward data and backward weight convolutions
|
||||
template <typename T>
|
||||
concept HasConvTraits = IsXdlFwdConv<InstanceTraits<T>>;
|
||||
|
||||
// Helper metafunctions to convert from ck enums to builder enums
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
|
||||
@@ -35,16 +86,15 @@ constexpr auto convert_pipeline_version()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
if constexpr(ck_ver == v1)
|
||||
return V1;
|
||||
else if constexpr(ck_ver == v2)
|
||||
return V2;
|
||||
else if constexpr(ck_ver == v3)
|
||||
return V3;
|
||||
else if constexpr(ck_ver == v4)
|
||||
return V4;
|
||||
else if constexpr(ck_ver == v5)
|
||||
return V5;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v3: return V3;
|
||||
case v4: return V4;
|
||||
case v5: return V5;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
|
||||
@@ -59,14 +109,14 @@ constexpr auto convert_pipeline_version()
|
||||
{
|
||||
using enum ck::PipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
if constexpr(ck_ver == v1)
|
||||
return V1;
|
||||
else if constexpr(ck_ver == v2)
|
||||
return V2;
|
||||
else if constexpr(ck_ver == v4)
|
||||
return V4;
|
||||
else if constexpr(ck_ver == weight_only)
|
||||
return WEIGHT_ONLY;
|
||||
|
||||
switch(ck_ver)
|
||||
{
|
||||
case v1: return V1;
|
||||
case v2: return V2;
|
||||
case v4: return V4;
|
||||
case weight_only: return WEIGHT_ONLY;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
|
||||
@@ -82,10 +132,12 @@ constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
if constexpr(ck_sched == Intrawave)
|
||||
return INTRAWAVE;
|
||||
else if constexpr(ck_sched == Interwave)
|
||||
return INTERWAVE;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Intrawave: return INTRAWAVE;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
|
||||
@@ -101,10 +153,12 @@ constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
if constexpr(ck_sched == Default)
|
||||
return DEFAULT;
|
||||
else if constexpr(ck_sched == Interwave)
|
||||
return INTERWAVE;
|
||||
|
||||
switch(ck_sched)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
case Interwave: return INTERWAVE;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Helper structures for organizing trait data with domain-specific naming
|
||||
@@ -213,21 +267,13 @@ constexpr builder::ConvDirection conv_direction()
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::FORWARD;
|
||||
}
|
||||
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::BACKWARD_DATA;
|
||||
}
|
||||
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::BACKWARD_WEIGHT;
|
||||
}
|
||||
else
|
||||
{
|
||||
return builder::ConvDirection::FORWARD; // Default fallback
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
|
||||
@@ -242,60 +288,52 @@ constexpr auto conv_spec()
|
||||
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
using enum builder::ConvFwdSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvForwardSpecialization == Default)
|
||||
switch(InstTraits::kConvForwardSpecialization)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_1X1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_3x3;
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter3x3: return FILTER_3x3;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
using enum builder::ConvBwdDataSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvBwdDataSpecialization == Default)
|
||||
switch(InstTraits::kConvBwdDataSpecialization)
|
||||
{
|
||||
return builder::ConvBwdDataSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
using enum builder::ConvBwdWeightSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvBwdWeightSpecialization == Default)
|
||||
switch(InstTraits::kConvBwdWeightSpecialization)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::ODD_C;
|
||||
case Default: return DEFAULT;
|
||||
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
|
||||
case Filter1x1Pad0: return FILTER_1X1_PAD0;
|
||||
case OddC: return ODD_C;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper variable template to check if CK layout enums match
|
||||
template <typename A,
|
||||
typename B,
|
||||
typename E,
|
||||
typename ExpectedA,
|
||||
typename ExpectedB,
|
||||
typename ExpectedE>
|
||||
inline constexpr bool layouts_are =
|
||||
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return An std::array corresponding to the tensor layouts:
|
||||
@@ -304,112 +342,49 @@ constexpr auto conv_spec()
|
||||
/// index 2 -> Output layout
|
||||
template <typename Instance>
|
||||
constexpr auto conv_layout()
|
||||
requires HasFwdConvLayouts<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ALayout = typename InstTraits::ALayout;
|
||||
using BLayout = typename InstTraits::BLayout;
|
||||
using ELayout = typename InstTraits::ELayout;
|
||||
// Helper lambda to construct layout array
|
||||
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
|
||||
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
using A = typename InstanceTraits<Instance>::ALayout;
|
||||
using B = typename InstanceTraits<Instance>::BLayout;
|
||||
using E = typename InstanceTraits<Instance>::ELayout;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
using enum builder::TensorLayout;
|
||||
|
||||
if constexpr(InstTraits::kSpatialDim == 1)
|
||||
switch(InstanceTraits<Instance>::kSpatialDim)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNWC> && std::is_same_v<BLayout, ctc::GKXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNWK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNWC,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::GNWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NWGK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NWGC,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::NWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
|
||||
builder::TensorLayout::GKXC,
|
||||
builder::TensorLayout::NGKW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCX> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
|
||||
builder::TensorLayout::GKCX,
|
||||
builder::TensorLayout::NGKW};
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNHWC> && std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNHWK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNHWC,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::GNHWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NHWGK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NHWGC,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::NHWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
|
||||
builder::TensorLayout::GKYXC,
|
||||
builder::TensorLayout::NGKHW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
|
||||
builder::TensorLayout::GKCYX,
|
||||
builder::TensorLayout::NGKHW};
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 3)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNDHWC> && std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNDHWK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNDHWC,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::GNDHWK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NDHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NDHWGK>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NDHWGC,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::NDHWGK};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
|
||||
builder::TensorLayout::GKZYXC,
|
||||
builder::TensorLayout::NGKDHW};
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCZYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
|
||||
builder::TensorLayout::GKCZYX,
|
||||
builder::TensorLayout::NGKDHW};
|
||||
}
|
||||
case 1:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
|
||||
return layouts(GNWC, GKXC, GNWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
|
||||
return layouts(NWGC, GKXC, NWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
|
||||
return layouts(NGCW, GKXC, NGKW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKCX, ctl::NGKW>)
|
||||
return layouts(NGCW, GKCX, NGKW);
|
||||
break;
|
||||
case 2:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
|
||||
return layouts(GNHWC, GKYXC, GNHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
|
||||
return layouts(NHWGC, GKYXC, NHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
|
||||
return layouts(NGCHW, GKYXC, NGKHW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
|
||||
return layouts(NGCHW, GKCYX, NGKHW);
|
||||
break;
|
||||
case 3:
|
||||
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
|
||||
return layouts(GNDHWC, GKZYXC, GNDHWK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
|
||||
return layouts(NDHWGC, GKZYXC, NDHWGK);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
|
||||
return layouts(NGCDHW, GKZYXC, NGKDHW);
|
||||
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>)
|
||||
return layouts(NGCDHW, GKCZYX, NGKDHW);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -418,39 +393,26 @@ constexpr auto conv_layout()
|
||||
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
|
||||
template <typename Instance>
|
||||
constexpr builder::DataType conv_data_type()
|
||||
requires HasDataTypes<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ADataType = typename InstTraits::ADataType;
|
||||
using enum builder::DataType;
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, ck::half_t>)
|
||||
{
|
||||
return builder::DataType::FP16;
|
||||
}
|
||||
return FP16;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
|
||||
{
|
||||
return builder::DataType::BF16;
|
||||
}
|
||||
return BF16;
|
||||
else if constexpr(std::is_same_v<ADataType, float>)
|
||||
{
|
||||
return builder::DataType::FP32;
|
||||
}
|
||||
return FP32;
|
||||
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
|
||||
{
|
||||
return builder::DataType::FP8;
|
||||
}
|
||||
return FP8;
|
||||
else if constexpr(std::is_same_v<ADataType, int8_t>)
|
||||
{
|
||||
return builder::DataType::I8;
|
||||
}
|
||||
return I8;
|
||||
else if constexpr(std::is_same_v<ADataType, uint8_t>)
|
||||
{
|
||||
return builder::DataType::U8;
|
||||
}
|
||||
return U8;
|
||||
else
|
||||
{
|
||||
// Default fallback
|
||||
return builder::DataType::FP32;
|
||||
}
|
||||
return FP32; // Default fallback
|
||||
}
|
||||
|
||||
/// @brief Derives the elementwise operation from op type.
|
||||
@@ -459,27 +421,19 @@ constexpr builder::DataType conv_data_type()
|
||||
template <typename ElementwiseOp>
|
||||
constexpr builder::ElementwiseOperation elementwise_op()
|
||||
{
|
||||
using enum builder::ElementwiseOperation;
|
||||
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
|
||||
|
||||
if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS_BNORM_CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
{
|
||||
return builder::ElementwiseOperation::SCALE;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
{
|
||||
return builder::ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
{
|
||||
return builder::ElementwiseOperation::SCALEADD_SCALEADD_RELU;
|
||||
}
|
||||
return BIAS_BNORM_CLAMP;
|
||||
if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
return CLAMP;
|
||||
if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
return SCALE;
|
||||
if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
return PASS_THROUGH;
|
||||
if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
|
||||
return SCALEADD_SCALEADD_RELU;
|
||||
}
|
||||
|
||||
/// @brief Derives a gemm padding from a kernel instance type.
|
||||
@@ -487,6 +441,7 @@ constexpr builder::ElementwiseOperation elementwise_op()
|
||||
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
|
||||
template <typename Instance>
|
||||
constexpr builder::GemmPadding gemm_spec()
|
||||
requires HasGemmSpec<InstanceTraits<Instance>>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using enum builder::GemmPadding;
|
||||
@@ -494,69 +449,24 @@ constexpr builder::GemmPadding gemm_spec()
|
||||
|
||||
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
|
||||
|
||||
if constexpr(gemm_spec == Default)
|
||||
switch(gemm_spec)
|
||||
{
|
||||
return DEFAULT;
|
||||
}
|
||||
else if constexpr(gemm_spec == MPadding)
|
||||
{
|
||||
return M_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NPadding)
|
||||
{
|
||||
return N_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == KPadding)
|
||||
{
|
||||
return K_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNPadding)
|
||||
{
|
||||
return MN_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MKPadding)
|
||||
{
|
||||
return MK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NKPadding)
|
||||
{
|
||||
return NK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNKPadding)
|
||||
{
|
||||
return MNK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == OPadding)
|
||||
{
|
||||
return O_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MOPadding)
|
||||
{
|
||||
return MO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NOPadding)
|
||||
{
|
||||
return NO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == KOPadding)
|
||||
{
|
||||
return KO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNOPadding)
|
||||
{
|
||||
return MNO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MKOPadding)
|
||||
{
|
||||
return MKO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NKOPadding)
|
||||
{
|
||||
return NKO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNKOPadding)
|
||||
{
|
||||
return MNKO_PADDING;
|
||||
case Default: return DEFAULT;
|
||||
case MPadding: return M_PADDING;
|
||||
case NPadding: return N_PADDING;
|
||||
case KPadding: return K_PADDING;
|
||||
case MNPadding: return MN_PADDING;
|
||||
case MKPadding: return MK_PADDING;
|
||||
case NKPadding: return NK_PADDING;
|
||||
case MNKPadding: return MNK_PADDING;
|
||||
case OPadding: return O_PADDING;
|
||||
case MOPadding: return MO_PADDING;
|
||||
case NOPadding: return NO_PADDING;
|
||||
case KOPadding: return KO_PADDING;
|
||||
case MNOPadding: return MNO_PADDING;
|
||||
case MKOPadding: return MKO_PADDING;
|
||||
case NKOPadding: return NKO_PADDING;
|
||||
case MNKOPadding: return MNKO_PADDING;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -571,6 +481,7 @@ struct ConvTraits;
|
||||
/// set of traits directly from a fully-formed device kernel `Instance` type.
|
||||
/// It uses `InstanceTraits` to access the kernel's template parameters.
|
||||
template <HasInstanceTraits Instance>
|
||||
requires IsXdlFwdConv<InstanceTraits<Instance>>
|
||||
struct ConvTraits<Instance>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
@@ -8,28 +8,30 @@
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <limits.h>
|
||||
#include <cmath>
|
||||
#include <ostream>
|
||||
#include <concepts>
|
||||
#include <iostream>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/utility/loop_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck_tile/ops/common/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
#include <ck_tile/ops/gemm.hpp>
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include <limits.h>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <type_traits>
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/pipeline_enum.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
|
||||
|
||||
@@ -145,6 +145,15 @@ enum struct GemmSpecialization
|
||||
MNKOPadding
|
||||
};
|
||||
|
||||
// Enums for the CK Tile convolution specialization.
|
||||
enum class TileConvSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_PAD0,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
FILTER_3x3
|
||||
};
|
||||
|
||||
// Enums for the forward convolution specialization.
|
||||
enum class ConvFwdSpecialization
|
||||
{
|
||||
|
||||
@@ -90,7 +90,7 @@ add_ck_builder_test(test_ckb_conv_builder
|
||||
|
||||
# Tests convolution trait selection and configuration
|
||||
add_ck_builder_test(test_ckb_conv_traits
|
||||
conv/test_conv_traits.cpp)
|
||||
conv/ck/test_conv_traits.cpp)
|
||||
|
||||
# Tests convolution problem description and parameter handling
|
||||
add_ck_builder_test(test_ckb_conv_description
|
||||
@@ -119,19 +119,22 @@ add_ck_builder_test(test_ckb_instance_string
|
||||
# Tests the forward convolution builder across multiple data types and dimensions.
|
||||
# Individual tests are split into separate files to enable parallel compilation.
|
||||
add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp
|
||||
conv/test_ckb_conv_fwd_1d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_1d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_1d_i8.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp8.cpp
|
||||
conv/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
conv/test_ckb_conv_fwd_2d_dl_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp
|
||||
conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp
|
||||
conv/ck/test_ckb_conv_fwd_1d_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_1d_bf16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_1d_i8.cpp
|
||||
conv/ck/test_ckb_conv_fwd_2d_fp8.cpp
|
||||
conv/ck/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
|
||||
conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp
|
||||
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
|
||||
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_tile_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_DATA,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(FwdTileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_ck_tile_test<Builder>({
|
||||
"grouped_convolution_backward_data",
|
||||
"fp16",
|
||||
"NHWGC_GKYXC_NHWGK",
|
||||
"64x64x64",
|
||||
"2x2",
|
||||
"16x16x16",
|
||||
// "4x4x4", // TODO: Enable this check
|
||||
"Default",
|
||||
"Intrawave",
|
||||
"CShuffleEpilogue",
|
||||
"set",
|
||||
"pipeline_AgBgCrCompV3",
|
||||
"DoubleSmemBuffer_0",
|
||||
"NumWaveGroups_1",
|
||||
"MergedGroups_1",
|
||||
"SplitImage_0",
|
||||
"ExplicitGemm_0",
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_tile_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(FwdTileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_ck_tile_test<Builder>({
|
||||
"grouped_convolution_backward_weight",
|
||||
"fp16",
|
||||
"NHWGC_GKYXC_NHWGK",
|
||||
"64x64x64",
|
||||
"2x2",
|
||||
"16x16x16",
|
||||
// "4x4x4", // TODO: Enable this check
|
||||
"Default",
|
||||
"Intrawave",
|
||||
"CShuffleEpilogue",
|
||||
"set",
|
||||
"pipeline_AgBgCrCompV3",
|
||||
"DoubleSmemBuffer_0",
|
||||
"NumWaveGroups_1",
|
||||
"MergedGroups_1",
|
||||
"SplitImage_0",
|
||||
"ExplicitGemm_0",
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_tile_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(FwdTileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_ck_tile_test<Builder>({
|
||||
"grouped_convolution_forward",
|
||||
"fp16",
|
||||
"NHWGC_GKYXC_NHWGK",
|
||||
"64x64x64",
|
||||
"2x2",
|
||||
"16x16x16",
|
||||
// "4x4x4", // TODO: Enable this check
|
||||
"Default",
|
||||
"Intrawave",
|
||||
"CShuffleEpilogue",
|
||||
"set",
|
||||
"pipeline_AgBgCrCompV3",
|
||||
"DoubleSmemBuffer_0",
|
||||
"NumWaveGroups_1",
|
||||
"MergedGroups_1",
|
||||
"SplitImage_0",
|
||||
"ExplicitGemm_0",
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -243,6 +243,73 @@ struct LargeTensorWrapper
|
||||
ConvAlgorithmSpecialization::LARGE_TENSOR;
|
||||
};
|
||||
|
||||
// Specify thread block dimensions for a GEMM (CK Tile).
|
||||
struct TileThreadBlock
|
||||
{
|
||||
// Size of the submatrix problem in a thread block.
|
||||
MNK<size_t> tile_size;
|
||||
};
|
||||
static_assert(ckb::TileThreadBlockDescriptor<TileThreadBlock>);
|
||||
|
||||
struct TileTransfer
|
||||
{
|
||||
size_t a_scalar_per_vector;
|
||||
size_t b_scalar_per_vector;
|
||||
size_t c_scalar_per_vector;
|
||||
};
|
||||
static_assert(ckb::TileTransferDescriptor<TileTransfer>);
|
||||
|
||||
struct TileBlockGemm
|
||||
{
|
||||
// Number of warps per each dimension.
|
||||
MNK<int> warps;
|
||||
// Number of data processed per each dimension for each XDL/WMMA instruction.
|
||||
MNK<int> warp_tile;
|
||||
// Double LDS buffer.
|
||||
bool double_smem_buffer;
|
||||
// Waves grouping (Ping-Pong scheduler).
|
||||
int num_wave_groups;
|
||||
PipelineVersion pipeline_version;
|
||||
PipelineScheduler scheduler;
|
||||
};
|
||||
static_assert(ckb::TileBlockGemmDescriptor<TileBlockGemm>);
|
||||
|
||||
struct TileOptimizations
|
||||
{
|
||||
// Number of convolution groups processed per one workgroup
|
||||
int num_groups_to_merge;
|
||||
// Split image for large tensors
|
||||
bool split_image;
|
||||
// Explicit gemm for 1x1, stride=0, pad=0 cases
|
||||
bool explicit_gemm;
|
||||
};
|
||||
static_assert(ckb::TileOptimizationsDescriptor<TileOptimizations>);
|
||||
|
||||
struct TileConvSpecialization_
|
||||
{
|
||||
TileConvSpecialization specialization;
|
||||
};
|
||||
|
||||
struct TileThreadBlock_
|
||||
{
|
||||
TileThreadBlock thread_block;
|
||||
};
|
||||
|
||||
struct TileTransfer_
|
||||
{
|
||||
TileTransfer transfer;
|
||||
};
|
||||
|
||||
struct TileBlockGemm_
|
||||
{
|
||||
TileBlockGemm block_gemm;
|
||||
};
|
||||
|
||||
struct TileOptimizations_
|
||||
{
|
||||
TileOptimizations optimizations;
|
||||
};
|
||||
|
||||
// Factory
|
||||
|
||||
template <typename... Components>
|
||||
@@ -339,6 +406,51 @@ struct ConvAlgorithmTemplate : Components...
|
||||
result.transfer = t;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
constexpr auto with_tile_specializations(const S& s) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TileConvSpecialization_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.specialization = s;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename TB>
|
||||
constexpr auto with_tile_thread_block(const TB& tb) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TileThreadBlock_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.thread_block = tb;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BG>
|
||||
constexpr auto with_tile_block_gemm(const BG& bg) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TileBlockGemm_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.block_gemm = bg;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr auto with_tile_transfer(const T& t) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TileTransfer_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.transfer = t;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
constexpr auto with_tile_optimizations(const O& o) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TileOptimizations_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.optimizations = o;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// Algorithm types
|
||||
@@ -361,4 +473,10 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
|
||||
|
||||
using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileThreadBlock_,
|
||||
TileBlockGemm_,
|
||||
TileTransfer_,
|
||||
TileConvSpecialization_,
|
||||
TileOptimizations_>;
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck/ck.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <ck/ck.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <ck/ck.hpp>
|
||||
#include <ck/utility/reduction_operator.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp>
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits_util.hpp>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_util.hpp"
|
||||
|
||||
namespace ck_tile::reflect::detail {
|
||||
namespace {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -28,4 +28,20 @@ constexpr void run_test(const std::vector<std::string>& kernel_instance_componen
|
||||
}
|
||||
}
|
||||
|
||||
// Common CK Tile test implementation
|
||||
template <typename Builder>
|
||||
constexpr void run_ck_tile_test(const std::vector<std::string>& kernel_instance_components)
|
||||
{
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
std::cout << kernel_string << std::endl;
|
||||
for(const auto& component : kernel_instance_components)
|
||||
{
|
||||
EXPECT_THAT(kernel_string, ::testing::HasSubstr(component));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_1x1x1{
|
||||
.a_scalar_per_vector = 1,
|
||||
.b_scalar_per_vector = 1,
|
||||
.c_scalar_per_vector = 1,
|
||||
};
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_4x4x4{
|
||||
.a_scalar_per_vector = 4,
|
||||
.b_scalar_per_vector = 4,
|
||||
.c_scalar_per_vector = 4,
|
||||
};
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_8x8x8{
|
||||
.a_scalar_per_vector = 8,
|
||||
.b_scalar_per_vector = 8,
|
||||
.c_scalar_per_vector = 8,
|
||||
};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
|
||||
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = {
|
||||
.warps = {.m = 2, .n = 2, .k = 1},
|
||||
.warp_tile = {.m = 16, .n = 16, .k = 16},
|
||||
.double_smem_buffer = false,
|
||||
.num_wave_groups = 1,
|
||||
.pipeline_version = PipelineVersion::V1,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v2_intrawave = {
|
||||
.warps = {.m = 2, .n = 2, .k = 1},
|
||||
.warp_tile = {.m = 16, .n = 16, .k = 16},
|
||||
.double_smem_buffer = false,
|
||||
.num_wave_groups = 1,
|
||||
.pipeline_version = PipelineVersion::V2,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v3_intrawave = {
|
||||
.warps = {.m = 2, .n = 2, .k = 1},
|
||||
.warp_tile = {.m = 16, .n = 16, .k = 16},
|
||||
.double_smem_buffer = false,
|
||||
.num_wave_groups = 1,
|
||||
.pipeline_version = PipelineVersion::V3,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v4_intrawave = {
|
||||
.warps = {.m = 2, .n = 2, .k = 1},
|
||||
.warp_tile = {.m = 16, .n = 16, .k = 16},
|
||||
.double_smem_buffer = false,
|
||||
.num_wave_groups = 1,
|
||||
.pipeline_version = PipelineVersion::V4,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v5_intrawave = {
|
||||
.warps = {.m = 2, .n = 2, .k = 1},
|
||||
.warp_tile = {.m = 16, .n = 16, .k = 16},
|
||||
.double_smem_buffer = false,
|
||||
.num_wave_groups = 1,
|
||||
.pipeline_version = PipelineVersion::V5,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
@@ -55,6 +55,11 @@
|
||||
#ifndef CK_ENABLE_FP32
|
||||
#define CK_ENABLE_FP32 "ON"
|
||||
#endif
|
||||
#ifndef CK_ENABLE_TF32
|
||||
#if defined(__gfx942__) || defined(__gfx95__)
|
||||
#define CK_ENABLE_TF32 "ON"
|
||||
#endif
|
||||
#endif
|
||||
#ifndef CK_ENABLE_FP64
|
||||
#define CK_ENABLE_FP64 "ON"
|
||||
#endif
|
||||
@@ -85,6 +90,12 @@
|
||||
#cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@
|
||||
#endif
|
||||
|
||||
#ifndef CK_ENABLE_TF32
|
||||
#if defined(__gfx942__) || defined(__gfx95__)
|
||||
#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_ENABLE_FP64
|
||||
#cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@
|
||||
#endif
|
||||
|
||||
@@ -5,24 +5,16 @@
|
||||
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#endif
|
||||
|
||||
#include "ck/utility/pipeline_enum.hpp"
|
||||
#include "ck/utility/loop_scheduler.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct PipelineVersion
|
||||
{
|
||||
v1,
|
||||
v2,
|
||||
// v3 is only used in the Stream-K implementation.
|
||||
v4,
|
||||
weight_only,
|
||||
};
|
||||
|
||||
template <PipelineVersion PipelineVer,
|
||||
index_t NumPrefetch = 1,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default,
|
||||
@@ -62,18 +54,3 @@ constexpr auto GridwiseGemmPipeline_Selector()
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
|
||||
{
|
||||
switch(p)
|
||||
{
|
||||
case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break;
|
||||
case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break;
|
||||
case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break;
|
||||
case ck::PipelineVersion::weight_only: os << "PipelineVersion::weight_only"; break;
|
||||
default: os << "";
|
||||
}
|
||||
return os;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -3,52 +3,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct BlockGemmPipelineVersion
|
||||
{
|
||||
// For GEMM
|
||||
v1, // Naive
|
||||
v2, // Mem
|
||||
v3, // Comp
|
||||
v4, // Comp, double lds buffer
|
||||
v5, // Comp, double global prefetch register buffer
|
||||
|
||||
// For GEMM with preshuffled weight
|
||||
// v1, single lds buffer
|
||||
// v2, double lds buffer
|
||||
};
|
||||
enum struct BlockGemmPipelineScheduler
|
||||
{
|
||||
Intrawave,
|
||||
Interwave,
|
||||
};
|
||||
|
||||
enum struct TailNumber
|
||||
{
|
||||
// Single / Double buffer pipeline
|
||||
Odd,
|
||||
Even,
|
||||
|
||||
// Long prefetch pipeline, up to 8
|
||||
One,
|
||||
Two,
|
||||
Three,
|
||||
Four,
|
||||
Five,
|
||||
Six,
|
||||
Seven,
|
||||
|
||||
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
|
||||
Empty,
|
||||
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
|
||||
// prefetchstages
|
||||
Full,
|
||||
};
|
||||
|
||||
enum SchedulerGroup : uint32_t
|
||||
{
|
||||
SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions
|
||||
|
||||
@@ -3,40 +3,20 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
#include <ostream>
|
||||
#endif
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct LoopScheduler
|
||||
{
|
||||
Default,
|
||||
Interwave,
|
||||
};
|
||||
|
||||
/// @brief Helper function to get default loop scheduler
|
||||
/// @details Returns the default loop scheduler based on compile-time configuration.
|
||||
constexpr LoopScheduler make_default_loop_scheduler()
|
||||
{
|
||||
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
|
||||
return LoopScheduler::Interwave;
|
||||
#else
|
||||
return LoopScheduler::Default;
|
||||
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case ck::LoopScheduler::Default: os << "Default"; break;
|
||||
case ck::LoopScheduler::Interwave: os << "Interwave"; break;
|
||||
default: os << "";
|
||||
}
|
||||
return os;
|
||||
}
|
||||
#endif
|
||||
|
||||
40
include/ck/utility/pipeline_enum.hpp
Normal file
40
include/ck/utility/pipeline_enum.hpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
#include <ostream>
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
/// @brief Pipeline version enumeration for GEMM kernels
|
||||
/// @details Defines different pipeline strategies for data movement and computation overlap
|
||||
/// in GEMM kernels. This is a lightweight header containing only the enum definition,
|
||||
/// extracted from gridwise_gemm_pipeline_selector.hpp to minimize dependencies.
|
||||
enum struct PipelineVersion
|
||||
{
|
||||
v1, ///< Version 1 pipeline
|
||||
v2, ///< Version 2 pipeline
|
||||
// v3 is only used in the Stream-K implementation.
|
||||
v4, ///< Version 4 pipeline
|
||||
weight_only, ///< Weight-only specialized pipeline
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
|
||||
{
|
||||
switch(p)
|
||||
{
|
||||
case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break;
|
||||
case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break;
|
||||
case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break;
|
||||
case ck::PipelineVersion::weight_only: os << "PipelineVersion::weight_only"; break;
|
||||
default: os << "";
|
||||
}
|
||||
return os;
|
||||
}
|
||||
#endif
|
||||
83
include/ck/utility/scheduler_enum.hpp
Normal file
83
include/ck/utility/scheduler_enum.hpp
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
#include <ostream>
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
/// @brief Block GEMM pipeline version enumeration
|
||||
/// @details Defines different block GEMM pipeline strategies.
|
||||
/// This is a lightweight header containing only enum definitions,
|
||||
/// extracted from blkgemmpipe_scheduler.hpp to minimize dependencies.
|
||||
enum struct BlockGemmPipelineVersion
|
||||
{
|
||||
// For GEMM
|
||||
v1, ///< Naive pipeline
|
||||
v2, ///< Memory-optimized pipeline
|
||||
v3, ///< Compute-optimized pipeline
|
||||
v4, ///< Compute-optimized with double LDS buffer
|
||||
v5, ///< Compute-optimized with double global prefetch register buffer
|
||||
|
||||
// For GEMM with preshuffled weight
|
||||
// v1, single lds buffer
|
||||
// v2, double lds buffer
|
||||
};
|
||||
|
||||
/// @brief Block GEMM pipeline scheduler enumeration
|
||||
/// @details Defines scheduling strategies for block GEMM pipelines.
|
||||
enum struct BlockGemmPipelineScheduler
|
||||
{
|
||||
Intrawave, ///< Schedule within a single wavefront
|
||||
Interwave, ///< Schedule across multiple wavefronts
|
||||
};
|
||||
|
||||
/// @brief Loop scheduler enumeration
|
||||
/// @details Defines scheduling strategies for computational loops.
|
||||
enum struct LoopScheduler
|
||||
{
|
||||
Default, ///< Default scheduling strategy
|
||||
Interwave, ///< Cross-wavefront scheduling
|
||||
};
|
||||
|
||||
/// @brief Tail number enumeration for pipeline buffering
|
||||
/// @details Defines the number of tail iterations in pipelined loops.
|
||||
enum struct TailNumber
|
||||
{
|
||||
// Single / Double buffer pipeline
|
||||
Odd, ///< Odd number of iterations
|
||||
Even, ///< Even number of iterations
|
||||
|
||||
// Long prefetch pipeline, up to 8
|
||||
One, ///< One tail iteration
|
||||
Two, ///< Two tail iterations
|
||||
Three, ///< Three tail iterations
|
||||
Four, ///< Four tail iterations
|
||||
Five, ///< Five tail iterations
|
||||
Six, ///< Six tail iterations
|
||||
Seven, ///< Seven tail iterations
|
||||
|
||||
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
|
||||
Empty, ///< No tail iterations
|
||||
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
|
||||
// prefetchstages
|
||||
Full, ///< Full tail iterations
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case ck::LoopScheduler::Default: os << "Default"; break;
|
||||
case ck::LoopScheduler::Interwave: os << "Interwave"; break;
|
||||
default: os << "";
|
||||
}
|
||||
return os;
|
||||
}
|
||||
#endif
|
||||
@@ -1552,6 +1552,81 @@ CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>&
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename Functor, typename LowLength>
|
||||
struct functor_transform : public base_transform<1, 1>
|
||||
{
|
||||
using LowerIndex = multi_index<1>;
|
||||
using UpperIndex = multi_index<1>;
|
||||
|
||||
using UpLengths = decltype(make_tuple(LowLength{}));
|
||||
|
||||
Functor functor_;
|
||||
UpLengths up_lengths_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr functor_transform() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr functor_transform(const Functor& functor,
|
||||
const LowLength& low_length)
|
||||
: functor_{functor}, up_lengths_{make_tuple(low_length)}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low(number<0>{}) = functor_(idx_up[number<0>{}]);
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff&,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& up_idx) const
|
||||
{
|
||||
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
|
||||
UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
const auto idx_low_old = idx_low;
|
||||
calculate_lower_index(idx_low, up_idx);
|
||||
idx_diff_low = idx_low - idx_low_old;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool
|
||||
is_valid_upper_index_always_mapped_to_valid_lower_index()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE static constexpr bool
|
||||
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
// Note: When using functor_transform, ensure that the transformed coordinates
|
||||
// are always valid for vectorized load/store operations.
|
||||
template <typename LowVectorLengths, typename LowVectorStrides>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
|
||||
const LowVectorStrides& low_vector_strides)
|
||||
{
|
||||
return make_tuple(low_vector_lengths, low_vector_strides);
|
||||
}
|
||||
};
|
||||
|
||||
//*******************************************************************************************************
|
||||
|
||||
template <typename LowLength>
|
||||
@@ -1671,6 +1746,13 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
|
||||
return offset<LowLength, OffsetLength>{low_length, offset_length};
|
||||
}
|
||||
|
||||
template <typename Functor, typename LowLength>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_functor_transform(const Functor& functor,
|
||||
const LowLength& low_length)
|
||||
{
|
||||
return functor_transform<Functor, LowLength>{functor, low_length};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
||||
|
||||
@@ -357,6 +357,12 @@ struct amdgcn_compiler_target_state
|
||||
#endif // __gfx950__
|
||||
|
||||
// GFX10
|
||||
#if defined(__gfx1010__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1010 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1010 = false;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1030__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1030 = true;
|
||||
#else
|
||||
@@ -493,6 +499,7 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX90A, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \
|
||||
|
||||
@@ -533,7 +533,8 @@ struct tile_scatter_gather
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
m0_set_with_memory(
|
||||
amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
|
||||
|
||||
using Traits = load_store_traits;
|
||||
|
||||
|
||||
@@ -1263,7 +1263,9 @@ struct tile_window_with_static_lengths
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename = std::enable_if_t<is_tensor_view_v<TensorView_>>>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
@@ -1310,7 +1312,10 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
|
||||
tile_distribution);
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
|
||||
template <typename TensorView,
|
||||
typename WindowLengths,
|
||||
typename StaticTileDistribution,
|
||||
typename = std::enable_if_t<is_tile_distribution_v<StaticTileDistribution>>>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution,
|
||||
|
||||
@@ -517,7 +517,8 @@ struct tile_window_linear
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
m0_set_with_memory(
|
||||
amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
|
||||
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
|
||||
|
||||
@@ -33,59 +33,73 @@ namespace ck_tile {
|
||||
* @example
|
||||
*
|
||||
* // Direct usage without creating a separate variable:
|
||||
* ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host_tensor);
|
||||
* ck_tile::FillUniformDistribution<>{-1.f, 1.f}(a_host_tensor);
|
||||
*/
|
||||
template <typename T>
|
||||
template <typename T = void>
|
||||
struct FillUniformDistribution
|
||||
{
|
||||
float a_{-5.f};
|
||||
float b_{5.f};
|
||||
std::optional<uint32_t> seed_{11939};
|
||||
// ATTENTION: Whether to use multi-threading (note: not guaranteed to be perfectly distributed
|
||||
// across threads).
|
||||
bool threaded = false;
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
if(threaded)
|
||||
{
|
||||
uint32_t num_thread = std::thread::hardware_concurrency();
|
||||
auto total = static_cast<std::size_t>(std::distance(first, last));
|
||||
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
|
||||
if(first == last)
|
||||
return;
|
||||
using T_iter = std::decay_t<decltype(*first)>;
|
||||
static_assert(std::is_same_v<T, T_iter> || std::is_void_v<T>,
|
||||
"Iterator value type must match template type T");
|
||||
constexpr auto PackedSize = numeric_traits<T_iter>::PackedSize;
|
||||
const auto total = static_cast<size_t>(std::distance(first, last));
|
||||
const auto total_bytes = total * sizeof(T_iter);
|
||||
|
||||
std::vector<joinable_thread> threads(num_thread);
|
||||
for(std::size_t it = 0; it < num_thread; ++it)
|
||||
{
|
||||
std::size_t iw_begin = it * work_per_thread;
|
||||
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
|
||||
auto thread_f = [this, total, iw_begin, iw_end, &first] {
|
||||
if(iw_begin > total || iw_end > total)
|
||||
return;
|
||||
// need to make each thread unique, add an offset to current seed
|
||||
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
|
||||
: std::random_device{}());
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
|
||||
if constexpr(numeric_traits<T>::PackedSize == 2)
|
||||
return ck_tile::type_convert<T>(fp32x2_t{dis(gen), dis(gen)});
|
||||
else
|
||||
return ck_tile::type_convert<T>(dis(gen));
|
||||
});
|
||||
};
|
||||
threads[it] = joinable_thread(thread_f);
|
||||
}
|
||||
}
|
||||
else
|
||||
// max 80 threads; at least 2MB per thread
|
||||
const size_t available_cpu_cores = get_available_cpu_cores();
|
||||
const size_t num_thread =
|
||||
min(80UL, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL));
|
||||
constexpr size_t BLOCK_BYTES = 64;
|
||||
constexpr size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T_iter);
|
||||
const size_t num_blocks = integer_divide_ceil(total_bytes, BLOCK_BYTES);
|
||||
const size_t blocks_per_thread = integer_divide_ceil(num_blocks, num_thread);
|
||||
|
||||
// use minstd_rand for better performance on discard()
|
||||
std::minstd_rand gen(seed_.has_value() ? *seed_ : std::random_device{}());
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
|
||||
std::vector<joinable_thread> threads;
|
||||
threads.reserve(num_thread - 1); // last job run in the main thread
|
||||
for(int it = num_thread - 1; it >= 0; --it)
|
||||
{
|
||||
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
std::generate(first, last, [&dis, &gen]() {
|
||||
if constexpr(numeric_traits<T>::PackedSize == 2)
|
||||
return ck_tile::type_convert<T>(fp32x2_t{dis(gen), dis(gen)});
|
||||
else
|
||||
return ck_tile::type_convert<T>(dis(gen));
|
||||
});
|
||||
const size_t ib_begin = it * blocks_per_thread;
|
||||
const size_t ib_end = min(ib_begin + blocks_per_thread, num_blocks);
|
||||
|
||||
auto job = [=]() {
|
||||
auto g_ = gen; // copy
|
||||
auto d_ = dis; // copy
|
||||
g_.discard(ib_begin * BLOCK_SIZE * PackedSize);
|
||||
auto t_fn = [&]() {
|
||||
if constexpr(PackedSize == 2)
|
||||
return type_convert<T_iter>(fp32x2_t{d_(g_), d_(g_)});
|
||||
else
|
||||
return type_convert<T_iter>(d_(g_));
|
||||
};
|
||||
|
||||
size_t ib = ib_begin;
|
||||
for(; ib < ib_end - 1; ++ib) // full blocks
|
||||
static_for<0, BLOCK_SIZE, 1>{}([&](auto iw_) {
|
||||
constexpr size_t iw = iw_.value;
|
||||
*(first + ib * BLOCK_SIZE + iw) = t_fn();
|
||||
});
|
||||
for(size_t iw = 0; iw < BLOCK_SIZE; ++iw) // last block
|
||||
if(ib * BLOCK_SIZE + iw < total)
|
||||
*(first + ib * BLOCK_SIZE + iw) = t_fn();
|
||||
};
|
||||
|
||||
if(it > 0)
|
||||
threads.emplace_back(std::move(job));
|
||||
else
|
||||
job(); // last job run in the main thread
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __linux__
|
||||
#include <sched.h>
|
||||
#endif
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
@@ -24,4 +27,50 @@ struct joinable_thread : std::thread
|
||||
this->join();
|
||||
}
|
||||
};
|
||||
|
||||
inline unsigned int get_available_cpu_cores()
|
||||
{
|
||||
#if defined(__linux__)
|
||||
cpu_set_t cpu_set;
|
||||
if(sched_getaffinity(0, sizeof(cpu_set_t), &cpu_set) == 0)
|
||||
{
|
||||
unsigned int cpu_count = CPU_COUNT(&cpu_set);
|
||||
if(cpu_count > 0)
|
||||
return cpu_count;
|
||||
}
|
||||
#endif
|
||||
// Fallback if sched_getaffinity unavailable or fails
|
||||
return std::thread::hardware_concurrency();
|
||||
}
|
||||
|
||||
class cpu_core_guard
|
||||
{
|
||||
#if defined(__linux__)
|
||||
cpu_set_t original_cpu_set_;
|
||||
|
||||
public:
|
||||
cpu_core_guard(unsigned int num_cores) : original_cpu_set_()
|
||||
{
|
||||
// save original cpu set
|
||||
sched_getaffinity(0, sizeof(cpu_set_t), &original_cpu_set_);
|
||||
|
||||
// set new cpu set
|
||||
cpu_set_t new_cpu_set;
|
||||
CPU_ZERO(&new_cpu_set);
|
||||
for(unsigned int i = 0; i < num_cores; ++i)
|
||||
{
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
CPU_SET(i, &new_cpu_set); // NOLINT(old-style-cast)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
sched_setaffinity(0, sizeof(cpu_set_t), &new_cpu_set);
|
||||
}
|
||||
~cpu_core_guard()
|
||||
{
|
||||
// restore original cpu set
|
||||
sched_setaffinity(0, sizeof(cpu_set_t), &original_cpu_set_);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1259,12 +1259,12 @@ struct MoeFlatmmKernel
|
||||
auto fused_token =
|
||||
kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
|
||||
|
||||
index_t scatter_token_id = fused_token & token_id_mask;
|
||||
index_t scatter_token_id = fused_token & token_id_mask;
|
||||
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
|
||||
if constexpr(IsInputGemm)
|
||||
scatter_token_id =
|
||||
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
|
||||
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
|
||||
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -18,21 +18,21 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
{
|
||||
using Underlying = FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using FlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using MXFlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
|
||||
using BlockGemmShape =
|
||||
remove_cvref_t<typename MXFlatmmPipeline_::BlockGemmShape>; // TileFlatmmShape
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using ALayout = remove_cvref_t<typename MXFlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename MXFlatmmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename MXFlatmmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
|
||||
static constexpr index_t KernelBlockSize = MXFlatmmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = MXFlatmmPipeline::UsePersistentKernel;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
using ADataType = remove_cvref_t<typename MXFlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename MXFlatmmPipeline::BDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
@@ -43,9 +43,9 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
static constexpr int APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr int BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
static constexpr int MXdlPack = FlatmmPipeline::MXdlPack;
|
||||
static constexpr int NXdlPack = FlatmmPipeline::NXdlPack;
|
||||
static constexpr int KXdlPack = FlatmmPipeline::KXdlPack;
|
||||
static constexpr int MXdlPack = MXFlatmmPipeline::MXdlPack;
|
||||
static constexpr int NXdlPack = MXFlatmmPipeline::NXdlPack;
|
||||
static constexpr int KXdlPack = MXFlatmmPipeline::KXdlPack;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -63,7 +63,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
|
||||
return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, MXFlatmmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -123,33 +123,23 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
|
||||
"A tensor for mx must be RowMajor");
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<MXFlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock;
|
||||
constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock;
|
||||
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
|
||||
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
|
||||
const index_t kFlatN = kargs.N / kNWarpTile;
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0,
|
||||
static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0,
|
||||
"wrong! vector size for B tensor");
|
||||
auto&& naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
|
||||
@@ -262,20 +252,12 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
|
||||
"A tensor for mx must be RowMajor");
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, MXFlatmmPipeline::kPadK>{});
|
||||
}();
|
||||
|
||||
const auto& b_flat_tensor_view = views.at(I1);
|
||||
@@ -289,14 +271,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
sequence<false, MXFlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
sequence<false, MXFlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
@@ -309,14 +291,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
sequence<false, MXFlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<FlatmmPipeline::kPadM, false>{});
|
||||
sequence<MXFlatmmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -334,26 +316,18 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
|
||||
"A tensor for mx must be RowMajor");
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}();
|
||||
|
||||
const auto& b_flat_block_window =
|
||||
make_tile_window(b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
make_tuple(number<MXFlatmmPipeline::flatNPerWarp>{},
|
||||
number<MXFlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
@@ -444,14 +418,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
|
||||
a_block_window.get_window_lengths(),
|
||||
a_block_window.get_window_origin(),
|
||||
FlatmmPipeline::GetADramTileDistribution());
|
||||
const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
|
||||
b_flat_block_window,
|
||||
scale_a_block_window,
|
||||
scale_b_block_window,
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
MXFlatmmPipeline::GetADramTileDistribution());
|
||||
const auto& c_block_tile = MXFlatmmPipeline{}(a_block_window_with_distr,
|
||||
b_flat_block_window,
|
||||
scale_a_block_window,
|
||||
scale_b_block_window,
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(DoEpiScale)
|
||||
@@ -487,10 +461,10 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
|
||||
splitk_batch_offset.a_k_split_offset / APackedSize;
|
||||
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
|
||||
splitk_batch_offset.b_k_split_offset / BPackedSize;
|
||||
const auto a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
|
||||
splitk_batch_offset.a_k_split_offset / APackedSize;
|
||||
const auto b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
|
||||
splitk_batch_offset.b_k_split_offset / BPackedSize;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
@@ -501,7 +475,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user