rebase with develop

This commit is contained in:
khushbu agarwal
2025-12-09 19:12:29 -05:00
164 changed files with 4843 additions and 2775 deletions

3
.gitignore vendored
View File

@@ -36,6 +36,9 @@ tags
# Editors
.vscode
# CMake formatting configuration (local)
.cmake-format.yaml
# Cline
.cline*

View File

@@ -2,6 +2,16 @@
Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/).
## (Unreleased) Composable Kernel 1.3.0
### Added
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
### Changed
### Upcoming changes
## Composable Kernel 1.2.0 for ROCm 7.2.0
### Added

View File

@@ -92,6 +92,10 @@ if (DTYPES)
add_definitions(-DCK_ENABLE_FP32)
set(CK_ENABLE_FP32 "ON")
endif()
if (DTYPES MATCHES "tf32")
# definition will be added based on the GPU target in the following section
set(CK_ENABLE_TF32 "ON")
endif()
if (DTYPES MATCHES "fp64")
add_definitions(-DCK_ENABLE_FP64)
set(CK_ENABLE_FP64 "ON")
@@ -106,6 +110,7 @@ else()
set(CK_ENABLE_INT8 "ON")
set(CK_ENABLE_FP16 "ON")
set(CK_ENABLE_FP32 "ON")
set(CK_ENABLE_TF32 "ON")
set(CK_ENABLE_FP64 "ON")
set(CK_ENABLE_BF16 "ON")
set(CK_ENABLE_FP8 "ON")
@@ -282,6 +287,15 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
set(CK_GFX950_SUPPORT "ON")
endif()
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
add_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "ON")
else()
message(STATUS "Disabling TF32 instances")
remove_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "OFF")
endif()
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH)
@@ -651,6 +665,9 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
set(add_inst 1)
endif()

View File

@@ -1,7 +1,7 @@
FROM ubuntu:24.04
ARG DEBIAN_FRONTEND=noninteractive
ARG ROCMVERSION=7.0.1
ARG ROCMVERSION=7.1.1
ARG compiler_version=""
ARG compiler_commit=""
ARG CK_SCCACHE=""
@@ -13,8 +13,8 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN set -xe && \
apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl
RUN wget https://repo.radeon.com/amdgpu-install/7.0.1/ubuntu/noble/amdgpu-install_7.0.1.70001-1_all.deb && \
apt install ./amdgpu-install_7.0.1.70001-1_all.deb -y && \
RUN wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \
apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \
apt update && \
apt install python3-setuptools python3-wheel -y && \
apt install rocm-dev -y

View File

@@ -1,4 +1,4 @@
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.0.1"
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.1.1"
FROM $BASE_DOCKER
ARG compiler_version=""
ARG compiler_commit=""

View File

@@ -29,4 +29,4 @@ RUN groupadd -g 109 render && \
git sparse-checkout set projects/hipblaslt shared/origami && \
cd projects/hipblaslt && \
git show --oneline -s && \
CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --logic-yaml-filter gfx950/*/* --architecture="gfx942;gfx950" -j 128 --skip_rocroller
CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx942;gfx950" -j 128 --skip_rocroller

22
Jenkinsfile vendored
View File

@@ -288,7 +288,7 @@ def getBaseDockerImageName(){
}
else{
def ROCM_numeric = parseVersion("${params.ROCMVERSION}")
if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 1 ){
if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 2 ){
img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}"
}
else{
@@ -434,7 +434,7 @@ def buildDocker(install_prefix){
}
catch(Exception ex){
echo "Unable to locate image: ${image_name}. Building image now"
retimage = docker.build("${image_name}", dockerArgs + ' .')
retimage = docker.build("${image_name}", dockerArgs)
withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) {
retimage.push()
}
@@ -834,12 +834,14 @@ def Build_CK(Map conf=[:]){
if (params.hipTensor_test && arch == "gfx90a" ){
// build and test hipTensor on gfx90a node
sh """#!/bin/bash
rm -rf "${params.hipTensor_branch}".zip
rm -rf hipTensor-"${params.hipTensor_branch}"
wget https://github.com/ROCm/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip
unzip -o "${params.hipTensor_branch}".zip
rm -rf rocm-libraries
git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git
cd rocm-libraries
git sparse-checkout init --cone
git sparse-checkout set projects/hiptensor
git checkout "${params.hipTensor_branch}"
"""
dir("hipTensor-${params.hipTensor_branch}"){
dir("rocm-libraries/projects/hiptensor"){
sh """#!/bin/bash
mkdir -p build
ls -ltr
@@ -1095,7 +1097,7 @@ def run_pytorch_tests(Map conf=[:]){
//launch develop branch daily jobs
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_PERFORMANCE_TESTS=true;FORCE_CI=true
0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX101=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX101=false;BUILD_GFX908=false;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true;BUILD_PACKAGES=true
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true
@@ -1121,8 +1123,8 @@ pipeline {
description: 'If you want to use a custom docker image, please specify it here (default: leave blank).')
string(
name: 'ROCMVERSION',
defaultValue: '7.0.1',
description: 'Specify which ROCM version to use: 7.0.1 (default).')
defaultValue: '7.1.1',
description: 'Specify which ROCM version to use: 7.1.1 (default).')
string(
name: 'COMPILER_VERSION',
defaultValue: '',

View File

@@ -187,7 +187,7 @@ limit the number of threads. For example, if you have a 128-core CPU and 128 Gb
Additional cmake flags can be used to significantly speed-up the build:
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build
instances of select data types only. The main default data types are fp32 and fp16; you can safely skip
other data types.

View File

@@ -27,6 +27,9 @@ if (DTYPES)
add_definitions(-DCK_ENABLE_FP32)
set(CK_ENABLE_FP32 "ON")
endif()
if (DTYPES MATCHES "tf32")
set(CK_ENABLE_TF32 "ON")
endif()
if (DTYPES MATCHES "fp64")
add_definitions(-DCK_ENABLE_FP64)
set(CK_ENABLE_FP64 "ON")
@@ -41,6 +44,7 @@ else()
set(CK_ENABLE_INT8 "ON")
set(CK_ENABLE_FP16 "ON")
set(CK_ENABLE_FP32 "ON")
set(CK_ENABLE_TF32 "ON")
set(CK_ENABLE_FP64 "ON")
set(CK_ENABLE_BF16 "ON")
if (GPU_TARGETS MATCHES "gfx94")
@@ -67,6 +71,14 @@ if (GPU_TARGETS)
add_definitions(-DCK_USE_FNUZ_FP8)
set(CK_USE_FNUZ_FP8 "ON")
endif()
if ((GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
add_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "ON")
else()
message(STATUS "Disabling TF32 instances for this target")
remove_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "OFF")
endif()
else()
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
set(CK_USE_XDL "ON")

View File

@@ -1,2 +1,2 @@
rocm-docs-core[api_reference]==1.20.1
rocm-docs-core[api_reference]==1.31.0
sphinxcontrib-bibtex==2.6.5

View File

@@ -237,7 +237,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core[api-reference]==1.20.1
rocm-docs-core[api-reference]==1.31.0
# via -r requirements.in
rpds-py==0.24.0
# via

View File

@@ -208,40 +208,6 @@ add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
# add fmha_fwd_v3 example
set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3")
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")
add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
)
target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE
fmha_fwd_v3.cpp
${FMHA_FWD_V3_INSTANCES}
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-fgpu-flush-denormals-to-zero
-Wno-undefined-func-template
--save-temps
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS)
check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32)
if(HAS_DISABLE_PACKED_FP32)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-mllvm --amdgpu-disable-packed-fp32=1
)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS
-DCK_TILE_DISABLE_PACKED_FP32=1
)
endif()
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global

View File

@@ -30,16 +30,24 @@ _MASK_MAP = {
}
def get_mask_map(mask: str):
if mask == "generic":
def get_mask_map(mask_impl: str):
if mask_impl == "generic":
return _MASK_MAP
elif mask == "simplified":
elif mask_impl == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None
def get_mask_impl(mask: str) -> str:
return "simplified" if mask.startswith("s_") else "generic"
def get_mask_cpp_type(mask: str) -> str:
return get_mask_map(get_mask_impl(mask))[mask]
_MASK_CHECK_MAP = {
"no": "t.mask_type == mask_enum::no_mask",
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
@@ -62,6 +70,10 @@ def get_mask_check_map(mask: str):
return None
def get_mask_cpp_check_expr(mask: str) -> str:
return get_mask_check_map(get_mask_impl(mask))[mask]
QSCALE_MAP = {
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
@@ -122,6 +134,7 @@ PIPELINE_MAP = {
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
"qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline",
}
PIPELINE_ENUM_MAP = {
@@ -131,6 +144,7 @@ PIPELINE_ENUM_MAP = {
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
"qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3",
}
BOOL_MAP = {

File diff suppressed because it is too large Load Diff

View File

@@ -1,616 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iomanip>
#include <iostream>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <ck_tile/core/numeric/bfloat16.hpp>
#include <ck_tile/core/numeric/half.hpp>
#include <ck_tile/core/numeric/math.hpp>
#include <ck_tile/core/utility/functional.hpp>
#include <ck_tile/host/arg_parser.hpp>
#include <ck_tile/host/device_memory.hpp>
#include <ck_tile/host/fill.hpp>
#include <ck_tile/host/check_err.hpp>
#include <ck_tile/host/host_tensor.hpp>
#include <ck_tile/host/reference/reference_batched_gemm.hpp>
#include <ck_tile/host/reference/reference_batched_masking.hpp>
#include <ck_tile/host/reference/reference_batched_softmax.hpp>
#include "fmha_fwd.hpp"
#include "fmha_fwd_v3.hpp"
#include "mask.hpp"
auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParser>
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("prec", "fp16", "data type. fp16/bf16")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s", "3328", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q & k")
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
.insert("iperm",
"0",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "0", "permute output")
.insert("causal", "0", "0: no mask, 1: causal mask")
.insert("v", "1", "0:no verify, 1:verify")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "30", "number of iterations to benchmark the kernel")
// Optional effective seqlen override (exclude PAD) for batch mode
.insert("q_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.")
.insert("kv_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
bool result = arg_parser.parse(argc, argv);
return std::make_pair(result, arg_parser);
}
enum class TensorLayout
{
bhsd,
bshd,
};
std::ostream& operator<<(std::ostream& stream, TensorLayout layout)
{
switch(layout)
{
case TensorLayout::bhsd: return stream << "bhsd";
case TensorLayout::bshd: return stream << "bshd";
default: return stream << "unknown";
}
}
struct Problem
{
explicit Problem(const ck_tile::ArgParser& args)
{
data_type = args.get_str("prec") == "fp16"
? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16
: ck_tile::fmha_fwd_v3_args::data_type_enum::bf16;
batch = args.get_int("b");
seqlen_q = args.get_int("s");
seqlen_k = args.get_int("s_k");
if(seqlen_k < 0)
{
seqlen_k = seqlen_q;
}
nhead_q = args.get_int("h");
nhead_kv = args.get_int("h_k");
if(nhead_kv < 0)
{
nhead_kv = nhead_q;
}
hdim = args.get_int("d");
softmax_scale = args.get_float("scale_s");
if(softmax_scale == .0f)
softmax_scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim));
const auto is_causal = args.get_bool("causal");
if(is_causal)
{
mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k);
}
else
{
mask = mask_info::decode("0", seqlen_q, seqlen_k);
}
input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
q_eff_lens = args.get_int_vec("q_eff_lens");
kv_eff_lens = args.get_int_vec("kv_eff_lens");
}
std::vector<ck_tile::index_t> get_query_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_q, seqlen_q, hdim};
}
else
{
return {batch, seqlen_q, nhead_q, hdim};
}
}
std::vector<ck_tile::index_t> get_key_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_kv, seqlen_k, hdim};
}
else
{
return {batch, seqlen_k, nhead_kv, hdim};
}
}
std::vector<ck_tile::index_t> get_value_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_kv, seqlen_k, hdim};
}
else
{
return {batch, seqlen_k, nhead_kv, hdim};
}
}
std::vector<ck_tile::index_t> get_output_shape() const
{
if(output_layout == TensorLayout::bhsd)
{
return {batch, nhead_q, seqlen_q, hdim};
}
else
{
return {batch, seqlen_q, nhead_q, hdim};
}
}
ck_tile::fmha_fwd_v3_args::data_type_enum data_type;
ck_tile::index_t batch;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_kv;
ck_tile::index_t hdim;
float softmax_scale;
mask_info mask;
TensorLayout input_layout;
TensorLayout output_layout;
std::vector<int> q_eff_lens;
std::vector<int> kv_eff_lens;
};
struct RunConfig
{
explicit RunConfig(const ck_tile::ArgParser& args)
{
seed = args.get_uint32("seed");
if(*seed == 0)
{
seed.reset();
}
kernel_warmup = args.get_int("warmup");
kernel_repeat = args.get_int("repeat");
verify = args.get_bool("v");
}
std::optional<uint32_t> seed;
int kernel_warmup;
int kernel_repeat;
bool verify;
};
template <typename DataType>
auto generate_qkv(const Problem& problem,
[[maybe_unused]] std::optional<uint32_t> seed = std::nullopt)
-> std::tuple<ck_tile::HostTensor<DataType>,
ck_tile::HostTensor<DataType>,
ck_tile::HostTensor<DataType>>
{
ck_tile::HostTensor<DataType> q(problem.get_query_shape());
ck_tile::HostTensor<DataType> k(problem.get_key_shape());
ck_tile::HostTensor<DataType> v(problem.get_value_shape());
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(q);
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(k);
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(v);
return std::make_tuple(q, k, v);
}
namespace host {
template <typename AccDataType,
typename PDataType,
typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename QElementOp,
typename KElementOp,
typename VElementOp,
typename SAccElementOp>
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
const ck_tile::HostTensor<KDataType>& k_bshd,
const ck_tile::HostTensor<VDataType>& v_bshd,
const mask_info& mask,
ck_tile::HostTensor<ODataType>& o_bshd,
const QElementOp& q_element_op = {},
const KElementOp& k_element_op = {},
const VElementOp& v_element_op = {},
const SAccElementOp& s_acc_element_op = {})
{
const int batch_size = q_bshd.mDesc.get_lengths()[0];
const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
const int nhead_q = q_bshd.mDesc.get_lengths()[2];
const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
const int hdim_v = v_bshd.mDesc.get_lengths()[3];
const int nr = nhead_q / nhead_kv;
ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
// do computation for each batch
for(int b = 0; b < batch_size; ++b)
{
// copy per-batch data from input tensors
// clang-format off
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); });
k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); });
v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
// clang-format on
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
if(mask.type == mask_enum::no_mask)
{
ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
}
else if(mask.type == mask_enum::window_generic)
{
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, seqlen_q, seqlen_kv));
}
else
{
// if left window size is negative, means causal
// else means generic (for current batch)
if(mask.left < 0)
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
else
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
}
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
s_host_ref, p_host_ref, ck_tile::identity{});
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
// copy resulting per-batch data to the output tensor
o_host_ref.ForEach(
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
}
}
} // namespace host
template <typename DataType>
bool run_impl(const Problem& problem, const RunConfig& run_config)
{
auto [q, k, v] = generate_qkv<DataType>(problem, run_config.seed);
ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes());
/// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v
ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes());
q_buf.ToDevice(q.data());
k_buf.ToDevice(k.data());
v_buf.ToDevice(v.data());
// Ensure output buffer is zero-initialized so padded regions compare cleanly
o_buf.SetZero();
ck_tile::fmha_fwd_v3_args args{};
args.data_type = problem.data_type;
args.batch = problem.batch;
args.seqlen_q = problem.seqlen_q;
args.seqlen_k = problem.seqlen_k;
args.nhead_q = problem.nhead_q;
args.nhead_kv = problem.nhead_kv;
args.hdim_qk = problem.hdim;
args.hdim_v = problem.hdim;
args.softmax_scale = problem.softmax_scale;
args.window_size_left = problem.mask.left;
args.window_size_right = problem.mask.right;
args.mask_type = static_cast<ck_tile::index_t>(problem.mask.type);
// bshd: (batch, seqlen_q, nhead_q, hdim)
// bhsd: (batch, nhead_q, seqlen_q, hdim)
args.q_ptr = q_buf.GetDeviceBuffer();
args.stride_q =
problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
args.nhead_stride_q =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim;
args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim;
// bshd: (batch, seqlen_k, nhead_kv, hdim)
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
args.k_ptr = k_buf.GetDeviceBuffer();
args.stride_k =
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
args.nhead_stride_k =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim;
// bshd: (batch, seqlen_k, nhead_kv, hdim)
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
args.v_ptr = v_buf.GetDeviceBuffer();
args.stride_v =
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
args.nhead_stride_v =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim;
// bshd: (batch, seqlen_q, nhead_q, hdim)
// bhsd: (batch, nhead_q, seqlen_q, hdim)
args.o_ptr = o_buf.GetDeviceBuffer();
args.stride_o =
problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
args.nhead_stride_o = problem.output_layout == TensorLayout::bshd
? problem.hdim
: problem.seqlen_q * problem.hdim;
args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim;
// Optional cumulative seqlen overrides (exclude PAD)
const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1;
const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1;
auto make_effective_vec = [&](const std::vector<int>& opt_vec, ck_tile::index_t fallback) {
std::vector<ck_tile::index_t> eff;
if(!opt_vec.empty() && opt_vec[0] != -1)
{
eff.assign(opt_vec.begin(), opt_vec.end());
if(eff.size() < static_cast<size_t>(problem.batch))
{
eff.resize(problem.batch, eff.back());
}
}
else
{
eff.assign(problem.batch, fallback);
}
return eff;
};
const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q);
const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k);
// Calculate cumulative sums for kernel arguments if varlen is used
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
auto calculate_cumulative = [&](const std::vector<ck_tile::index_t>& per_batch_vec,
std::vector<ck_tile::index_t>& cum_vec) {
cum_vec.resize(per_batch_vec.size() + 1);
cum_vec[0] = 0;
for(std::size_t i = 0; i < per_batch_vec.size(); ++i)
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
};
if(has_varlen_q)
{
calculate_cumulative(eff_q_vec, cuq_cum);
}
if(has_varlen_k)
{
calculate_cumulative(eff_kv_vec, cukv_cum);
}
ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0);
ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0);
cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr);
cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr);
args.cu_seqlen_q_ptr =
!cuq_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cuq_buf.GetDeviceBuffer())
: nullptr;
args.cu_seqlen_kv_ptr =
!cukv_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cukv_buf.GetDeviceBuffer())
: nullptr;
ck_tile::stream_config stream_config{nullptr,
true,
/*log_level=*/0,
run_config.kernel_warmup,
run_config.kernel_repeat};
auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config);
if(!result)
{
std::cerr << "faild to run fmha_fwd_v3()" << std::endl;
return false;
}
std::size_t flop = [&] {
if(problem.mask.type == mask_enum::no_mask)
{
return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
else
{
/// FIXME: Use a more accurate method; for now, were just dividing the flop by 2.
return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
}();
float tflops = static_cast<float>(flop) / 1.e9 / time;
std::cout << "[" << problem.data_type << "|";
if(problem.input_layout == problem.output_layout)
{
std::cout << problem.input_layout;
}
else
{
std::cout << problem.input_layout << "-" << problem.output_layout;
}
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
<< ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim
<< ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed
<< ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops
<< " TFlops" << std::endl;
if(!run_config.verify)
{
return true;
}
// transpose tensor descriptors from bhsd to bshd if necessary
if(problem.input_layout != TensorLayout::bshd)
{
q = q.transpose({0, 2, 1, 3});
k = k.transpose({0, 2, 1, 3});
v = v.transpose({0, 2, 1, 3});
}
ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
if(problem.output_layout != TensorLayout::bshd)
{
o_ref = o_ref.transpose({0, 2, 1, 3});
}
// If variable lengths are provided, compute per-batch references
// with the effective lengths; else compute a single full reference.
if(has_varlen_q || has_varlen_k)
{
// Variable-length aware verification: zero-fill padded region and only compute valid part.
o_ref.SetZero();
for(int b = 0; b < problem.batch; ++b)
{
const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
continue;
// Slice current batch from inputs (bshd) and build single-batch tensors
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
// Copy effective region
q_b.ForEach([&](auto& self, auto idx) {
// idx: [0, s, h, d]
self(idx) = q(b, idx[1], idx[2], idx[3]);
});
k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
host::fmha_fwd<float, DataType>(q_b,
k_b,
v_b,
problem.mask,
o_b,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
// Scatter into o_ref's bshd descriptor memory
for(int s = 0; s < seqlen_q_eff; ++s)
{
for(int h = 0; h < problem.nhead_q; ++h)
{
for(int d = 0; d < problem.hdim; ++d)
{
o_ref(b, s, h, d) = o_b(0, s, h, d);
}
}
}
}
}
else
{
// No varlen override: compute the full reference once
host::fmha_fwd<float, DataType>(q,
k,
v,
problem.mask,
o_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
}
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
o_buf.FromDevice(o.data());
const auto [rtol, atol] = [&] {
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
return std::make_tuple(1e-3, 1e-3);
else
return std::make_tuple(1e-2, 1e-2);
}();
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
}
int main(int argc, char* argv[])
{
auto [parse_result, args] = parse_cmd_args(argc, argv);
if(!parse_result)
{
std::cerr << "failed to parse command line arguments" << std::endl;
}
Problem problem(args);
RunConfig run_config(args);
const auto run = [&] {
if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16)
{
return run_impl<ck_tile::fp16_t>(problem, run_config);
}
else
{
return run_impl<ck_tile::bf16_t>(problem, run_config);
}
};
return !run();
}

View File

@@ -686,6 +686,100 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
}
template <typename FmhaKernel>
auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
{
/// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly
/// maximizes the kernel's performance.
int remap_opt = 2;
if(args.mask_type != static_cast<int>(mask_enum::no_mask) &&
((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q)))
{
if(65536 <= args.seqlen_q)
{
remap_opt = 0;
}
else
{
remap_opt = 1;
}
}
auto kargs = [&] {
if constexpr(FmhaKernel::kIsGroupMode)
{
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
0, // nhead_stride_lse
args.nhead_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
remap_opt,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
else
{
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
0, // nhead_stride_lse
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
0, // batch_stride_lse
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
remap_opt,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
}();
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaKernel>
auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
{

View File

@@ -1,60 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
#include "mask.hpp"
namespace ck_tile {
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type)
{
switch(data_type)
{
case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16";
case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16";
default: return stream << "unknown";
}
}
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config)
{
if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
else
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
}
else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
else
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
}
return std::make_pair(false, -1.f);
}
} // namespace ck_tile

View File

@@ -1,73 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <utility>
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/host/stream_config.hpp"
namespace ck_tile {
struct fmha_fwd_v3_args
{
enum class data_type_enum
{
fp16,
bf16
};
data_type_enum data_type;
// bool is_varlen;
index_t batch;
index_t seqlen_q;
index_t seqlen_k;
index_t nhead_q;
index_t nhead_kv;
index_t hdim_qk;
index_t hdim_v;
float softmax_scale;
index_t window_size_left;
index_t window_size_right;
index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and
// window_size_right == 0).
const void* q_ptr;
index_t stride_q;
index_t nhead_stride_q;
index_t batch_stride_q;
const void* k_ptr;
index_t stride_k;
index_t nhead_stride_k;
index_t batch_stride_k;
const void* v_ptr;
index_t stride_v;
index_t nhead_stride_v;
index_t batch_stride_v;
void* o_ptr;
index_t stride_o;
index_t nhead_stride_o;
index_t batch_stride_o;
// Optional batch-mode cumulative seqlen overrides (exclude PAD)
// If provided, they override per-batch effective lengths to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
};
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type);
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config);
} // namespace ck_tile

View File

@@ -1,179 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <utility>
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "fmha_fwd_v3.hpp"
#include "mask.hpp"
#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \
template <> \
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch<kernel_traits>( \
const fmha_fwd_v3_args& args, const stream_config& config) \
{ \
return std::make_pair(true, \
fmha_fwd_v3_kernel_launch<kernel_traits::kernel>(args, config)); \
}
namespace ck_tile {
template <fmha_fwd_v3_args::data_type_enum DataType>
struct fmha_fwd_v3_problem_traits;
template <>
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::fp16>
{
using qkvp_dtype = ck_tile::half_t;
using acc_dtype = float;
using o_dtype = ck_tile::half_t;
using lse_dtype = float;
};
template <>
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::bf16>
{
using qkvp_dtype = ck_tile::bf16_t;
using acc_dtype = float;
using o_dtype = ck_tile::bf16_t;
using lse_dtype = float;
};
template <fmha_fwd_v3_args::data_type_enum DataType, bool IsVariableSeqlen, bool IsMasking>
struct fmha_fwd_v3_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_variable_seqlen = IsVariableSeqlen;
static constexpr bool is_masking = IsMasking;
// M0 N0 K0 N1 K1
using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>;
using fmha_warp_gemm_shape = sequence<32, 32, 16>;
using fmha_block_warps = sequence<8, 1, 1>;
using fmha_shape = TileFmhaShape<fmha_block_tile,
fmha_block_warps,
fmha_warp_gemm_shape,
fmha_block_warps,
fmha_warp_gemm_shape,
true // IsVLayoutRowMajor
>;
using fmha_traits = TileFmhaFwdV3Traits<true, // kPadSeqLenQ
true, // kPadSeqLenK
false, // kPadHeadDimQ
false, // kPadHeadDimV
false, // kStoreLSE
-1 // kBlockPerCu
>;
using fmha_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
using fmha_pipeline_problem =
BlockFmhaFwdV3PipelineProblem<typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::lse_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
fmha_shape,
IsVariableSeqlen,
fmha_mask,
fmha_traits>;
using fmha_pipeline = BlockFmhaFwdV3Pipeline<fmha_pipeline_problem>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
true, // kPadM
true, // kPadM
true // UseRawStore
>>;
using kernel = FmhaFwdV3Kernel<fmha_pipeline, epilogue>;
};
template <typename Kernel>
float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config)
{
/// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly
/// maximizes the kernel's performance.
int remap_opt = 2;
if(args.mask_type != static_cast<int>(mask_enum::no_mask) &&
((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q)))
{
if(65536 <= args.seqlen_q)
{
remap_opt = 0;
}
else
{
remap_opt = 1;
}
}
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_qk,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_kv,
args.softmax_scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
0, // nhead_stride_lse
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
0, // batch_stride_lse
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
remap_opt,
args.cu_seqlen_q_ptr,
args.cu_seqlen_kv_ptr);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
template <typename KernelTraits>
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args,
const stream_config& config);
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -1,14 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -284,26 +284,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else if(init == 1)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed, true}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed, true}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed, true}(d_host);
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host);
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed, true}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f, seed, true}(
topk_weight_host);
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed}(d_host);
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed}(sa_host);
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f, seed}(topk_weight_host);
}
else if(init == 2)
{
ck_tile::FillNormalDistribution<ADataType>{0.f, 1.f, seed, true}(a_host);
ck_tile::FillNormalDistribution<GDataType>{0.f, 1.f, seed, true}(g_host);
ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed, true}(d_host);
ck_tile::FillNormalDistribution<AScaleDataType>{0.f, 1.f, seed, true}(sa_host);
ck_tile::FillNormalDistribution<GScaleDataType>{0.f, 1.f, seed, true}(sg_host);
ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed, true}(sd_host);
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host);
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host);
ck_tile::FillNormalDistribution<ADataType>{0.f, 1.f, seed}(a_host);
ck_tile::FillNormalDistribution<GDataType>{0.f, 1.f, seed}(g_host);
ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed}(d_host);
ck_tile::FillNormalDistribution<AScaleDataType>{0.f, 1.f, seed}(sa_host);
ck_tile::FillNormalDistribution<GScaleDataType>{0.f, 1.f, seed}(sg_host);
ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed}(sd_host);
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed}(sy_host);
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed}(topk_weight_host);
}
// permute weight

View File

@@ -9,14 +9,190 @@
#include <string>
#include <tuple>
#include <memory>
#include <type_traits>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
#include "ck_tile/host.hpp"
#include "quant_grouped_gemm.hpp"
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename QuantGroupSize,
ck_tile::QuantType QuantMode = ck_tile::QuantType::BQuantGrouped>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false, // PreshuffleQuant
GemmConfig::PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantMode,
AQLayout,
BQLayout,
GemmConfig::TransposeC,
GemmConfig::DoubleSmemBuffer,
GemmConfig::Persistent>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline =
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
GemmConfig::PreshuffleB>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
GemmConfig::TransposeC,
BDataType,
scheduler,
has_hot_loop_v,
tail_number_v>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
ADataType,
scheduler,
has_hot_loop_v,
tail_number_v>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
GemmConfig::TransposeC,
BDataType,
scheduler,
has_hot_loop_v,
tail_number_v>>;
using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
QuantGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
@@ -59,41 +235,48 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
BQLayout,
GemmConfig::TransposeC,
GemmConfig::DoubleSmemBuffer,
true>; // Persistence
GemmConfig::Persistent>;
float ave_time{0};
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
constexpr bool transpose_c = false;
using QuantGemmProblem = typename std::conditional<
QuantMode == ck_tile::QuantType::BQuantGrouped,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>,
constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped;
using QuantGemmProblem = std::conditional_t<
UseGroupedQuant,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize,
GemmConfig::TransposeC>,
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
QuantGroupSize>>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
GemmConfig::TransposeC,
BDataType,
scheduler>>::type;
scheduler>>;
using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant,
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -146,6 +329,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
int main(int argc, char* argv[])
{
int result1 = !run_grouped_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
int result1 = run_grouped_gemm_example(argc, argv);
return result1;
}

View File

@@ -64,6 +64,7 @@ struct GemmTypeConfig<ck_tile::bf8_t>
using CDataType = ck_tile::half_t;
};
template <bool Persistent_>
struct GemmConfigBase
{
static constexpr bool kPadM = false;
@@ -83,10 +84,11 @@ struct GemmConfigBase
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool PreshuffleB = false;
static constexpr bool Persistent = Persistent_;
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
template <typename PrecType, bool Persistent>
struct GemmConfigComputeV3_2 : public GemmConfigBase<Persistent>
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -101,8 +103,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
template <typename PrecType, bool Persistent>
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -121,6 +123,66 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
static constexpr bool DoubleSmemBuffer = true;
};
template <ck_tile::QuantType QuantMode>
struct GemmQuantConfig;
template <>
struct GemmQuantConfig<ck_tile::QuantType::TensorQuant>
{
template <typename PrecType, bool Persistent>
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
};
template <>
struct GemmQuantConfig<ck_tile::QuantType::RowColQuant>
{
template <typename PrecType, bool Persistent>
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<GemmProblem>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
};
template <>
struct GemmQuantConfig<ck_tile::QuantType::AQuantGrouped>
{
template <typename PrecType, bool Persistent>
using GemmConfig = GemmConfigComputeV3_2<PrecType, Persistent>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<GemmProblem>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>;
};
template <>
struct GemmQuantConfig<ck_tile::QuantType::BQuantGrouped>
{
template <typename PrecType, bool Persistent>
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<PrecType, Persistent>;
template <typename GemmProblem, bool PreshuffleB = false>
using GemmPipeline = std::conditional_t<PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<GemmProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<GemmProblem>>;
template <typename GemmProblem, bool PreshuffleB = false>
using BaseGemmPipeline =
std::conditional_t<PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmProblem>>;
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
@@ -148,8 +210,9 @@ auto create_args(int argc, char* argv[])
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")
.insert("kbatch", "1", "kbatch for SplitK")
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol")
.insert("init", "0", "0. Random, 2. One(s) (Constant)");
.insert("quant_mode", "bquant", "Choose aquant, bquant (default), tensor, or rowcol")
.insert("init", "0", "0. Random, 2. One(s) (Constant)")
.insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -57,56 +57,83 @@ float invoke_gemm(int n_warmup,
float ave_time = 0;
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
// the gemm problems known on the host. Instead, we can just pass the pointer
// to the kernel and let the workgroups figure out which tiles to work on.
// This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
assert(args[0].k_batch == 1);
for(const auto& arg : args)
if constexpr(!GemmConfig::Persistent)
{
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.aq_ptr,
arg.bq_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.QK_A,
arg.QK_B,
arg.stride_A,
arg.stride_B,
arg.stride_E,
arg.stride_AQ,
arg.stride_BQ,
arg.k_batch});
ave_time =
grouped_gemm<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
QuantGroupSize,
QuantMode>(args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
gemm_workspace.GetDeviceBuffer());
}
else
{
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
// the gemm problems known on the host. Instead, we can just pass the pointer
// to the kernel and let the workgroups figure out which tiles to work on.
// This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
if(args[0].k_batch != 1)
{
throw std::runtime_error("Split-K not supported yet for persistent kernel");
}
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.aq_ptr,
arg.bq_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.QK_A,
arg.QK_B,
arg.stride_A,
arg.stride_B,
arg.stride_E,
arg.stride_AQ,
arg.stride_BQ,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time = grouped_gemm_tileloop<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
QuantGroupSize,
QuantMode>(stream, group_count, kargs_ptr);
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time = grouped_gemm_tileloop<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
QuantGroupSize,
QuantMode>(stream, group_count, kargs_ptr);
std::string op_name = "Quant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")";
@@ -259,13 +286,24 @@ int run_grouped_gemm_example_with_layouts(int argc,
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
BQK = 0; // No B quantization
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error(
"K must be divisible by QuantGroupSize::kK for AQuantGrouped mode");
}
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
AQK = 0; // No A quantization
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error("K must be divisible by 128 for BQuantGrouped mode");
throw std::runtime_error(
"K must be divisible by QuantGroupSize::kK for BQuantGrouped mode");
}
}
@@ -284,6 +322,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
stride_AQs[i] =
ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout));
stride_BQs[i] = 0; // No B quantization
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
stride_AQs[i] = 0; // No A quantization
@@ -311,10 +355,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(1, 1, stride_BQs[i], is_row_major(bq_layout))));
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(0, 0, stride_BQs[i], is_row_major(bq_layout))));
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(0, AQK, stride_AQs[i], is_row_major(aq_layout))));
ck_tile::host_tensor_descriptor(0, 0, stride_AQs[i], is_row_major(aq_layout))));
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout))));
}
@@ -444,7 +495,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
bq_tensors[i],
c_m_n_host_ref);
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
@@ -452,6 +503,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
AccDataType,
CDataType,
QuantGroupSize,
true>(
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::reference_gemm_quant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}
@@ -477,7 +539,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
return pass;
}
template <typename GemmConfig, typename PrecType, ck_tile::QuantType QuantMode>
template <typename PrecType, ck_tile::QuantType QuantMode, typename GemmConfig>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
@@ -494,6 +556,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
@@ -511,7 +574,24 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
}
}
template <template <typename PrecType> typename GemmConfig>
template <typename PrecType, ck_tile::QuantType QuantMode>
int run_gemm_example_persistency(
std::string a_layout, std::string b_layout, bool persistent, int argc, char* argv[])
{
if(persistent)
{
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, true>;
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
a_layout, b_layout, argc, argv);
}
else
{
using GemmConfig = GemmQuantConfig<QuantMode>::template GemmConfig<PrecType, false>;
return run_gemm_example_prec_type<PrecType, QuantMode, GemmConfig>(
a_layout, b_layout, argc, argv);
}
}
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -524,29 +604,29 @@ int run_grouped_gemm_example(int argc, char* argv[])
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
std::string quant_mode = arg_parser.get_str("quant_mode");
bool persistent = arg_parser.get_bool("persistent");
if(data_type == "fp8")
{
if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "aquant")
{
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::fp8_t, ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, persistent, argc, argv);
}
else
{
@@ -557,24 +637,23 @@ int run_grouped_gemm_example(int argc, char* argv[])
{
if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "aquant")
{
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, persistent, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
return run_gemm_example_persistency<ck_tile::bf8_t, ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, persistent, argc, argv);
}
else
{

View File

@@ -71,17 +71,17 @@ int run_mx_flatmm_with_layouts(int argc,
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b);
ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b);
}
else if(init_method == 1)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b);
}
else
{

View File

@@ -69,7 +69,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>>>;
const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
@@ -128,7 +133,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<GemmConfig::PreshuffleQuant == true,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
std::conditional_t<GemmConfig::PreshuffleB == true,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
@@ -433,7 +440,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
stride_AQ = 0; // No A quantization
stride_BQ = ck_tile::get_default_stride(BQK, N, stride_BQ, is_row_major(bq_layout));
stride_BQ = ck_tile::get_default_stride(BQK, BQN, stride_BQ, is_row_major(bq_layout));
}
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{

View File

@@ -95,6 +95,47 @@ concept AccessOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
};
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
// size is deduced from block gemm structure).
template <typename T>
concept TileThreadBlockDescriptor = requires(T t) {
{ t.tile_size.m } -> std::convertible_to<size_t>;
{ t.tile_size.n } -> std::convertible_to<size_t>;
{ t.tile_size.k } -> std::convertible_to<size_t>;
};
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
// size is deduced from block gemm structure).
template <typename T>
concept TileTransferDescriptor = requires(T t) {
{ t.a_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.b_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.c_scalar_per_vector } -> std::convertible_to<size_t>;
};
// Concept to check if struct specifies block GEMM (CK Tile).
template <typename T>
concept TileBlockGemmDescriptor = requires(T t) {
{ t.warps.m } -> std::convertible_to<int>;
{ t.warps.n } -> std::convertible_to<int>;
{ t.warps.k } -> std::convertible_to<int>;
{ t.warp_tile.m } -> std::convertible_to<int>;
{ t.warp_tile.n } -> std::convertible_to<int>;
{ t.warp_tile.k } -> std::convertible_to<int>;
{ t.double_smem_buffer } -> std::convertible_to<bool>;
{ t.num_wave_groups } -> std::convertible_to<int>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept to check if struct specifies optimizations (CK Tile).
template <typename T>
concept TileOptimizationsDescriptor = requires(T t) {
{ t.num_groups_to_merge } -> std::convertible_to<int>;
{ t.split_image } -> std::convertible_to<bool>;
{ t.explicit_gemm } -> std::convertible_to<bool>;
};
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
// concept.
template <typename T>
@@ -110,6 +151,12 @@ concept SpecifiesThreadBlock = requires {
{ T::thread_block } -> ThreadBlockDescriptor;
};
// Concept to check if struct specifies thread block info (CK Tile).
template <typename T>
concept SpecifiesTileThreadBlock = requires {
{ T::thread_block } -> TileThreadBlockDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseXdlGemm = requires {
@@ -130,6 +177,14 @@ concept SpecifiesBlockTransfer = requires(T t) {
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
};
// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C.
template <typename T>
concept SpecifiesTileTransfer = requires(T t) {
{ T::transfer.a_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.b_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.c_scalar_per_vector } -> std::convertible_to<size_t>;
};
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
template <typename T>
concept SpecifiesLdsTransfer = requires(T t) {
@@ -159,8 +214,36 @@ concept SpecifiesBlockGemm = requires {
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept to check if struct specifies block GEMM (CK Tile).
template <typename T>
concept SpecifiesFwdConcSpecialization = requires {
concept SpecifiesTileBlockGemm = requires {
{ T::block_gemm.warps.m } -> std::convertible_to<int>;
{ T::block_gemm.warps.n } -> std::convertible_to<int>;
{ T::block_gemm.warps.k } -> std::convertible_to<int>;
{ T::block_gemm.warp_tile.m } -> std::convertible_to<int>;
{ T::block_gemm.warp_tile.n } -> std::convertible_to<int>;
{ T::block_gemm.warp_tile.k } -> std::convertible_to<int>;
{ T::block_gemm.double_smem_buffer } -> std::convertible_to<bool>;
{ T::block_gemm.num_wave_groups } -> std::convertible_to<int>;
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept to check if struct specifies block GEMM (CK Tile).
template <typename T>
concept SpecifiesTileOptimizations = requires {
{ T::optimizations.num_groups_to_merge } -> std::convertible_to<int>;
{ T::optimizations.split_image } -> std::convertible_to<bool>;
{ T::optimizations.explicit_gemm } -> std::convertible_to<bool>;
};
template <typename T>
concept SpecifiesTileConvSpecialization = requires {
{ T::specialization } -> std::convertible_to<TileConvSpecialization>;
};
template <typename T>
concept SpecifiesFwdConvSpecialization = requires {
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
};

View File

@@ -15,6 +15,11 @@ concept InputVectorTransferLimits = requires {
Value.lds_dst_scalar_per_vector > 0;
};
// Limits for input and output vector transfer (CK Tile).
template <auto Value>
concept TileInputOutputVectorTransferLimits =
requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; };
// Limits for output vector transfer.
template <auto Value>
concept OutputVectorTransferLimits = requires {

View File

@@ -59,6 +59,7 @@
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
namespace ck_tile::builder::factory {
@@ -81,6 +82,15 @@ namespace ck_tile::builder::factory {
//
// TODO: Make this dispatch logic much more robust and clear for users.
// CK Tile kernel
template <typename T>
consteval bool IsTileAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> && SpecifiesTileTransfer<T> &&
SpecifiesTileConvSpecialization<T> && SpecifiesTileBlockGemm<T> &&
SpecifiesTileOptimizations<T>;
}
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
template <typename T>
consteval bool IsXdlV3Algorithm()
@@ -88,7 +98,7 @@ consteval bool IsXdlV3Algorithm()
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesBlockGemm<T>;
}
@@ -99,7 +109,7 @@ consteval bool IsXdlAlgorithm()
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesNumPrefetchStages<T> && SpecifiesNumGroupsToMerge<T> &&
SpecifiesLoopScheduler<T>;
}
@@ -111,7 +121,7 @@ consteval bool IsWmmaAlgorithm()
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
}
@@ -120,7 +130,7 @@ template <typename T>
consteval bool IsDlAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
}
@@ -137,10 +147,15 @@ template <ConvSignatureDescriptor auto SIGNATURE,
StringLiteral VERSION>
constexpr auto make_conv_instance()
{
if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
// CK Tile supports common factory for each direction
if constexpr(IsTileAlgorithm<AlgoType>())
{
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
if constexpr(IsXdlV3Algorithm<AlgoType>())
{
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};

View File

@@ -7,11 +7,11 @@
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -8,12 +8,12 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -8,12 +8,12 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -8,12 +8,12 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -8,12 +8,12 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -0,0 +1,131 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp"
namespace ck_tile::builder::factory {
// Factory for CK Tile Grouped Convolution kernels.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
struct ConvTileFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::TileConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::TileConvTensorTypes<SIGNATURE.data_type>;
using Ops = internal::TileElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto CONV_SPECIALIZATION = internal::SetTileConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetTileThreadBlockInfo<ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm<ALGORITHM>();
static constexpr auto OPTIMIZATIONS = internal::SetTileOptimizations<ALGORITHM>();
static constexpr auto SCALAR_PER_VECTOR = internal::SetTileBlockTransfer<ALGORITHM>();
static constexpr auto CONV_DIRECTION = internal::SetTileConvDirection<SIGNATURE>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(TileInputOutputVectorTransferLimits<SCALAR_PER_VECTOR>);
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<SPATIAL_DIM,
CONV_SPECIALIZATION,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
SCALAR_PER_VECTOR.a,
SCALAR_PER_VECTOR.b,
SCALAR_PER_VECTOR.c,
OPTIMIZATIONS.num_groups_to_merge,
OPTIMIZATIONS.split_image,
OPTIMIZATIONS.explicit_gemm>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k>,
ck_tile::sequence<BLOCK_GEMM.warps.m, BLOCK_GEMM.warps.n, BLOCK_GEMM.warps.k>,
ck_tile::sequence<BLOCK_GEMM.warp_tile.m, BLOCK_GEMM.warp_tile.n, BLOCK_GEMM.warp_tile.k>>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
BLOCK_GEMM.double_smem_buffer,
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::AsLayout,
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::BsLayout,
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::CLayout,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
BLOCK_GEMM.num_wave_groups>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
GemmShape,
GemmUniversalTraits,
BLOCK_GEMM.scheduler,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Types::EDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename internal::TilePipelineType<
BLOCK_GEMM.pipeline_version>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
typename Types::ADataType,
typename Types::BDataType,
typename Types::DsDataTypes,
typename Types::AccDataType,
typename Types::EDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
typename Ops::CDEElementwiseOp,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK_GEMM.warps.m,
BLOCK_GEMM.warps.n,
BLOCK_GEMM.warp_tile.m,
BLOCK_GEMM.warp_tile.n,
BLOCK_GEMM.warp_tile.k,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
// TODO:: This template parameter will be moved inside the kernel
ck_tile::memory_operation_enum::set,
BLOCK_GEMM.num_wave_groups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
SCALAR_PER_VECTOR.c>>;
using Instance = typename internal::GroupedConvolutionTileKernel<SIGNATURE,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>::Instance;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,25 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
namespace ck_tile::builder::factory::internal {
struct TileScalarPerVector
{
size_t a = 0;
size_t b = 0;
size_t c = 0;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr TileScalarPerVector SetTileBlockTransfer()
{
return TileScalarPerVector{.a = ALGORITHM.transfer.a_scalar_per_vector,
.b = ALGORITHM.transfer.b_scalar_per_vector,
.c = ALGORITHM.transfer.c_scalar_per_vector};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder::factory::internal {
template <ElementwiseOperation Op>
struct ElementwiseOpToCKTile
{
static_assert(sizeof(UnsupportedEnumValue<Op>) == 0,
"Unsupported elementwise operation conversion to CK.");
};
template <>
struct ElementwiseOpToCKTile<ElementwiseOperation::PASS_THROUGH>
{
using Op = ck_tile::element_wise::PassThrough;
};
template <>
struct ElementwiseOpToCKTile<ElementwiseOperation::SCALE>
{
using Op = ck_tile::element_wise::Scale;
};
template <>
struct ElementwiseOpToCKTile<ElementwiseOperation::CLAMP>
{
using Op = ck_tile::element_wise::Clamp;
};
template <auto TensorDesc>
consteval auto GetTileElementwiseOp()
{
if constexpr(HasTensorOp<decltype(TensorDesc)>)
{
constexpr auto op = TensorDesc.operation.elementwise_operation;
return ElementwiseOpToCKTile<op>{};
}
else
{
return ElementwiseOpToCKTile<ElementwiseOperation::PASS_THROUGH>{};
}
}
template <auto Sig>
struct TileElementwiseOps
{
static constexpr auto input_op = GetTileElementwiseOp<Sig.input>();
static constexpr auto weight_op = GetTileElementwiseOp<Sig.weight>();
static constexpr auto output_op = GetTileElementwiseOp<Sig.output>();
using AElementwiseOp = typename decltype(input_op)::Op;
using BElementwiseOp = typename decltype(weight_op)::Op;
using CDEElementwiseOp = typename decltype(output_op)::Op;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,88 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_signature_concepts.hpp"
namespace ck_tile::builder::factory::internal {
template <ConvSignatureDescriptor auto SIGNATURE,
typename GroupedConvTraitsType,
typename TilePartitioner,
typename GemmPipeline,
typename ConvEpilogue>
struct GroupedConvolutionTileKernel
{
static_assert(false, "Unknown Direction");
};
template <ConvSignatureDescriptor auto SIGNATURE,
typename GroupedConvTraitsType,
typename TilePartitioner,
typename GemmPipeline,
typename ConvEpilogue>
requires ConvDirectionIsForward<SIGNATURE>
struct GroupedConvolutionTileKernel<SIGNATURE,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>
{
using Instance = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
};
template <ConvSignatureDescriptor auto SIGNATURE,
typename GroupedConvTraitsType,
typename TilePartitioner,
typename GemmPipeline,
typename ConvEpilogue>
requires ConvDirectionIsBackwardData<SIGNATURE>
struct GroupedConvolutionTileKernel<SIGNATURE,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>
{
using Instance = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
};
template <ConvSignatureDescriptor auto SIGNATURE,
typename GroupedConvTraitsType,
typename TilePartitioner,
typename GemmPipeline,
typename ConvEpilogue>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct GroupedConvolutionTileKernel<SIGNATURE,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>
{
using Instance = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
};
template <ConvSignatureDescriptor auto SIGNATURE>
consteval ck_tile::GroupedConvDirection SetTileConvDirection()
{
constexpr auto direction = SIGNATURE.direction;
using ck_tile_direction = ck_tile::GroupedConvDirection;
switch(direction)
{
case ConvDirection::FORWARD: return ck_tile_direction::FORWARD;
case ConvDirection::BACKWARD_DATA: return ck_tile_direction::BACKWARD_DATA;
case ConvDirection::BACKWARD_WEIGHT: return ck_tile_direction::BACKWARD_WEIGHT;
default: throw "Unknown Direction";
}
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,200 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
namespace ck_tile::builder::factory::internal {
using ALayout = ck_tile::tensor_layout::convolution::NWGC;
template <TensorLayout Layout>
struct LayoutToCKTile
{
static_assert(sizeof(UnsupportedEnumValue<Layout>) == 0,
"Unsupported layout conversion to CK.");
};
// Bias layouts
template <>
struct LayoutToCKTile<TensorLayout::G_K_strided>
{
using type = ck_tile::tensor_layout::convolution::G_K;
};
template <>
struct LayoutToCKTile<TensorLayout::GC>
{
using type = ck_tile::tensor_layout::convolution::GC;
};
template <>
struct LayoutToCKTile<TensorLayout::G_C_strided>
{
using type = ck_tile::tensor_layout::convolution::G_C;
};
// Input 1D
template <>
struct LayoutToCKTile<TensorLayout::NWGC>
{
using type = ck_tile::tensor_layout::convolution::NWGC;
};
template <>
struct LayoutToCKTile<TensorLayout::GNWC>
{
using type = ck_tile::tensor_layout::convolution::GNWC;
};
// Input 2D
template <>
struct LayoutToCKTile<TensorLayout::NHWGC>
{
using type = ck_tile::tensor_layout::convolution::NHWGC;
};
template <>
struct LayoutToCKTile<TensorLayout::GNHWC>
{
using type = ck_tile::tensor_layout::convolution::GNHWC;
};
// Input 3D
template <>
struct LayoutToCKTile<TensorLayout::NDHWGC>
{
using type = ck_tile::tensor_layout::convolution::NDHWGC;
};
template <>
struct LayoutToCKTile<TensorLayout::GNDHWC>
{
using type = ck_tile::tensor_layout::convolution::GNDHWC;
};
// Weight 1D
template <>
struct LayoutToCKTile<TensorLayout::GKXC>
{
using type = ck_tile::tensor_layout::convolution::GKXC;
};
template <>
struct LayoutToCKTile<TensorLayout::GKCX>
{
using type = ck_tile::tensor_layout::convolution::GKCX;
};
// Weight 2D
template <>
struct LayoutToCKTile<TensorLayout::GKYXC>
{
using type = ck_tile::tensor_layout::convolution::GKYXC;
};
template <>
struct LayoutToCKTile<TensorLayout::GKCYX>
{
using type = ck_tile::tensor_layout::convolution::GKCYX;
};
// Weight 3D
template <>
struct LayoutToCKTile<TensorLayout::GKCZYX>
{
using type = ck_tile::tensor_layout::convolution::GKCZYX;
};
template <>
struct LayoutToCKTile<TensorLayout::GKZYXC>
{
using type = ck_tile::tensor_layout::convolution::GKZYXC;
};
// Output 1D
template <>
struct LayoutToCKTile<TensorLayout::NWGK>
{
using type = ck_tile::tensor_layout::convolution::NWGK;
};
template <>
struct LayoutToCKTile<TensorLayout::GNWK>
{
using type = ck_tile::tensor_layout::convolution::GNWK;
};
// Output 2D
template <>
struct LayoutToCKTile<TensorLayout::NHWGK>
{
using type = ck_tile::tensor_layout::convolution::NHWGK;
};
template <>
struct LayoutToCKTile<TensorLayout::GNHWK>
{
using type = ck_tile::tensor_layout::convolution::GNHWK;
};
// Output 3D
template <>
struct LayoutToCKTile<TensorLayout::NDHWGK>
{
using type = ck_tile::tensor_layout::convolution::NDHWGK;
};
template <>
struct LayoutToCKTile<TensorLayout::GNDHWK>
{
using type = ck_tile::tensor_layout::convolution::GNDHWK;
};
template <TensorLayout Layout>
consteval auto TensorLayoutToCKTile()
{
return typename LayoutToCKTile<Layout>::type{};
}
struct EmptyAuxiliaryTileTensorLayout
{
using type = ck_tile::tuple<>;
};
template <auto AuxiliaryTileTensorConfigsArray, size_t... Indices>
consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence<Indices...>)
{
return ck_tile::tuple<
decltype(TensorLayoutToCKTile<AuxiliaryTileTensorConfigsArray[Indices].layout>())...>{};
}
template <auto AuxiliaryTileTensorConfigsValue, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM>)
struct AuxiliaryTileTensorLayouts
{
static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size();
using type = decltype(GetAuxiliaryTileTensorLayoutTuple<AuxiliaryTileTensorConfigsValue>(
std::make_index_sequence<Size>{}));
};
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
template <auto Signature, size_t SPATIAL_DIM>
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
consteval auto GetAuxiliaryTileTensorLayouts()
{
return AuxiliaryTileTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
SPATIAL_DIM>{};
}
template <auto Signature, size_t SPATIAL_DIM>
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
consteval auto GetAuxiliaryTileTensorLayouts()
{
return EmptyAuxiliaryTileTensorLayout{};
}
template <auto Signature, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM> &&
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
struct TileConvTensorLayouts
{
using ALayout = decltype(TensorLayoutToCKTile<Signature.input.config.layout>());
using BLayout = decltype(TensorLayoutToCKTile<Signature.weight.config.layout>());
using ELayout = decltype(TensorLayoutToCKTile<Signature.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTileTensorLayouts<Signature, SPATIAL_DIM>())::type;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,87 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/builder/types.hpp"
#include "ck_tile/builder/builder_utils.hpp"
namespace ck_tile::builder::factory::internal {
// Type mappings from builder convolution data type to CK Tile tensor types.
template <DataType T>
struct TileConvTensorTypes
{
// This will trigger if a specialization for the given DataType is not found.
// We should always catch this in an earlier validation check.
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
"Internal error. Unsupported data type for convolution factory.");
};
template <>
struct TileConvTensorTypes<DataType::FP16>
{
using ADataType = ck_tile::half_t;
using AComputeType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using BComputeType = ck_tile::half_t;
using CShuffleDataType = ck_tile::half_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = ck_tile::half_t;
};
template <>
struct TileConvTensorTypes<DataType::BF16>
{
using ADataType = ck_tile::bf16_t;
using AComputeType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using BComputeType = ck_tile::bf16_t;
using CShuffleDataType = ck_tile::bf16_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = ck_tile::bf16_t;
};
template <>
struct TileConvTensorTypes<DataType::FP32>
{
using ADataType = float;
using AComputeType = float;
using BDataType = float;
using BComputeType = float;
using CShuffleDataType = float;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = float;
};
template <>
struct TileConvTensorTypes<DataType::I8>
{
using ADataType = int8_t;
using AComputeType = int8_t;
using BDataType = int8_t;
using BComputeType = int8_t;
using CShuffleDataType = int8_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = int32_t;
using EDataType = int8_t;
};
template <>
struct TileConvTensorTypes<DataType::FP8>
{
using ADataType = ck_tile::fp8_t;
using AComputeType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using BComputeType = ck_tile::fp8_t;
using CShuffleDataType = ck_tile::fp8_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = ck_tile::fp8_t;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,32 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
namespace ck_tile::builder::factory::internal {
// Convenience struct for a tuple of m, n, and k values.
struct TileBlockMNK
{
int m{};
int n{};
int k{};
};
struct TileConvBlock
{
TileBlockMNK per_block = {};
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr TileConvBlock SetTileThreadBlockInfo()
{
constexpr auto& TB = ALGORITHM.thread_block;
return TileConvBlock{
.per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k},
};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,158 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder::factory::internal {
// Convenience struct for a tuple of m, n, and k values.
struct TileBlockGemmMNK
{
int m{};
int n{};
int k{};
};
struct TileBlockGemmSpec
{
TileBlockGemmMNK warps = {};
TileBlockGemmMNK warp_tile = {};
bool double_smem_buffer = false;
int num_wave_groups = 1;
ck_tile::GemmPipeline pipeline_version;
ck_tile::GemmPipelineScheduler scheduler;
};
struct TileOptimizations
{
int num_groups_to_merge = 1;
bool split_image = false;
bool explicit_gemm = false;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::GemmPipelineScheduler SetTileScheduler()
{
constexpr auto scheduler = ALGORITHM.block_gemm.scheduler;
using ck_tile_sched = ck_tile::GemmPipelineScheduler;
switch(scheduler)
{
case PipelineScheduler::DEFAULT: return ck_tile_sched::Default;
case PipelineScheduler::INTERWAVE: return ck_tile_sched::Interwave;
case PipelineScheduler::INTRAWAVE: return ck_tile_sched::Intrawave;
default: throw "Unknown PipelineScheduler";
}
}
template <ck_tile::GemmPipeline PipelineId>
struct TilePipelineType
{
static_assert(false, "Unknown PipelineScheduler");
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::BASIC_V1>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V5>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
{
constexpr auto version = ALGORITHM.block_gemm.pipeline_version;
using ck_tile_pipeline = ck_tile::GemmPipeline;
switch(version)
{
case PipelineVersion::V1: return ck_tile_pipeline::BASIC_V1;
case PipelineVersion::V2: return ck_tile_pipeline::MEMORY;
case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3;
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
case PipelineVersion::WEIGHT_ONLY:
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
default: throw "Unknown block GEMM PipelineVersion";
}
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::ConvolutionSpecialization SetTileConvSpecialization()
{
constexpr auto specialization = ALGORITHM.specialization;
using ck_tile_conv_spec = ck_tile::ConvolutionSpecialization;
switch(specialization)
{
case TileConvSpecialization::DEFAULT: return ck_tile_conv_spec::Default;
case TileConvSpecialization::FILTER_1X1_PAD0: return ck_tile_conv_spec::Filter1x1Pad0;
case TileConvSpecialization::FILTER_1X1_STRIDE1_PAD0:
return ck_tile_conv_spec::Filter1x1Stride1Pad0;
case TileConvSpecialization::FILTER_3x3: return ck_tile_conv_spec::Filter3x3;
default: throw "Unknown ConvFwdSpecialization";
}
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval TileBlockGemmSpec SetTileBlockGemm()
{
constexpr auto& BG = ALGORITHM.block_gemm;
constexpr bool double_smem_buffer = BG.double_smem_buffer;
constexpr int num_wave_groups = BG.num_wave_groups;
constexpr ck_tile::GemmPipeline pipeline_version = SetTileBlockGemmPipelineVersion<ALGORITHM>();
constexpr ck_tile::GemmPipelineScheduler scheduler = SetTileScheduler<ALGORITHM>();
return TileBlockGemmSpec{
.warps = {.m = BG.warps.m, .n = BG.warps.n, .k = BG.warps.k},
.warp_tile = {.m = BG.warp_tile.m, .n = BG.warp_tile.n, .k = BG.warp_tile.k},
.double_smem_buffer = double_smem_buffer,
.num_wave_groups = num_wave_groups,
.pipeline_version = pipeline_version,
.scheduler = scheduler};
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval TileOptimizations SetTileOptimizations()
{
constexpr auto& OPT = ALGORITHM.optimizations;
return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge,
.split_image = OPT.split_image,
.explicit_gemm = OPT.explicit_gemm};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -251,14 +251,10 @@ class ConvDescription : public Description
};
} // namespace conv
/// @brief Helper concept to detect if a type has ConvTraits specialization
template <typename T>
concept HasConvTraits = requires { typename conv::ConvTraits<T>; };
/// @brief Factory function to create ConvDescription from a convolution instance type
/// @tparam Instance The convolution instance type (must have InstanceTraits specialization)
/// @tparam Instance The convolution instance type (must have ConvTraits specialization)
/// @return A ConvDescription object populated with the instance's configuration details
template <HasConvTraits Instance>
template <conv::HasConvTraits Instance>
conv::ConvDescription describe()
{
using Traits = conv::ConvTraits<Instance>;

View File

@@ -4,23 +4,74 @@
#pragma once
#include <concepts>
#include <ck_tile/builder/conv_builder.hpp>
#include <ck_tile/builder/conv_signature_concepts.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_util.hpp>
#include <ck_tile/builder/types.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/utility/loop_scheduler.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
#include <ck_tile/ops/gemm.hpp>
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/pipeline_enum.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include "ck_tile/builder/conv_builder.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_util.hpp"
#include "ck_tile/builder/types.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include <ck_tile/ops/grouped_convolution.hpp>
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
namespace ck_tile::reflect::conv {
// Forward convolution layout concept - checks for A/B/E layout types
template <typename T>
concept HasFwdConvLayouts = requires {
typename T::ALayout;
typename T::BLayout;
typename T::ELayout;
};
// GEMM specialization concept - checks for kGemmSpecialization member
template <typename T>
concept HasGemmSpec = requires {
{
T::kGemmSpecialization
} -> std::convertible_to<ck::tensor_operation::device::GemmSpecialization>;
};
// Data types concept - checks for ADataType member
template <typename T>
concept HasDataTypes = requires { typename T::ADataType; };
// Elementwise operations concept - checks for A/B/CDE elementwise operation types
template <typename T>
concept HasElementwiseOps = requires {
typename T::AElementwiseOperation;
typename T::BElementwiseOperation;
typename T::CDEElementwiseOperation;
};
// Tile parameters concept - checks for tile dimension and transfer members
template <typename T>
concept HasTileParams = requires {
{ T::kKPerBlock } -> std::convertible_to<int>;
{ T::kMPerBlock } -> std::convertible_to<int>;
{ T::kNPerBlock } -> std::convertible_to<int>;
{ T::kAK1 } -> std::convertible_to<int>;
{ T::kBK1 } -> std::convertible_to<int>;
T::kCThreadClusterLengths;
};
// Comprehensive concept that checks if an instance has all XDL forward convolution traits
// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions
template <typename T>
concept IsXdlFwdConv = HasFwdConvLayouts<T> && HasGemmSpec<T> && HasDataTypes<T> &&
HasElementwiseOps<T> && HasTileParams<T>;
// Primary concept for checking if a type can be described
// Currently only forward convolutions are supported, but this can be extended
// in the future to include backward data and backward weight convolutions
template <typename T>
concept HasConvTraits = IsXdlFwdConv<InstanceTraits<T>>;
// Helper metafunctions to convert from ck enums to builder enums
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
@@ -35,16 +86,15 @@ constexpr auto convert_pipeline_version()
{
using enum ck::BlockGemmPipelineVersion;
using enum builder::PipelineVersion;
if constexpr(ck_ver == v1)
return V1;
else if constexpr(ck_ver == v2)
return V2;
else if constexpr(ck_ver == v3)
return V3;
else if constexpr(ck_ver == v4)
return V4;
else if constexpr(ck_ver == v5)
return V5;
switch(ck_ver)
{
case v1: return V1;
case v2: return V2;
case v3: return V3;
case v4: return V4;
case v5: return V5;
}
}
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
@@ -59,14 +109,14 @@ constexpr auto convert_pipeline_version()
{
using enum ck::PipelineVersion;
using enum builder::PipelineVersion;
if constexpr(ck_ver == v1)
return V1;
else if constexpr(ck_ver == v2)
return V2;
else if constexpr(ck_ver == v4)
return V4;
else if constexpr(ck_ver == weight_only)
return WEIGHT_ONLY;
switch(ck_ver)
{
case v1: return V1;
case v2: return V2;
case v4: return V4;
case weight_only: return WEIGHT_ONLY;
}
}
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
@@ -82,10 +132,12 @@ constexpr auto convert_pipeline_scheduler()
{
using enum ck::BlockGemmPipelineScheduler;
using enum builder::PipelineScheduler;
if constexpr(ck_sched == Intrawave)
return INTRAWAVE;
else if constexpr(ck_sched == Interwave)
return INTERWAVE;
switch(ck_sched)
{
case Intrawave: return INTRAWAVE;
case Interwave: return INTERWAVE;
}
}
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
@@ -101,10 +153,12 @@ constexpr auto convert_pipeline_scheduler()
{
using enum ck::LoopScheduler;
using enum builder::PipelineScheduler;
if constexpr(ck_sched == Default)
return DEFAULT;
else if constexpr(ck_sched == Interwave)
return INTERWAVE;
switch(ck_sched)
{
case Default: return DEFAULT;
case Interwave: return INTERWAVE;
}
}
/// @brief Helper structures for organizing trait data with domain-specific naming
@@ -213,21 +267,13 @@ constexpr builder::ConvDirection conv_direction()
using InstTraits = InstanceTraits<Instance>;
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
{
return builder::ConvDirection::FORWARD;
}
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
{
return builder::ConvDirection::BACKWARD_DATA;
}
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
{
return builder::ConvDirection::BACKWARD_WEIGHT;
}
else
{
return builder::ConvDirection::FORWARD; // Default fallback
}
}
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
@@ -242,60 +288,52 @@ constexpr auto conv_spec()
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
using enum builder::ConvFwdSpecialization;
if constexpr(InstTraits::kConvForwardSpecialization == Default)
switch(InstTraits::kConvForwardSpecialization)
{
return builder::ConvFwdSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0)
{
return builder::ConvFwdSpecialization::FILTER_1X1_PAD0;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3)
{
return builder::ConvFwdSpecialization::FILTER_3x3;
case Default: return DEFAULT;
case Filter1x1Pad0: return FILTER_1X1_PAD0;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
case Filter3x3: return FILTER_3x3;
}
}
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
using enum builder::ConvBwdDataSpecialization;
if constexpr(InstTraits::kConvBwdDataSpecialization == Default)
switch(InstTraits::kConvBwdDataSpecialization)
{
return builder::ConvBwdDataSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0;
case Default: return DEFAULT;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
}
}
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
using enum builder::ConvBwdWeightSpecialization;
if constexpr(InstTraits::kConvBwdWeightSpecialization == Default)
switch(InstTraits::kConvBwdWeightSpecialization)
{
return builder::ConvBwdWeightSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0)
{
return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC)
{
return builder::ConvBwdWeightSpecialization::ODD_C;
case Default: return DEFAULT;
case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0;
case Filter1x1Pad0: return FILTER_1X1_PAD0;
case OddC: return ODD_C;
}
}
}
// Helper variable template to check if CK layout enums match
template <typename A,
typename B,
typename E,
typename ExpectedA,
typename ExpectedB,
typename ExpectedE>
inline constexpr bool layouts_are =
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array corresponding to the tensor layouts:
@@ -304,112 +342,49 @@ constexpr auto conv_spec()
/// index 2 -> Output layout
template <typename Instance>
constexpr auto conv_layout()
requires HasFwdConvLayouts<InstanceTraits<Instance>>
{
using InstTraits = InstanceTraits<Instance>;
using ALayout = typename InstTraits::ALayout;
using BLayout = typename InstTraits::BLayout;
using ELayout = typename InstTraits::ELayout;
// Helper lambda to construct layout array
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
namespace ctc = ck::tensor_layout::convolution;
using A = typename InstanceTraits<Instance>::ALayout;
using B = typename InstanceTraits<Instance>::BLayout;
using E = typename InstanceTraits<Instance>::ELayout;
namespace ctl = ck::tensor_layout::convolution;
using enum builder::TensorLayout;
if constexpr(InstTraits::kSpatialDim == 1)
switch(InstanceTraits<Instance>::kSpatialDim)
{
if constexpr(std::is_same_v<ALayout, ctc::GNWC> && std::is_same_v<BLayout, ctc::GKXC> &&
std::is_same_v<ELayout, ctc::GNWK>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNWC,
builder::TensorLayout::GKXC,
builder::TensorLayout::GNWK};
}
else if constexpr(std::is_same_v<ALayout, ctc::NWGC> &&
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NWGK>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NWGC,
builder::TensorLayout::GKXC,
builder::TensorLayout::NWGK};
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NGKW>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
builder::TensorLayout::GKXC,
builder::TensorLayout::NGKW};
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
std::is_same_v<BLayout, ctc::GKCX> && std::is_same_v<ELayout, ctc::NGKW>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCW,
builder::TensorLayout::GKCX,
builder::TensorLayout::NGKW};
}
}
else if constexpr(InstTraits::kSpatialDim == 2)
{
if constexpr(std::is_same_v<ALayout, ctc::GNHWC> && std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::GNHWK>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNHWC,
builder::TensorLayout::GKYXC,
builder::TensorLayout::GNHWK};
}
else if constexpr(std::is_same_v<ALayout, ctc::NHWGC> &&
std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::NHWGK>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NHWGC,
builder::TensorLayout::GKYXC,
builder::TensorLayout::NHWGK};
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::NGKHW>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
builder::TensorLayout::GKYXC,
builder::TensorLayout::NGKHW};
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
std::is_same_v<BLayout, ctc::GKCYX> &&
std::is_same_v<ELayout, ctc::NGKHW>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCHW,
builder::TensorLayout::GKCYX,
builder::TensorLayout::NGKHW};
}
}
else if constexpr(InstTraits::kSpatialDim == 3)
{
if constexpr(std::is_same_v<ALayout, ctc::GNDHWC> && std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::GNDHWK>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::GNDHWC,
builder::TensorLayout::GKZYXC,
builder::TensorLayout::GNDHWK};
}
else if constexpr(std::is_same_v<ALayout, ctc::NDHWGC> &&
std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::NDHWGK>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NDHWGC,
builder::TensorLayout::GKZYXC,
builder::TensorLayout::NDHWGK};
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::NGKDHW>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
builder::TensorLayout::GKZYXC,
builder::TensorLayout::NGKDHW};
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
std::is_same_v<BLayout, ctc::GKCZYX> &&
std::is_same_v<ELayout, ctc::NGKDHW>)
{
return std::array<builder::TensorLayout, 3>{builder::TensorLayout::NGCDHW,
builder::TensorLayout::GKCZYX,
builder::TensorLayout::NGKDHW};
}
case 1:
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
return layouts(GNWC, GKXC, GNWK);
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
return layouts(NWGC, GKXC, NWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
return layouts(NGCW, GKXC, NGKW);
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKCX, ctl::NGKW>)
return layouts(NGCW, GKCX, NGKW);
break;
case 2:
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
return layouts(GNHWC, GKYXC, GNHWK);
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
return layouts(NHWGC, GKYXC, NHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
return layouts(NGCHW, GKYXC, NGKHW);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
return layouts(NGCHW, GKCYX, NGKHW);
break;
case 3:
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
return layouts(GNDHWC, GKZYXC, GNDHWK);
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
return layouts(NDHWGC, GKZYXC, NDHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
return layouts(NGCDHW, GKZYXC, NGKDHW);
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>)
return layouts(NGCDHW, GKCZYX, NGKDHW);
break;
}
}
@@ -418,39 +393,26 @@ constexpr auto conv_layout()
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
template <typename Instance>
constexpr builder::DataType conv_data_type()
requires HasDataTypes<InstanceTraits<Instance>>
{
using InstTraits = InstanceTraits<Instance>;
using ADataType = typename InstTraits::ADataType;
using enum builder::DataType;
if constexpr(std::is_same_v<ADataType, ck::half_t>)
{
return builder::DataType::FP16;
}
return FP16;
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
{
return builder::DataType::BF16;
}
return BF16;
else if constexpr(std::is_same_v<ADataType, float>)
{
return builder::DataType::FP32;
}
return FP32;
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
{
return builder::DataType::FP8;
}
return FP8;
else if constexpr(std::is_same_v<ADataType, int8_t>)
{
return builder::DataType::I8;
}
return I8;
else if constexpr(std::is_same_v<ADataType, uint8_t>)
{
return builder::DataType::U8;
}
return U8;
else
{
// Default fallback
return builder::DataType::FP32;
}
return FP32; // Default fallback
}
/// @brief Derives the elementwise operation from op type.
@@ -459,27 +421,19 @@ constexpr builder::DataType conv_data_type()
template <typename ElementwiseOp>
constexpr builder::ElementwiseOperation elementwise_op()
{
using enum builder::ElementwiseOperation;
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
{
return builder::ElementwiseOperation::BIAS_BNORM_CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
{
return builder::ElementwiseOperation::CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
{
return builder::ElementwiseOperation::SCALE;
}
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
{
return builder::ElementwiseOperation::PASS_THROUGH;
}
else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
{
return builder::ElementwiseOperation::SCALEADD_SCALEADD_RELU;
}
return BIAS_BNORM_CLAMP;
if constexpr(detail::case_insensitive_equal(name, "Clamp"))
return CLAMP;
if constexpr(detail::case_insensitive_equal(name, "Scale"))
return SCALE;
if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
return PASS_THROUGH;
if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu"))
return SCALEADD_SCALEADD_RELU;
}
/// @brief Derives a gemm padding from a kernel instance type.
@@ -487,6 +441,7 @@ constexpr builder::ElementwiseOperation elementwise_op()
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
template <typename Instance>
constexpr builder::GemmPadding gemm_spec()
requires HasGemmSpec<InstanceTraits<Instance>>
{
using InstTraits = InstanceTraits<Instance>;
using enum builder::GemmPadding;
@@ -494,69 +449,24 @@ constexpr builder::GemmPadding gemm_spec()
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
if constexpr(gemm_spec == Default)
switch(gemm_spec)
{
return DEFAULT;
}
else if constexpr(gemm_spec == MPadding)
{
return M_PADDING;
}
else if constexpr(gemm_spec == NPadding)
{
return N_PADDING;
}
else if constexpr(gemm_spec == KPadding)
{
return K_PADDING;
}
else if constexpr(gemm_spec == MNPadding)
{
return MN_PADDING;
}
else if constexpr(gemm_spec == MKPadding)
{
return MK_PADDING;
}
else if constexpr(gemm_spec == NKPadding)
{
return NK_PADDING;
}
else if constexpr(gemm_spec == MNKPadding)
{
return MNK_PADDING;
}
else if constexpr(gemm_spec == OPadding)
{
return O_PADDING;
}
else if constexpr(gemm_spec == MOPadding)
{
return MO_PADDING;
}
else if constexpr(gemm_spec == NOPadding)
{
return NO_PADDING;
}
else if constexpr(gemm_spec == KOPadding)
{
return KO_PADDING;
}
else if constexpr(gemm_spec == MNOPadding)
{
return MNO_PADDING;
}
else if constexpr(gemm_spec == MKOPadding)
{
return MKO_PADDING;
}
else if constexpr(gemm_spec == NKOPadding)
{
return NKO_PADDING;
}
else if constexpr(gemm_spec == MNKOPadding)
{
return MNKO_PADDING;
case Default: return DEFAULT;
case MPadding: return M_PADDING;
case NPadding: return N_PADDING;
case KPadding: return K_PADDING;
case MNPadding: return MN_PADDING;
case MKPadding: return MK_PADDING;
case NKPadding: return NK_PADDING;
case MNKPadding: return MNK_PADDING;
case OPadding: return O_PADDING;
case MOPadding: return MO_PADDING;
case NOPadding: return NO_PADDING;
case KOPadding: return KO_PADDING;
case MNOPadding: return MNO_PADDING;
case MKOPadding: return MKO_PADDING;
case NKOPadding: return NKO_PADDING;
case MNKOPadding: return MNKO_PADDING;
}
}
@@ -571,6 +481,7 @@ struct ConvTraits;
/// set of traits directly from a fully-formed device kernel `Instance` type.
/// It uses `InstanceTraits` to access the kernel's template parameters.
template <HasInstanceTraits Instance>
requires IsXdlFwdConv<InstanceTraits<Instance>>
struct ConvTraits<Instance>
{
using InstTraits = InstanceTraits<Instance>;

View File

@@ -8,28 +8,30 @@
#pragma once
#include <array>
#include <string>
#include <concepts>
#include <string_view>
#include <sstream>
#include <type_traits>
#include <limits.h>
#include <cmath>
#include <ostream>
#include <concepts>
#include <iostream>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/utility/loop_scheduler.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck_tile/ops/common/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck_tile/ops/gemm.hpp>
#include "ck_tile/ops/epilogue.hpp"
#include <limits.h>
#include <ostream>
#include <sstream>
#include <string>
#include <string_view>
#include <type_traits>
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/pipeline_enum.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include "ck/utility/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"

View File

@@ -145,6 +145,15 @@ enum struct GemmSpecialization
MNKOPadding
};
// Enums for the CK Tile convolution specialization.
enum class TileConvSpecialization
{
DEFAULT,
FILTER_1X1_PAD0,
FILTER_1X1_STRIDE1_PAD0,
FILTER_3x3
};
// Enums for the forward convolution specialization.
enum class ConvFwdSpecialization
{

View File

@@ -90,7 +90,7 @@ add_ck_builder_test(test_ckb_conv_builder
# Tests convolution trait selection and configuration
add_ck_builder_test(test_ckb_conv_traits
conv/test_conv_traits.cpp)
conv/ck/test_conv_traits.cpp)
# Tests convolution problem description and parameter handling
add_ck_builder_test(test_ckb_conv_description
@@ -119,19 +119,22 @@ add_ck_builder_test(test_ckb_instance_string
# Tests the forward convolution builder across multiple data types and dimensions.
# Individual tests are split into separate files to enable parallel compilation.
add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp
conv/test_ckb_conv_fwd_1d_fp16.cpp
conv/test_ckb_conv_fwd_1d_bf16.cpp
conv/test_ckb_conv_fwd_1d_i8.cpp
conv/test_ckb_conv_fwd_2d_fp8.cpp
conv/test_ckb_conv_fwd_2d_bf16.cpp
conv/test_ckb_conv_fwd_2d_fp16.cpp
conv/test_ckb_conv_fwd_2d_fp32.cpp
conv/test_ckb_conv_fwd_2d_dl_fp16.cpp
conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
conv/test_ckb_conv_fwd_3d_bf16.cpp
conv/test_ckb_conv_fwd_3d_fp16.cpp
conv/test_ckb_conv_fwd_3d_fp32.cpp
conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp
conv/ck/test_ckb_conv_fwd_1d_fp16.cpp
conv/ck/test_ckb_conv_fwd_1d_bf16.cpp
conv/ck/test_ckb_conv_fwd_1d_i8.cpp
conv/ck/test_ckb_conv_fwd_2d_fp8.cpp
conv/ck/test_ckb_conv_fwd_2d_bf16.cpp
conv/ck/test_ckb_conv_fwd_2d_fp16.cpp
conv/ck/test_ckb_conv_fwd_2d_fp32.cpp
conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp
conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
conv/ck/test_ckb_conv_fwd_3d_bf16.cpp
conv/ck/test_ckb_conv_fwd_3d_fp16.cpp
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp
)

View File

@@ -0,0 +1,52 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_tile_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::BACKWARD_DATA,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(TileConvSpecialization::DEFAULT)
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(FwdTileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_ck_tile_test<Builder>({
"grouped_convolution_backward_data",
"fp16",
"NHWGC_GKYXC_NHWGK",
"64x64x64",
"2x2",
"16x16x16",
// "4x4x4", // TODO: Enable this check
"Default",
"Intrawave",
"CShuffleEpilogue",
"set",
"pipeline_AgBgCrCompV3",
"DoubleSmemBuffer_0",
"NumWaveGroups_1",
"MergedGroups_1",
"SplitImage_0",
"ExplicitGemm_0",
});
}
} // namespace

View File

@@ -0,0 +1,52 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_tile_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::BACKWARD_WEIGHT,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(TileConvSpecialization::DEFAULT)
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(FwdTileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_ck_tile_test<Builder>({
"grouped_convolution_backward_weight",
"fp16",
"NHWGC_GKYXC_NHWGK",
"64x64x64",
"2x2",
"16x16x16",
// "4x4x4", // TODO: Enable this check
"Default",
"Intrawave",
"CShuffleEpilogue",
"set",
"pipeline_AgBgCrCompV3",
"DoubleSmemBuffer_0",
"NumWaveGroups_1",
"MergedGroups_1",
"SplitImage_0",
"ExplicitGemm_0",
});
}
} // namespace

View File

@@ -0,0 +1,52 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_tile_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(TileConvSpecialization::DEFAULT)
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(FwdTileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_ck_tile_test<Builder>({
"grouped_convolution_forward",
"fp16",
"NHWGC_GKYXC_NHWGK",
"64x64x64",
"2x2",
"16x16x16",
// "4x4x4", // TODO: Enable this check
"Default",
"Intrawave",
"CShuffleEpilogue",
"set",
"pipeline_AgBgCrCompV3",
"DoubleSmemBuffer_0",
"NumWaveGroups_1",
"MergedGroups_1",
"SplitImage_0",
"ExplicitGemm_0",
});
}
} // namespace

View File

@@ -243,6 +243,73 @@ struct LargeTensorWrapper
ConvAlgorithmSpecialization::LARGE_TENSOR;
};
// Specify thread block dimensions for a GEMM (CK Tile).
struct TileThreadBlock
{
// Size of the submatrix problem in a thread block.
MNK<size_t> tile_size;
};
static_assert(ckb::TileThreadBlockDescriptor<TileThreadBlock>);
struct TileTransfer
{
size_t a_scalar_per_vector;
size_t b_scalar_per_vector;
size_t c_scalar_per_vector;
};
static_assert(ckb::TileTransferDescriptor<TileTransfer>);
struct TileBlockGemm
{
// Number of warps per each dimension.
MNK<int> warps;
// Number of data processed per each dimension for each XDL/WMMA instruction.
MNK<int> warp_tile;
// Double LDS buffer.
bool double_smem_buffer;
// Waves grouping (Ping-Pong scheduler).
int num_wave_groups;
PipelineVersion pipeline_version;
PipelineScheduler scheduler;
};
static_assert(ckb::TileBlockGemmDescriptor<TileBlockGemm>);
struct TileOptimizations
{
// Number of convolution groups processed per one workgroup
int num_groups_to_merge;
// Split image for large tensors
bool split_image;
// Explicit gemm for 1x1, stride=0, pad=0 cases
bool explicit_gemm;
};
static_assert(ckb::TileOptimizationsDescriptor<TileOptimizations>);
struct TileConvSpecialization_
{
TileConvSpecialization specialization;
};
struct TileThreadBlock_
{
TileThreadBlock thread_block;
};
struct TileTransfer_
{
TileTransfer transfer;
};
struct TileBlockGemm_
{
TileBlockGemm block_gemm;
};
struct TileOptimizations_
{
TileOptimizations optimizations;
};
// Factory
template <typename... Components>
@@ -339,6 +406,51 @@ struct ConvAlgorithmTemplate : Components...
result.transfer = t;
return result;
}
template <typename S>
constexpr auto with_tile_specializations(const S& s) const
{
static_assert(std::is_base_of_v<TileConvSpecialization_, ConvAlgorithmTemplate>);
auto result = *this;
result.specialization = s;
return result;
}
template <typename TB>
constexpr auto with_tile_thread_block(const TB& tb) const
{
static_assert(std::is_base_of_v<TileThreadBlock_, ConvAlgorithmTemplate>);
auto result = *this;
result.thread_block = tb;
return result;
}
template <typename BG>
constexpr auto with_tile_block_gemm(const BG& bg) const
{
static_assert(std::is_base_of_v<TileBlockGemm_, ConvAlgorithmTemplate>);
auto result = *this;
result.block_gemm = bg;
return result;
}
template <typename T>
constexpr auto with_tile_transfer(const T& t) const
{
static_assert(std::is_base_of_v<TileTransfer_, ConvAlgorithmTemplate>);
auto result = *this;
result.transfer = t;
return result;
}
template <typename O>
constexpr auto with_tile_optimizations(const O& o) const
{
static_assert(std::is_base_of_v<TileOptimizations_, ConvAlgorithmTemplate>);
auto result = *this;
result.optimizations = o;
return result;
}
};
// Algorithm types
@@ -361,4 +473,10 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileThreadBlock_,
TileBlockGemm_,
TileTransfer_,
TileConvSpecialization_,
TileOptimizations_>;
} // namespace ck_tile::builder::test

View File

@@ -2,9 +2,10 @@
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <ck/ck.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp>
#include "ck/ck.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
namespace {

View File

@@ -2,10 +2,12 @@
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <ck/ck.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
namespace {

View File

@@ -1,17 +1,19 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <ck/ck.hpp>
#include <ck/utility/reduction_operator.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp>
#include <ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
namespace {

View File

@@ -1,16 +1,16 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <ck_tile/builder/reflect/instance_traits_util.hpp>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <gtest/gtest.h>
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include "ck/utility/sequence.hpp"
#include "ck_tile/builder/reflect/instance_traits_util.hpp"
namespace ck_tile::reflect::detail {
namespace {

View File

@@ -4,7 +4,7 @@
#include <gtest/gtest.h>
#include <type_traits>
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
namespace {

View File

@@ -4,7 +4,7 @@
#include <gtest/gtest.h>
#include <type_traits>
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "impl/conv_signature_types.hpp"
namespace {

View File

@@ -4,7 +4,7 @@
#include <gtest/gtest.h>
#include <type_traits>
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
namespace {

View File

@@ -2,7 +2,7 @@
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace {

View File

@@ -3,7 +3,7 @@
#include <gtest/gtest.h>
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
namespace {

View File

@@ -28,4 +28,20 @@ constexpr void run_test(const std::vector<std::string>& kernel_instance_componen
}
}
// Common CK Tile test implementation
template <typename Builder>
constexpr void run_ck_tile_test(const std::vector<std::string>& kernel_instance_components)
{
auto instance = typename Builder::Instance{};
const auto kernel_string = instance.GetTypeString();
std::cout << "Generated kernel: " << kernel_string << std::endl;
EXPECT_GT(kernel_string.size(), 0);
std::cout << kernel_string << std::endl;
for(const auto& component : kernel_instance_components)
{
EXPECT_THAT(kernel_string, ::testing::HasSubstr(component));
}
}
} // namespace ck_tile::builder::test_utils

View File

@@ -0,0 +1,85 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "impl/conv_algorithm_types.hpp"
#include "impl/conv_signature_types.hpp"
#include "ck_tile/builder/conv_builder.hpp"
namespace ck_tile::builder::test_utils {
using namespace ck_tile::builder;
using namespace test;
constexpr TileTransfer FwdTileTransfer_1x1x1{
.a_scalar_per_vector = 1,
.b_scalar_per_vector = 1,
.c_scalar_per_vector = 1,
};
constexpr TileTransfer FwdTileTransfer_4x4x4{
.a_scalar_per_vector = 4,
.b_scalar_per_vector = 4,
.c_scalar_per_vector = 4,
};
constexpr TileTransfer FwdTileTransfer_8x8x8{
.a_scalar_per_vector = 8,
.b_scalar_per_vector = 8,
.c_scalar_per_vector = 8,
};
constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = {
.warps = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V1,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v2_intrawave = {
.warps = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V2,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v3_intrawave = {
.warps = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V3,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v4_intrawave = {
.warps = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V4,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v5_intrawave = {
.warps = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V5,
.scheduler = PipelineScheduler::INTRAWAVE};
} // namespace ck_tile::builder::test_utils

View File

@@ -55,6 +55,11 @@
#ifndef CK_ENABLE_FP32
#define CK_ENABLE_FP32 "ON"
#endif
#ifndef CK_ENABLE_TF32
#if defined(__gfx942__) || defined(__gfx95__)
#define CK_ENABLE_TF32 "ON"
#endif
#endif
#ifndef CK_ENABLE_FP64
#define CK_ENABLE_FP64 "ON"
#endif
@@ -85,6 +90,12 @@
#cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@
#endif
#ifndef CK_ENABLE_TF32
#if defined(__gfx942__) || defined(__gfx95__)
#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@
#endif
#endif
#ifndef CK_ENABLE_FP64
#cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@
#endif

View File

@@ -5,24 +5,16 @@
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
#include <iostream>
#include <ostream>
#endif
#include "ck/utility/pipeline_enum.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp"
namespace ck {
enum struct PipelineVersion
{
v1,
v2,
// v3 is only used in the Stream-K implementation.
v4,
weight_only,
};
template <PipelineVersion PipelineVer,
index_t NumPrefetch = 1,
LoopScheduler LoopSched = LoopScheduler::Default,
@@ -62,18 +54,3 @@ constexpr auto GridwiseGemmPipeline_Selector()
}
} // namespace ck
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
{
switch(p)
{
case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break;
case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break;
case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break;
case ck::PipelineVersion::weight_only: os << "PipelineVersion::weight_only"; break;
default: os << "";
}
return os;
}
#endif

View File

@@ -3,52 +3,12 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/scheduler_enum.hpp"
namespace ck {
enum struct BlockGemmPipelineVersion
{
// For GEMM
v1, // Naive
v2, // Mem
v3, // Comp
v4, // Comp, double lds buffer
v5, // Comp, double global prefetch register buffer
// For GEMM with preshuffled weight
// v1, single lds buffer
// v2, double lds buffer
};
enum struct BlockGemmPipelineScheduler
{
Intrawave,
Interwave,
};
enum struct TailNumber
{
// Single / Double buffer pipeline
Odd,
Even,
// Long prefetch pipeline, up to 8
One,
Two,
Three,
Four,
Five,
Six,
Seven,
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
Empty,
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
// prefetchstages
Full,
};
enum SchedulerGroup : uint32_t
{
SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions

View File

@@ -3,40 +3,20 @@
#pragma once
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
#include <ostream>
#endif
#include "ck/utility/common_header.hpp"
#include "ck/utility/scheduler_enum.hpp"
namespace ck {
enum struct LoopScheduler
{
Default,
Interwave,
};
/// @brief Helper function to get default loop scheduler
/// @details Returns the default loop scheduler based on compile-time configuration.
constexpr LoopScheduler make_default_loop_scheduler()
{
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
return LoopScheduler::Interwave;
#else
return LoopScheduler::Default;
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
#endif
}
} // namespace ck
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
{
switch(s)
{
case ck::LoopScheduler::Default: os << "Default"; break;
case ck::LoopScheduler::Interwave: os << "Interwave"; break;
default: os << "";
}
return os;
}
#endif

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
#include <ostream>
#endif
namespace ck {
/// @brief Pipeline version enumeration for GEMM kernels
/// @details Defines different pipeline strategies for data movement and computation overlap
/// in GEMM kernels. This is a lightweight header containing only the enum definition,
/// extracted from gridwise_gemm_pipeline_selector.hpp to minimize dependencies.
enum struct PipelineVersion
{
v1, ///< Version 1 pipeline
v2, ///< Version 2 pipeline
// v3 is only used in the Stream-K implementation.
v4, ///< Version 4 pipeline
weight_only, ///< Weight-only specialized pipeline
};
} // namespace ck
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
{
switch(p)
{
case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break;
case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break;
case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break;
case ck::PipelineVersion::weight_only: os << "PipelineVersion::weight_only"; break;
default: os << "";
}
return os;
}
#endif

View File

@@ -0,0 +1,83 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
#include <ostream>
#endif
namespace ck {
/// @brief Block GEMM pipeline version enumeration
/// @details Defines different block GEMM pipeline strategies.
/// This is a lightweight header containing only enum definitions,
/// extracted from blkgemmpipe_scheduler.hpp to minimize dependencies.
enum struct BlockGemmPipelineVersion
{
// For GEMM
v1, ///< Naive pipeline
v2, ///< Memory-optimized pipeline
v3, ///< Compute-optimized pipeline
v4, ///< Compute-optimized with double LDS buffer
v5, ///< Compute-optimized with double global prefetch register buffer
// For GEMM with preshuffled weight
// v1, single lds buffer
// v2, double lds buffer
};
/// @brief Block GEMM pipeline scheduler enumeration
/// @details Defines scheduling strategies for block GEMM pipelines.
enum struct BlockGemmPipelineScheduler
{
Intrawave, ///< Schedule within a single wavefront
Interwave, ///< Schedule across multiple wavefronts
};
/// @brief Loop scheduler enumeration
/// @details Defines scheduling strategies for computational loops.
enum struct LoopScheduler
{
Default, ///< Default scheduling strategy
Interwave, ///< Cross-wavefront scheduling
};
/// @brief Tail number enumeration for pipeline buffering
/// @details Defines the number of tail iterations in pipelined loops.
enum struct TailNumber
{
// Single / Double buffer pipeline
Odd, ///< Odd number of iterations
Even, ///< Even number of iterations
// Long prefetch pipeline, up to 8
One, ///< One tail iteration
Two, ///< Two tail iterations
Three, ///< Three tail iterations
Four, ///< Four tail iterations
Five, ///< Five tail iterations
Six, ///< Six tail iterations
Seven, ///< Seven tail iterations
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
Empty, ///< No tail iterations
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
// prefetchstages
Full, ///< Full tail iterations
};
} // namespace ck
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
{
switch(s)
{
case ck::LoopScheduler::Default: os << "Default"; break;
case ck::LoopScheduler::Interwave: os << "Interwave"; break;
default: os << "";
}
return os;
}
#endif

View File

@@ -1552,6 +1552,81 @@ CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>&
printf("}");
}
template <typename Functor, typename LowLength>
struct functor_transform : public base_transform<1, 1>
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(LowLength{}));
Functor functor_;
UpLengths up_lengths_;
CK_TILE_HOST_DEVICE constexpr functor_transform() = default;
CK_TILE_HOST_DEVICE constexpr functor_transform(const Functor& functor,
const LowLength& low_length)
: functor_{functor}, up_lengths_{make_tuple(low_length)}
{
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = functor_(idx_up[number<0>{}]);
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& up_idx) const
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
const auto idx_low_old = idx_low;
calculate_lower_index(idx_low, up_idx);
idx_diff_low = idx_low - idx_low_old;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
// Note: When using functor_transform, ensure that the transformed coordinates
// are always valid for vectorized load/store operations.
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
return make_tuple(low_vector_lengths, low_vector_strides);
}
};
//*******************************************************************************************************
template <typename LowLength>
@@ -1671,6 +1746,13 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
return offset<LowLength, OffsetLength>{low_length, offset_length};
}
template <typename Functor, typename LowLength>
CK_TILE_HOST_DEVICE constexpr auto make_functor_transform(const Functor& functor,
const LowLength& low_length)
{
return functor_transform<Functor, LowLength>{functor, low_length};
}
} // namespace ck_tile
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"

View File

@@ -357,6 +357,12 @@ struct amdgcn_compiler_target_state
#endif // __gfx950__
// GFX10
#if defined(__gfx1010__)
static constexpr bool CK_TILE_ARCH_GFX1010 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1010 = false;
#endif
#if defined(__gfx1030__)
static constexpr bool CK_TILE_ARCH_GFX1030 = true;
#else
@@ -493,6 +499,7 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX90A, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \

View File

@@ -533,7 +533,8 @@ struct tile_scatter_gather
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(m0_init_value); // This should be wave independent
m0_set_with_memory(
amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
using Traits = load_store_traits;

View File

@@ -1263,7 +1263,9 @@ struct tile_window_with_static_lengths
}
};
template <typename TensorView_, typename WindowLengths_>
template <typename TensorView_,
typename WindowLengths_,
typename = std::enable_if_t<is_tensor_view_v<TensorView_>>>
CK_TILE_DEVICE constexpr auto
make_tile_window(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
@@ -1310,7 +1312,10 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
tile_distribution);
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,
typename = std::enable_if_t<is_tile_distribution_v<StaticTileDistribution>>>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution,

View File

@@ -517,7 +517,8 @@ struct tile_window_linear
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(m0_init_value); // This should be wave independent
m0_set_with_memory(
amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
using vector_t = typename Base::Traits::vector_t;

View File

@@ -33,59 +33,73 @@ namespace ck_tile {
* @example
*
* // Direct usage without creating a separate variable:
* ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host_tensor);
* ck_tile::FillUniformDistribution<>{-1.f, 1.f}(a_host_tensor);
*/
template <typename T>
template <typename T = void>
struct FillUniformDistribution
{
float a_{-5.f};
float b_{5.f};
std::optional<uint32_t> seed_{11939};
// ATTENTION: Whether to use multi-threading (note: not guaranteed to be perfectly distributed
// across threads).
bool threaded = false;
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
if(threaded)
{
uint32_t num_thread = std::thread::hardware_concurrency();
auto total = static_cast<std::size_t>(std::distance(first, last));
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
if(first == last)
return;
using T_iter = std::decay_t<decltype(*first)>;
static_assert(std::is_same_v<T, T_iter> || std::is_void_v<T>,
"Iterator value type must match template type T");
constexpr auto PackedSize = numeric_traits<T_iter>::PackedSize;
const auto total = static_cast<size_t>(std::distance(first, last));
const auto total_bytes = total * sizeof(T_iter);
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
auto thread_f = [this, total, iw_begin, iw_end, &first] {
if(iw_begin > total || iw_end > total)
return;
// need to make each thread unique, add an offset to current seed
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
: std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
if constexpr(numeric_traits<T>::PackedSize == 2)
return ck_tile::type_convert<T>(fp32x2_t{dis(gen), dis(gen)});
else
return ck_tile::type_convert<T>(dis(gen));
});
};
threads[it] = joinable_thread(thread_f);
}
}
else
// max 80 threads; at least 2MB per thread
const size_t available_cpu_cores = get_available_cpu_cores();
const size_t num_thread =
min(80UL, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL));
constexpr size_t BLOCK_BYTES = 64;
constexpr size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T_iter);
const size_t num_blocks = integer_divide_ceil(total_bytes, BLOCK_BYTES);
const size_t blocks_per_thread = integer_divide_ceil(num_blocks, num_thread);
// use minstd_rand for better performance on discard()
std::minstd_rand gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::vector<joinable_thread> threads;
threads.reserve(num_thread - 1); // last job run in the main thread
for(int it = num_thread - 1; it >= 0; --it)
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first, last, [&dis, &gen]() {
if constexpr(numeric_traits<T>::PackedSize == 2)
return ck_tile::type_convert<T>(fp32x2_t{dis(gen), dis(gen)});
else
return ck_tile::type_convert<T>(dis(gen));
});
const size_t ib_begin = it * blocks_per_thread;
const size_t ib_end = min(ib_begin + blocks_per_thread, num_blocks);
auto job = [=]() {
auto g_ = gen; // copy
auto d_ = dis; // copy
g_.discard(ib_begin * BLOCK_SIZE * PackedSize);
auto t_fn = [&]() {
if constexpr(PackedSize == 2)
return type_convert<T_iter>(fp32x2_t{d_(g_), d_(g_)});
else
return type_convert<T_iter>(d_(g_));
};
size_t ib = ib_begin;
for(; ib < ib_end - 1; ++ib) // full blocks
static_for<0, BLOCK_SIZE, 1>{}([&](auto iw_) {
constexpr size_t iw = iw_.value;
*(first + ib * BLOCK_SIZE + iw) = t_fn();
});
for(size_t iw = 0; iw < BLOCK_SIZE; ++iw) // last block
if(ib * BLOCK_SIZE + iw < total)
*(first + ib * BLOCK_SIZE + iw) = t_fn();
};
if(it > 0)
threads.emplace_back(std::move(job));
else
job(); // last job run in the main thread
}
}

View File

@@ -3,6 +3,9 @@
#pragma once
#ifdef __linux__
#include <sched.h>
#endif
#include <thread>
#include <utility>
@@ -24,4 +27,50 @@ struct joinable_thread : std::thread
this->join();
}
};
inline unsigned int get_available_cpu_cores()
{
#if defined(__linux__)
cpu_set_t cpu_set;
if(sched_getaffinity(0, sizeof(cpu_set_t), &cpu_set) == 0)
{
unsigned int cpu_count = CPU_COUNT(&cpu_set);
if(cpu_count > 0)
return cpu_count;
}
#endif
// Fallback if sched_getaffinity unavailable or fails
return std::thread::hardware_concurrency();
}
class cpu_core_guard
{
#if defined(__linux__)
cpu_set_t original_cpu_set_;
public:
cpu_core_guard(unsigned int num_cores) : original_cpu_set_()
{
// save original cpu set
sched_getaffinity(0, sizeof(cpu_set_t), &original_cpu_set_);
// set new cpu set
cpu_set_t new_cpu_set;
CPU_ZERO(&new_cpu_set);
for(unsigned int i = 0; i < num_cores; ++i)
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
CPU_SET(i, &new_cpu_set); // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
sched_setaffinity(0, sizeof(cpu_set_t), &new_cpu_set);
}
~cpu_core_guard()
{
// restore original cpu set
sched_setaffinity(0, sizeof(cpu_set_t), &original_cpu_set_);
}
#endif
};
} // namespace ck_tile

View File

@@ -1259,12 +1259,12 @@ struct MoeFlatmmKernel
auto fused_token =
kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
index_t scatter_token_id = fused_token & token_id_mask;
index_t scatter_token_id = fused_token & token_id_mask;
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
if constexpr(IsInputGemm)
scatter_token_id =
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
});
});

View File

@@ -18,21 +18,21 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
{
using Underlying = FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using MXFlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
using BlockGemmShape =
remove_cvref_t<typename MXFlatmmPipeline_::BlockGemmShape>; // TileFlatmmShape
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
using ALayout = remove_cvref_t<typename MXFlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename MXFlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename MXFlatmmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
static constexpr index_t KernelBlockSize = MXFlatmmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = MXFlatmmPipeline::UsePersistentKernel;
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
using ADataType = remove_cvref_t<typename MXFlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename MXFlatmmPipeline::BDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
@@ -43,9 +43,9 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
static constexpr int APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr int BPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr int MXdlPack = FlatmmPipeline::MXdlPack;
static constexpr int NXdlPack = FlatmmPipeline::NXdlPack;
static constexpr int KXdlPack = FlatmmPipeline::KXdlPack;
static constexpr int MXdlPack = MXFlatmmPipeline::MXdlPack;
static constexpr int NXdlPack = MXFlatmmPipeline::NXdlPack;
static constexpr int KXdlPack = MXFlatmmPipeline::KXdlPack;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -63,7 +63,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, MXFlatmmPipeline::GetName());
// clang-format on
}
@@ -123,33 +123,23 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
const SplitKBatchOffset& splitk_batch_offset)
{
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
"A tensor for mx must be RowMajor");
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<MXFlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}();
constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock;
constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock;
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
const index_t kFlatN = kargs.N / kNWarpTile;
const auto& b_flat_tensor_view = [&]() {
static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0,
static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0,
"wrong! vector size for B tensor");
auto&& naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
@@ -262,20 +252,12 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
"A tensor for mx must be RowMajor");
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, MXFlatmmPipeline::kPadK>{});
}();
const auto& b_flat_tensor_view = views.at(I1);
@@ -289,14 +271,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
sequence<false, MXFlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
sequence<false, MXFlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
@@ -309,14 +291,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
sequence<false, MXFlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<FlatmmPipeline::kPadM, false>{});
sequence<MXFlatmmPipeline::kPadM, false>{});
}
}();
@@ -334,26 +316,18 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
const auto& e_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>,
"A tensor for mx must be RowMajor");
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}();
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
make_tuple(number<MXFlatmmPipeline::flatNPerWarp>{},
number<MXFlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
const auto ds_block_window = generate_tuple(
@@ -444,14 +418,14 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
a_block_window.get_window_lengths(),
a_block_window.get_window_origin(),
FlatmmPipeline::GetADramTileDistribution());
const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
b_flat_block_window,
scale_a_block_window,
scale_b_block_window,
num_loop,
smem_ptr_ping,
smem_ptr_pong);
MXFlatmmPipeline::GetADramTileDistribution());
const auto& c_block_tile = MXFlatmmPipeline{}(a_block_window_with_distr,
b_flat_block_window,
scale_a_block_window,
scale_b_block_window,
num_loop,
smem_ptr_ping,
smem_ptr_pong);
// Run Epilogue Pipeline
if constexpr(DoEpiScale)
@@ -487,10 +461,10 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
splitk_batch_offset.a_k_split_offset / APackedSize;
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
splitk_batch_offset.b_k_split_offset / BPackedSize;
const auto a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
splitk_batch_offset.a_k_split_offset / APackedSize;
const auto b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
splitk_batch_offset.b_k_split_offset / BPackedSize;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS
@@ -501,7 +475,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
b_flat_ptr,
kargs.ds_ptr,

Some files were not shown because too many files have changed in this diff Show More