Merge remote-tracking branch 'origin/develop' into tianyuwu/ck_tile/WMMA_GEMM_F16

This commit is contained in:
Tianyuan Wu
2025-08-12 07:10:01 +00:00
30 changed files with 3101 additions and 817 deletions

View File

@@ -6,6 +6,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
### Added
* Added a basic copy kernel example and supporting documentation for new CK Tile developers.
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels.
* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced).

View File

@@ -1,16 +1,20 @@
ARG BASE_DOCKER="rocm/pytorch:latest"
FROM $BASE_DOCKER
RUN groupadd -f render && \
ARG AITER_BRANCH="main"
ARG CK_AITER_BRANCH="develop"
RUN groupadd -g 109 render && \
usermod -u 1001 jenkins && \
groupmod -g 1001 jenkins && \
pip install pandas zmq einops && \
pip install numpy==1.26.2 && \
sudo mkdir /home/jenkins && \
sudo mkdir /home/jenkins/workspace && \
cd /home/jenkins/workspace && \
rm -rf aiter && \
git clone --recursive https://github.com/ROCm/aiter.git && \
git clone -b "$AITER_BRANCH" --recursive https://github.com/ROCm/aiter.git && \
cd aiter && \
rm -rf 3rdparty/composable_kernel/ && \
git clone https://github.com/ROCm/composable_kernel.git 3rdparty/composable_kernel/ && \
git clone -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git 3rdparty/composable_kernel/ && \
python3 setup.py develop && \
chown -R jenkins:jenkins /home/jenkins/workspace && \
chmod -R a+rwx /home/jenkins/workspace && \

165
Jenkinsfile vendored
View File

@@ -190,7 +190,7 @@ def buildDocker(install_prefix){
}
else if(params.RUN_AITER_TESTS){
image_name = "rocm/composable_kernel:ck_aiter"
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter . "
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . "
}
else{
dockerArgs = dockerArgs + " -f Dockerfile . "
@@ -438,34 +438,6 @@ def cmake_build(Map conf=[:]){
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
}
}
if (params.RUN_CK_TILE_TRANSPOSE_TESTS){
try{
archiveArtifacts "perf_transpose_*.log"
if (arch_type == 1){
stash includes: "perf_transpose_**_gfx90a.log", name: "perf_transpose_log_gfx90a"
}
else if (arch_type == 2){
stash includes: "perf_transpose_**_gfx942.log", name: "perf_transpose_log_gfx942"
}
}
catch(Exception err){
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
}
}
if (params.RUN_CK_TILE_GEMM_TESTS){
try{
archiveArtifacts "perf_tile_gemm_**.log"
if (arch == 1){
stash includes: "perf_tile_gemm_**_gfx90a.log", name: "perf_tile_gemm_log_gfx90a"
}
else if (arch == 2){
stash includes: "perf_tile_gemm_**_gfx942.log", name: "perf_tile_gemm_log_gfx942"
}
}
catch(Exception err){
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
}
}
}
def buildHipClangJob(Map conf=[:]){
@@ -762,24 +734,6 @@ def process_results(Map conf=[:]){
echo "could not locate the FMHA performance logs: ${err.getMessage()}."
}
}
if (params.RUN_CK_TILE_TRANSPOSE_TESTS){
try{
unstash "perf_transpose_log_gfx942"
unstash "perf_transpose_log_gfx90a"
}
catch(Exception err){
echo "could not locate the Transpose performance logs: ${err.getMessage()}."
}
}
if (params.RUN_CK_TILE_GEMM_TESTS){
try{
unstash "perf_tile_gemm_log_gfx942"
unstash "perf_tile_gemm_log_gfx90a"
}
catch(Exception err){
echo "could not locate the GEMM performance logs: ${err.getMessage()}."
}
}
if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){
// unstash deb packages
unstash "packages"
@@ -843,10 +797,11 @@ def run_aiter_tests(Map conf=[:]){
withDockerContainer(image: image, args: dockerOpts) {
timeout(time: 45, unit: 'MINUTES'){
try{
sh "python3 --version"
sh "rocminfo"
sh "python3 ../aiter/op_tests/test_gemm_a8w8_blockscale.py"
//sh "python3 ../aiter/op_tests/test_mha.py"
sh "python3 --version"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py"
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py"
}
catch(e){
echo "Throwing error exception while running AITER tests"
@@ -861,7 +816,7 @@ def run_aiter_tests(Map conf=[:]){
}
//launch develop branch daily jobs
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
@@ -941,14 +896,6 @@ pipeline {
name: "RUN_CK_TILE_FMHA_TESTS",
defaultValue: false,
description: "Run the ck_tile FMHA tests (default: OFF)")
booleanParam(
name: "RUN_CK_TILE_TRANSPOSE_TESTS",
defaultValue: false,
description: "Run the ck_tile Transpose tests (default: OFF)")
booleanParam(
name: "RUN_CK_TILE_GEMM_TESTS",
defaultValue: false,
description: "Run the ck_tile GEMM tests (default: OFF)")
booleanParam(
name: "RUN_TILE_ENGINE_GEMM_TESTS",
defaultValue: false,
@@ -1009,6 +956,14 @@ pipeline {
name: "RUN_AITER_TESTS",
defaultValue: false,
description: "Run AITER tests with latest CK develop branch (default: OFF)")
string(
name: 'aiter_branch',
defaultValue: 'main',
description: 'Specify which branch of AITER to use (default: main)')
string(
name: 'ck_aiter_branch',
defaultValue: 'develop',
description: 'Specify which branch of CK to test with AITER (default: develop)')
}
environment{
dbuser = "${dbuser}"
@@ -1093,13 +1048,13 @@ pipeline {
{
parallel
{
stage("Run AITER Tests on gfx90a")
stage("Run AITER Tests on gfx942")
{
when {
beforeAgent true
expression { params.RUN_AITER_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx90a")}
agent{ label rocmnode("gfx942")}
steps{
run_aiter_tests()
cleanWs()
@@ -1198,94 +1153,6 @@ pipeline {
}
}
}
stage("Run CK_TILE_TRANSPOSE Tests")
{
parallel
{
stage("Run CK_TILE_TRANSPOSE Tests on gfx90a")
{
when {
beforeAgent true
expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
make -j64 tile_example_batched_transpose && \
cd ../ &&
example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
stage("Run CK_TILE_TRANSPOSE Tests on gfx942")
{
when {
beforeAgent true
expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx942") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
make -j64 tile_example_batched_transpose && \
cd ../ &&
example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
}
}
stage("Run CK_TILE_GEMM Tests")
{
parallel
{
stage("Run CK_TILE_GEMM Tests on gfx90a")
{
when {
beforeAgent true
expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
make -j64 tile_example_gemm_universal && \
cd ../ &&
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
stage("Run CK_TILE_GEMM Tests on gfx942")
{
when {
beforeAgent true
expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx942") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
make -j64 tile_example_gemm_universal && \
cd ../ &&
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
}
}
stage("Run TILE_ENGINE_GEMM Tests")
{
parallel

View File

@@ -96,7 +96,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
4. Build the entire CK library:
```bash
make -j
make -j"$(nproc)"
```
5. Install CK:
@@ -213,4 +213,4 @@ script/uninstall_precommit.sh
```
If you need to temporarily disable pre-commit hooks, you can add the `--no-verify` option to the
`git commit` command.
`git commit` command.

View File

@@ -83,6 +83,7 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
{F_deterministic},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
{F_trload},
fmha_bwd_trait_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = ck_tile::BlockFmhaBwdDQDKDVPipeline<fmha_bwd_pipeline_problem_{F_idx}>;
@@ -113,7 +114,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dbias},
{F_dpad},
{F_dvpad},
{F_deterministic}>;
{F_deterministic},
{F_trload}>;
#include <iostream>
@@ -168,29 +170,35 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
template <>
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
const bool has_load_tr = ck_tile::is_load_tr_supported();
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
FMHA_BWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{
{F_body}
}}
"""
FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
{F_inner_dispatch}
}}
FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_body}
}}
"""
FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
{F_body}
}}
"""
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r;
}}
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r;
}}
"""
# M0 size for 1d kernels (dot/convert)
@@ -250,6 +258,7 @@ class FmhaBwdDQDKDVKernel:
F_mode : str # value from MODE_MAP
F_deterministic : str #
mask_impl : str #
F_trload : str #
@property
def template(self) -> str:
@@ -291,6 +300,7 @@ class FmhaBwdDQDKDVKernel:
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_deterministic = BOOL_MAP[self.F_deterministic],
F_trload = BOOL_MAP[self.F_trload],
)
@property
@@ -324,6 +334,9 @@ class FmhaBwdDQDKDVKernel:
if self.F_deterministic == 't' : n += '_deterministic'
else: n += '_ndeterministic'
if self.F_trload == 't' : n += '_trload'
else: n += '_ntrload'
return n
@property
@@ -332,8 +345,8 @@ class FmhaBwdDQDKDVKernel:
# TODO: design a more practical way to do it
# this is current supported tile size.
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str, tr_load: str) -> Optional[dict]:
if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f':
return {
'32' : FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
'64' : FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
@@ -341,6 +354,10 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
# '160' : FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
'256' : FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
}
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't':
return {
'128' : FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
}
else:
return None
@@ -573,6 +590,7 @@ class FmhaBwdApiTrait:
dvpad : str
deterministic : str
mask_impl : str
tr_load : bool
@property
def bm0(self) -> int:
@@ -620,7 +638,7 @@ class FmhaBwdApiTrait:
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile,
F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout,
F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl)
F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl, F_trload=self.tr_load)
@property
def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel:
@@ -636,12 +654,13 @@ class FmhaBwdApiTrait:
class FmhaBwdApiPool:
def __init__(self, mask_impl):
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(list))
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
self.mask_impl = mask_impl
def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None:
# TODO: do we need to check duplication?
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
self.dq_dk_dv_pool[trait.tr_load][trait.dtype][trait.hdim].append(copy.copy(trait))
@staticmethod
def if_(i: int) -> str:
@@ -656,24 +675,31 @@ class FmhaBwdApiPool:
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic])
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load])
i += 1
return inners
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.dq_dk_dv_pool):
per_hdim_case=str()
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype]):
traits=self.dq_dk_dv_pool[dtype][hdim]
inners = self._api_innders(traits)
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(j), F_hdim=hdim, F_inner_dispatch=inners)
per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(i), F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
tr_load_cond_map = {
"t": "has_load_tr",
"f": "true"
}
per_tr_load = ''
for tr_load in ["t", "f"]:
per_dtypes = ''
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load]):
per_hdim_case = ''
for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][dtype]):
traits = self.dq_dk_dv_pool[tr_load][dtype][hdim]
inners = self._api_innders(traits)
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(k), F_hdim=hdim, F_body=inners)
per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(j), F_dtype=dtype, F_body=per_hdim_case)
per_tr_load += FMHA_BWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_body=per_dtypes)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
per_tr_load += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load)
def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]:
if filter_list == '':
@@ -690,8 +716,8 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {}
api_pool = FmhaBwdApiPool(mask_impl)
for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype, tr_load)
if d is None:
continue
for hdim_str, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
@@ -703,7 +729,9 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
continue
if ("wg32" in dropout):
continue
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl)
if tr_load == "t" and (dpad == "t" or dvpad == "t"):
continue # tr_load cannot work with dpad or dvpad
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load)
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
continue

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
@@ -363,7 +364,8 @@ template <ck_tile::index_t HDim_,
bool kHasBiasGrad_,
bool kPadD_,
bool kPadDv_,
bool kIsDeterministic_>
bool kIsDeterministic_,
bool kUseTrLoad_>
struct fmha_bwd_dq_dk_dv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -376,6 +378,7 @@ struct fmha_bwd_dq_dk_dv_traits_
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
};
template <typename Traits_>

View File

@@ -526,9 +526,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
sizeof(KDataType) * real_seqlen_k * hdim_q +
sizeof(VDataType) * hdim_v * real_seqlen_k +
sizeof(ODataType) * real_seqlen_q * hdim_v);
num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q +
sizeof(VDataType) * hdim_v * real_seqlen_k);
}
}

View File

@@ -0,0 +1,7 @@
add_executable(tile_example_copy EXCLUDE_FROM_ALL copy_basic.cpp)
# Impact: This flag ensures that the compiler doesn't make
# assumptions about memory aliasing that could interfere with Composable Kernel's explicit memory access patterns.
target_compile_options(tile_example_copy PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)

View File

@@ -0,0 +1,313 @@
# CK Tile Framework: Getting Started with Tile Copy Operations
## Overview
### Copy Kernel
A minimal CK_Tile memory copy implementation demonstrating the basic setup required to write a kernel in CK Tile.
This experimental kernel is intended for novice CK developers. It introduces the building blocks of CK Tile and provides a sandbox for experimenting with kernel parameters.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture
# (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# Make the copy kernel executable
make tile_example_copy -j
```
This will result in an executable `build/bin/test_copy_basic`
## example
```
args:
-m input matrix rows. (default 64)
-n input matrix cols. (default 8)
-id wave to use for computation. (default 0)
-v validation flag to check device results. (default 1)
-prec datatype precision to use. (default fp16)
-warmup no. of warmup iterations. (default 50)
-repeat no. of iterations for kernel execution time. (default 100)
```
## CK Tile Architecture Components
The CK Tile framework is built around four key architectural components that work together to define and execute GPU kernels: shape, policy, problem, and pipeline.
### **1. Shape**
Defines the **hierarchical tile structure** and **memory layout** of the kernel:
```cpp
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, Vector>;
```
**Components:**
- **BlockWaves**: Number of concurrent waves per block (e.g., `seq<4, 1>` for 4 waves along M, 1 along N)
- **BlockTile**: Total elements processed by one block (e.g., `seq<512, 8>`)
- **WaveTile**: Elements processed by one wave (e.g., `seq<32, 8>`)
- **Vector**: Elements processed by one thread (e.g., `seq<1, 4>` for 4 contiguous elements)
**Purpose**: Defines the **work distribution hierarchy** from threads → waves → blocks.
### **2. Problem**
Defines the **data types** and **kernel configuration**:
```cpp
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
```
**Components:**
- **XDataType**: Input/output data type (e.g., `float`, `half`)
- **Shape**: The tile shape defined above
**Purpose**: Encapsulates **what** the kernel operates on and **how** it's configured.
### **3. Policy**
Defines the **memory access patterns** and **distribution strategies**:
```cpp
using Policy = ck_tile::TileCopyPolicy<Problem>;
```
**Key Functions:**
- **MakeDRAMDistribution()**: Defines how threads access DRAM memory.
**Purpose**: Defines **how** data is accessed and distributed across threads.
### **4. Pipeline**
Defines the **execution flow** and **memory movement patterns**:
```cpp
// Example pipeline stages:
// 1. DRAM → Registers (load_tile)
// 2. Registers → LDS (store_tile)
// 3. LDS → Registers (load_tile with distribution)
// 4. Registers → DRAM (store_tile)
```
**Purpose**: Defines the **sequence of operations** and **memory movement strategy**.
### **Component Interaction**
```cpp
// Complete kernel definition
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, Vector>;
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
using Policy = ck_tile::TileCopyPolicy<Problem>;
using Kernel = ck_tile::TileCopyKernel<Problem, Policy>;
```
**Flow:**
1. **Shape** defines the tile structure and work distribution
2. **Problem** combines data types with the shape
3. **Policy** defines memory access patterns for the problem
4. **Kernel** implements the actual computation using all components
### **Why This Architecture?**
#### **Separation of Concerns**
- **Shape**: Focuses on **work distribution** and **tile structure**
- **Problem**: Focuses on **data types** and **configuration**
- **Policy**: Focuses on **memory access** and **optimization**
- **Pipeline**: Focuses on **execution flow** and **synchronization**
#### **Reusability**
- Same **Shape** can be used with different **Problems**
- Same **Policy** can be applied to different **Shapes**
- **Pipelines** can be reused across different kernels
#### **Performance Optimization**
- **Shape** enables optimal work distribution
- **Policy** enables optimal memory access patterns
- **Pipeline** enables optimal execution flow
## Core Concepts
### Hierarchical Tile Structure
The CK Tile framework organizes work in a hierarchical manner:
1. **Vector**: Number of contiguous elements processed by a single thread
- Enables vectorized memory loads/stores.
- Example: `Vector = seq<1, 4>` means each thread loads 4 contiguous elements along the N dimension
- A Vector can be imagined as a thread-level tile
2. **WaveTile**: Number of elements covered by a single wave (64 threads on AMD)
- Must satisfy: `Wave_Tile_M / Vector_M * Wave_Tile_N / Vector_N == WaveSize`
- This ensures the number of threads needed equals the wave size
- Example: `WaveTile = seq<64, 4>` with `Vector = seq<1, 4>` means:
- Each thread handles 4 elements (Vector_N = 4)
- Wave needs 64×4/4 = 64 threads to cover 64×4 = 256 elements
- Total elements = 256, which requires WaveSize = 64 threads
3. **BlockTile**: Number of elements covered by one block (typically mapped to one CU)
- Example: `BlockTile = seq<256, 64>` means each block processes 256×64 elements
4. **BlockWaves**: Number of concurrent waves active in a block
- Usually 4 waves per block on modern AMD GPUs
- Example: `BlockWaves = seq<4, 1>` means 4 waves along M dimension, 1 along N
### Wave Repetition
In many scenarios, the total work (BlockTile) is larger than what the available waves can cover in a single iteration. This requires **wave repetition**:
```cpp
// Calculate how many times a wave needs to repeat to cover the entire block tile
static constexpr index_t WaveRepetitionPerBlock_M =
Block_Tile_M / (Waves_Per_Block_M * Wave_Tile_M);
static constexpr index_t WaveRepetitionPerBlock_N =
Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N);
```
**Key Insight**: When waves repeat, the effective work per thread becomes `Vector * Repeat`, not just `Vector`.
## Tile Distribution Encoding
The tile distribution encoding specifies how work is distributed across threads:
```cpp
constexpr auto outer_encoding =
tile_distribution_encoding<sequence<1>, // replication
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>, // hierarchy
tuple<sequence<1>, sequence<1, 2>>, // parallelism
tuple<sequence<1>, sequence<2, 0>>, // paralleism
sequence<1, 2>, // yield
sequence<0, 1>>{}; // yield
```
### Encoding Parameters Explained
- **M0, M1, M2**: Hierarchical distribution along M dimension
- M0: Number of wave iterations along M
- M1: Number of waves along M
- M2: Number of threads per wave along M
- **N0, N1**: Distribution along N dimension
- N0: Number of threads along N
- N1: Vector size (elements per thread)
- **YIELD arguments**: Both `Repeat` and `Vector` because effective work per thread is `Vector * Repeat`
## Tensor Abstractions
### Tensor Descriptor
Defines the logical structure of a tensor:
```cpp
auto desc = make_naive_tensor_descriptor(
make_tuple(M, N), // tensor dimensions
make_tuple(N, 1), // strides
number<Vector_N>{}, // vector length for vectorized access
number<1>{} // guaranteed last dimension vector stride
);
```
### Tensor View
Combines memory buffer with tensor descriptor:
```cpp
auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, // memory buffer
make_tuple(M, N), // dimensions
make_tuple(N, 1), // strides
number<S::Vector_N>{}, // vector length
number<1>{} // guaranteed last dimension vector stride
);
```
### Tile Window
A view into a specific tile of the tensor with thread distribution:
```cpp
auto x_window = make_tile_window(
x_m_n, // tensor view
make_tuple(Block_Tile_M, Block_Tile_N), // tile size
{iM, 0}, // tile origin
tile_distribution // how work is distributed among threads
);
```
## The test_copy_basic Kernel
### Kernel Structure
The `TileCopyKernel` implements a basic copy operation from input tensor `x` to output tensor `y`:
```cpp
template <typename Problem_, typename Policy_>
struct TileCopyKernel
{
CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const
{
// 1. Create tensor views
// 2. Create tile windows
// 3. Iterate over N dimension tiles
// 4. Load, copy, and store data
}
};
```
### Step-by-Step Execution
1. **Tensor View Creation**:
```cpp
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
```
- Creates views for both input and output tensors
- Specifies vectorized access with `Vector_N` elements per load
2. **Tile Window Creation**:
```cpp
auto x_window = make_tile_window(x_m_n,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{iM, 0},
Policy::template MakeDRAMDistribution<Problem>());
```
- Creates windows into specific tiles of the tensors
- Each block processes one tile starting at `{iM, 0}`
- Tile distribution determines how threads access data
3. **N-Dimension Iteration**:
```cpp
index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N));
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
```
- If tensor N dimension > Block_Tile_N, multiple iterations are needed
- Each iteration processes one tile along N dimension
4. **Load-Store Operations**:
```cpp
dram_reg_tile dram_tile;
load_tile(dram_tile, x_window); // Load from global memory to registers
store_tile(y_window, dram_tile); // Store from registers to global memory
move_tile_window(x_window, {0, S::Block_Tile_N}); // Move to next N tile
move_tile_window(y_window, {0, S::Block_Tile_N});
```
### How Load/Store Works
1. **Load Tile**:
- Each thread loads its assigned elements based on tile distribution
- Vectorized loads enable efficient memory bandwidth utilization
- Data is distributed to per-thread register buffers
2. **Store Tile**:
- Each thread writes its assigned elements back to global memory
- Maintains the same distribution pattern as load
3. **Tile Window Movement**:
- Moves the window to the next tile along N dimension
- Enables processing of large tensors that don't fit in one tile
## Memory Access Patterns
### Vectorized Access
- Enabled by specifying vector length in tensor views
- Each thread loads/stores multiple contiguous elements in one operation
- Improves memory bandwidth utilization
### Thread Distribution
- Tile distribution encoding determines which threads access which elements
- Ensures all threads participate and no data is missed
- Enables memory coalescing for optimal performance
### Coordinate Transform (Embed)
- Maps multi-dimensional tensor indices to linear memory addresses
- Handles stride calculations automatically
- Enables efficient access to non-contiguous memory layouts

View File

@@ -0,0 +1,147 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include <cstring>
#include "copy_basic.hpp"
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "128", "m dimension")
.insert("n", "8", "n dimension")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision(fp16 or fp32)")
.insert("warmup", "50", "cold iter")
.insert("repeat", "100", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
using XDataType = DataType;
using YDataType = DataType;
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
// Create host tensors
ck_tile::HostTensor<XDataType> x_host({m, n}); // input matrix
ck_tile::HostTensor<YDataType> y_host_ref({m, n}); // reference output matrix
ck_tile::HostTensor<YDataType> y_host_dev({m, n}); // device output matrix
// Initialize input data with increasing values
ck_tile::half_t value = 1;
for(int i = 0; i < m; i++)
{
value = 1;
for(int j = 0; j < n; j++)
{
x_host(i, j) = value++;
}
}
// Allocate device memory
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
// Define tile configuration
using Vector = ck_tile::sequence<1, 4>; // vector size along M and N dimension
using WaveTile = ck_tile::sequence<64, 4>; // wave size along M and N dimension
using BlockWaves = ck_tile::sequence<4, 1>; // number of waves along M dimension
using BlockTile = ck_tile::sequence<512, 4>; // block size along M and N dimension
// Calculate grid size
ck_tile::index_t kGridSize =
ck_tile::integer_divide_ceil(m, BlockTile::at(ck_tile::number<0>{}));
std::cout << "grid size (number of blocks per grid) " << kGridSize << std::endl;
// Define kernel types
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, Vector>;
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
using Policy = ck_tile::TileCopyPolicy<Problem>;
using Kernel = ck_tile::ElementWiseTileCopyKernel<Problem, Policy>;
// using Kernel = ck_tile::TileCopyKernel<Problem, Policy>;
// using Kernel = ck_tile::TileCopyKernel_LDS<Problem, Policy>;
// question: Why do we not have a pipeline?
// answer: For basic copy operation, pipeline is not needed.
// we intentionally do not use pipeline for this example and let the kernel be composite of
// Problem and Policy
constexpr ck_tile::index_t kBlockSize = Shape::BlockSize;
// Print configuration information
std::cout << "block size (number of threads per block) " << kBlockSize << std::endl;
std::cout << "wave size (number of threads per wave) " << ck_tile::get_warp_size() << std::endl;
std::cout << "block waves (number of waves per block) " << BlockWaves::at(ck_tile::number<0>{})
<< " " << BlockWaves::at(ck_tile::number<1>{}) << std::endl;
std::cout << "block tile (number of elements per block) " << BlockTile::at(ck_tile::number<0>{})
<< " " << BlockTile::at(ck_tile::number<1>{}) << std::endl;
std::cout << "wave tile (number of elements per wave) " << WaveTile::at(ck_tile::number<0>{})
<< " " << WaveTile::at(ck_tile::number<1>{}) << std::endl;
std::cout << "vector (number of elements per thread) " << Vector::at(ck_tile::number<0>{})
<< " " << Vector::at(ck_tile::number<1>{}) << std::endl;
std::cout << "WaveRepetitionPerBlock_M = " << Shape::WaveRepetitionPerBlock_M << " --> ("
<< Shape::Block_Tile_M << "/" << Shape::Waves_Per_Block_M << "*" << Shape::Wave_Tile_M
<< ")" << std::endl;
std::cout << "WaveRepetitionPerBlock_N = " << Shape::WaveRepetitionPerBlock_N << " --> ("
<< Shape::Block_Tile_N << "/" << Shape::Waves_Per_Block_N << "*" << Shape::Wave_Tile_N
<< ")" << std::endl;
// Launch kernel
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, warmup, repeat, 1},
ck_tile::make_kernel<kBlockSize, 1>(Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
m,
n));
// Calculate and print performance metrics
std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m * n;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
bool pass = true;
if(do_validation)
{
// Copy results back to host
y_buf.FromDevice(y_host_dev.mData.data());
// Use exact equality (tolerance = 0) for copy operations since copy should be exact
pass = ck_tile::check_err(y_host_dev, x_host, "Error: Copy operation failed!", 0.0, 0.0);
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
// Print results for debugging
// std::cout << "Input matrix (x_host):" << std::endl;
// std::cout << x_host << std::endl;
// std::cout << "Output matrix (y_host_dev):" << std::endl;
// std::cout << y_host_dev << std::endl;
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
if(arg_parser.get_str("prec") == "fp16")
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
else
return run<float>(arg_parser) ? 0 : -2;
}

View File

@@ -0,0 +1,369 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
namespace ck_tile {
/**
* @brief Tile copy shape configuration
*
* @tparam BlockWaves Number of waves along seq<M, N>
* @tparam BlockTile Block size, seq<M, N>
* @tparam WaveTile Wave size, seq<M, N>
* @tparam Vector Contiguous elements (vector size) along seq<M, N>
*/
template <typename BlockWaves, typename BlockTile, typename WaveTile, typename Vector>
struct TileCopyShape
{
// Vector dimensions for memory operations
static constexpr index_t Vector_M = Vector::at(number<0>{});
static constexpr index_t Vector_N = Vector::at(number<1>{});
// Wave tile dimensions
static constexpr index_t Wave_Tile_M = WaveTile::at(number<0>{});
static constexpr index_t Wave_Tile_N = WaveTile::at(number<1>{});
// Block tile dimensions
static constexpr index_t Block_Tile_M = BlockTile::at(number<0>{});
static constexpr index_t Block_Tile_N = BlockTile::at(number<1>{});
// Waves per block configuration
static constexpr index_t Waves_Per_Block_M = BlockWaves::at(number<0>{});
static constexpr index_t Waves_Per_Block_N = BlockWaves::at(number<1>{});
// Calculate wave repetition to cover entire block tile
static constexpr index_t WaveRepetitionPerBlock_M =
Block_Tile_M / (Waves_Per_Block_M * Wave_Tile_M);
static constexpr index_t WaveRepetitionPerBlock_N =
Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N);
// Hardware configuration
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t BlockSize = Waves_Per_Block_M * Waves_Per_Block_N * WaveSize;
// Configuration validation
static_assert(Block_Tile_M > 0 && Block_Tile_N > 0, "Block tile dimensions must be positive");
static_assert(Wave_Tile_M > 0 && Wave_Tile_N > 0, "Wave tile dimensions must be positive");
static_assert(Vector_M > 0 && Vector_N > 0, "Vector dimensions must be positive");
static_assert(Waves_Per_Block_M > 0 && Waves_Per_Block_N > 0,
"Waves per block must be positive");
static_assert(Waves_Per_Block_M * Wave_Tile_M > 0,
"Invalid wave configuration for M dimension");
static_assert(Waves_Per_Block_N * Wave_Tile_N > 0,
"Invalid wave configuration for N dimension");
// Ensure wave tile dimensions align with wave size
static_assert(Wave_Tile_M / Vector_M * Wave_Tile_N / Vector_N == WaveSize,
"(Wave_Tile_M/Vector_M) * (Wave_Tile_N/Vector_N) != WaveSize");
};
/**
* @brief Problem definition for tile copy operation
*/
template <typename XDataType_, typename BlockShape_>
struct TileCopyProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
};
/**
* @brief Policy for tile copy operation
*/
template <typename Problem_>
struct TileCopyPolicy
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
/**
* @brief Create DRAM distribution for optimal memory access
*/
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution()
{
using S = typename Problem::BlockShape;
constexpr index_t wave_size = S::WaveSize;
constexpr index_t block_size = S::BlockSize;
// Distribution calculation to ensure all threads participate
constexpr index_t N1 = S::Vector_N; // Elements per thread along N
constexpr index_t N0 = S::Block_Tile_N / N1; // Threads needed along N
constexpr index_t M2 = wave_size / N0; // Threads per wave along M
constexpr index_t M1 = block_size / wave_size; // Waves possible along M
constexpr index_t M0 = S::Block_Tile_M / (M1 * M2); // Wave iterations along M
// Validate complete coverage
static_assert(M0 * M1 * M2 * N0 * N1 == S::Block_Tile_M * S::Block_Tile_N,
"Tile distribution must cover entire block tile");
constexpr auto outer_encoding =
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{};
return make_static_tile_distribution(outer_encoding);
}
};
/**
* @brief Direct copy kernel from global memory to global memory
*/
template <typename Problem_, typename Policy_>
struct TileCopyKernel
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
using Policy = ck_tile::remove_cvref_t<Policy_>;
CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const
{
using S = typename Problem::BlockShape;
// Calculate tile block origin and validate bounds
// Use __builtin_amdgcn_readfirstlane to broadcast the same value to all threads in a wave
// This saves VGPR usage by avoiding per-thread storage of the same value
const auto tile_block_origin_m =
__builtin_amdgcn_readfirstlane(get_block_id() * S::Block_Tile_M);
if(tile_block_origin_m >= M)
{
return; // Early exit for out-of-bounds blocks
}
// Create tensor views for input and output
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
// Create tile windows with DRAM distribution
auto x_window =
make_tile_window(x_m_n,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
auto y_window =
make_tile_window(y_m_n,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
// Calculate iterations needed to cover N dimension
// Note: This kernel uses data parallelism only in the M dimension.
// Each block processes one tile in M dimension, but iterates through N dimension tiles.
// This design choice is for simplicity and to avoid complex tile distribution.
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N));
// Get tile distribution for register tensor
auto DramTileDist = x_window.get_tile_distribution();
using dram_reg_tile = decltype(make_static_distributed_tensor<XDataType>(DramTileDist));
// Main copy loop - processes N dimension tiles sequentially within each block
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
dram_reg_tile dram_tile;
// Direct copy implementation
load_tile(dram_tile, x_window);
store_tile(y_window, dram_tile);
// Move to next N tile
move_tile_window(x_window, {0, S::Block_Tile_N});
move_tile_window(y_window, {0, S::Block_Tile_N});
}
}
};
/**
* @brief Element-wise copy kernel for data transformation scenarios
*
* This kernel performs element-wise copy operations, allowing for data transformation
* during the copy process. Useful when data needs to be processed or converted
* between different formats.
*/
template <typename Problem_, typename Policy_>
struct ElementWiseTileCopyKernel
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
using Policy = ck_tile::remove_cvref_t<Policy_>;
CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const
{
using S = typename Problem::BlockShape;
// Calculate block origin and validate bounds
// Use __builtin_amdgcn_readfirstlane to broadcast the same value to all threads in a wave
// This saves VGPR usage by avoiding per-thread storage of the same value
const auto tile_block_origin_m =
__builtin_amdgcn_readfirstlane(get_block_id() * S::Block_Tile_M);
if(tile_block_origin_m >= M)
{
return; // Early exit for out-of-bounds blocks
}
// Create tensor views for input and output
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
// Create tile windows with DRAM distribution
auto x_window =
make_tile_window(x_m_n,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
auto y_window =
make_tile_window(y_m_n,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
// Calculate iterations needed to cover N dimension
// Note: This kernel uses data parallelism only in the M dimension.
// Each block processes one tile in M dimension, but iterates through N dimension tiles.
// This design choice is for simplicity and to avoid complex tile distribution.
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N));
// Main element-wise copy loop - processes N dimension tiles sequentially within each block
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
// Element-wise copy implementation for data transformation
const auto xa = load_tile(x_window);
auto y_compute = load_tile(y_window);
constexpr auto spans = decltype(xa)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
const auto x = ck_tile::type_convert<XDataType>(xa[i_j_idx]);
y_compute(i_j_idx) = x;
});
});
store_tile(y_window, y_compute);
// Move to next N tile
move_tile_window(x_window, {0, S::Block_Tile_N});
move_tile_window(y_window, {0, S::Block_Tile_N});
}
}
};
/**
* @brief LDS-based copy kernel for data processing scenarios
*
* This kernel copies data from global memory to LDS and then to global memory,
* useful when data needs to be processed or transformed during the copy operation.
*/
template <typename Problem_, typename Policy_>
struct TileCopyKernel_LDS
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
using Policy = ck_tile::remove_cvref_t<Policy_>;
CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const
{
using S = typename Problem::BlockShape;
// Calculate block origin and validate bounds
// Use __builtin_amdgcn_readfirstlane to broadcast the same value to all threads in a wave
// This saves VGPR usage by avoiding per-thread storage of the same value
const auto tile_block_origin_m =
__builtin_amdgcn_readfirstlane(get_block_id() * S::Block_Tile_M);
if(tile_block_origin_m >= M)
{
return; // Early exit for out-of-bounds blocks
}
// LDS buffer allocation
__shared__ XDataType x_lds_buffer[S::Block_Tile_M * S::Block_Tile_N];
// LDS tensor descriptor and view
const auto x_lds_descriptor =
make_naive_tensor_descriptor(make_tuple(S::Block_Tile_M, S::Block_Tile_N),
make_tuple(S::Block_Tile_N, 1),
number<S::Vector_N>{},
number<1>{});
auto x_lds_view = make_tensor_view<address_space_enum::lds>(x_lds_buffer, x_lds_descriptor);
// LDS windows with different distributions for optimal access patterns
auto x_lds_write_window = make_tile_window(
x_lds_view, make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}), {0, 0});
auto x_lds_read_window =
make_tile_window(x_lds_view,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{0, 0},
Policy::template MakeDRAMDistribution<Problem>());
// Global memory tensor views
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
// Global memory tile windows
auto x_window =
make_tile_window(x_m_n,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{tile_block_origin_m, 0},
Policy::template MakeDRAMDistribution<Problem>());
auto y_window =
make_tile_window(y_m_n,
make_tuple(number<S::Block_Tile_M>{}, number<S::Block_Tile_N>{}),
{tile_block_origin_m, 0});
// Calculate iterations needed to cover N dimension
// Note: This kernel uses data parallelism only in the M dimension.
// Each block processes one tile in M dimension, but iterates through N dimension tiles.
// This design choice is for simplicity and to avoid complex tile distribution.
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N));
// Main copy loop with LDS staging - processes N dimension tiles sequentially within each
// block
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
// Global memory to LDS
auto dram_tile = load_tile(x_window);
store_tile(x_lds_write_window, dram_tile);
// Synchronize LDS access
block_sync_lds();
// LDS to global memory
auto lds_tile = load_tile(x_lds_read_window);
store_tile(y_window, lds_tile);
// Move to next N tile
move_tile_window(x_window, {0, S::Block_Tile_N});
move_tile_window(y_window, {0, S::Block_Tile_N});
}
}
};
} // namespace ck_tile

View File

@@ -23,3 +23,4 @@ add_subdirectory(20_grouped_convolution)
add_subdirectory(21_elementwise)
add_subdirectory(35_batched_transpose)
add_subdirectory(38_block_scale_gemm)
add_subdirectory(39_copy)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -83,4 +83,14 @@ CK_TILE_BINARY_OP(<=)
#undef CK_TILE_LEFT_UNARY_OP
#undef CK_TILE_BINARY_OP
template <typename T>
struct is_constant : std::false_type
{
};
template <auto v>
struct is_constant<constant<v>> : std::true_type
{
};
template <typename T>
inline constexpr bool is_constant_v = is_constant<T>::value;
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -259,6 +259,7 @@ struct tensor_adaptor
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
template <index_t Internal = 0>
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides(
const array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
const array<index_t, ndim_hidden_>& guaranteed_vector_strides)
@@ -266,7 +267,9 @@ struct tensor_adaptor
auto vector_lengths = guaranteed_vector_lengths;
auto vector_strides = guaranteed_vector_strides;
static_for<0, get_num_of_transform(), 1>{}([&](auto itran) {
static_for<0,
Internal ? std::min(Internal, get_num_of_transform()) : get_num_of_transform(),
1>{}([&](auto itran) {
constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran);
constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran);
@@ -298,11 +301,16 @@ struct tensor_adaptor
set_container_subset(vector_lengths, up_dims, up_vector_lengths);
set_container_subset(vector_strides, up_dims, up_vector_strides);
});
constexpr auto top_dims = TopDimensionHiddenIds{};
return make_tuple(get_container_subset(vector_lengths, top_dims),
get_container_subset(vector_strides, top_dims));
if constexpr(Internal > 0)
{
return make_tuple(vector_lengths, vector_strides);
}
else
{
constexpr auto top_dims = TopDimensionHiddenIds{};
return make_tuple(get_container_subset(vector_lengths, top_dims),
get_container_subset(vector_strides, top_dims));
}
}
private:

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -133,9 +133,10 @@ struct tensor_descriptor : public tensor_adaptor<Transforms,
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
template <index_t Internal = 0>
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides()
{
return Base::get_top_dimension_safe_vector_length_strides(
return Base::template get_top_dimension_safe_vector_length_strides<Internal>(
to_array<index_t, ndim_hidden_>(GuaranteedVectorLengths{}),
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
}
@@ -377,12 +378,29 @@ make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
constexpr index_t first_dim_length = []() {
if constexpr(is_constant_v<remove_cvref_t<decltype(element_space_size)>>)
return decltype(element_space_size)::value;
else
return -1;
}();
using last_t = remove_cvref_t<decltype(lengths.template get<N - 1>())>;
constexpr index_t last_dim_length = []() {
if constexpr(is_constant_v<last_t>)
return std::max(last_t::value, GuaranteedLastDimensionVectorLength);
else
return -1;
}();
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
typename sequence_merge<sequence<first_dim_length>,
typename uniform_sequence_gen<N - 1, -1>::type,
sequence<last_dim_length>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type, sequence<1>>::type;
typename sequence_merge<sequence<1>,
typename uniform_sequence_gen<N - 1, -1>::type,
sequence<1>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,

View File

@@ -35,8 +35,6 @@ struct Add
return type_convert<T>(y_ + x_);
}
static constexpr bool requires_special_combine = false;
};
struct SquareAdd
@@ -64,28 +62,6 @@ struct SquareAdd
float x_ = type_convert<float>(x);
return type_convert<T>(y_ + (x_ * x_));
}
// For combining partial results
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
CK_TILE_HOST_DEVICE constexpr T combine_partial_results(const T& partial1,
const T& partial2) const
{
return partial1 + partial2; // Just add the partial sums, don't square again
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
CK_TILE_HOST_DEVICE constexpr T combine_partial_results(T& partial1, T& partial2) const
{
float partial1_ = type_convert<float>(partial1);
float partial2_ = type_convert<float>(partial2);
return type_convert<T>(partial1_ + partial2_);
}
static constexpr bool requires_special_combine = true;
};
struct Max
@@ -109,8 +85,6 @@ struct Max
{
return max(y, x);
}
static constexpr bool requires_special_combine = false;
};
struct AbsMax
@@ -134,8 +108,6 @@ struct AbsMax
{
return max(y, abs(x));
}
static constexpr bool requires_special_combine = false;
};
} // namespace ReduceOp

View File

@@ -65,6 +65,11 @@ inline bool is_gfx12_supported()
return get_device_name() == "gfx1200" || get_device_name() == "gfx1201";
}
inline bool is_load_tr_supported()
{
// Check if load transpose is supported.
return get_device_name() == "gfx950";
}
} // namespace ck_tile
#endif

View File

@@ -25,8 +25,10 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp"

View File

@@ -62,6 +62,12 @@ struct FmhaBwdDQDKDVKernel
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
#if defined(__gfx950__)
static constexpr bool kIsAvialable = true;
#else
static constexpr bool kIsAvialable = !kUseTrLoad;
#endif
// clang-format off
template <typename T> struct t2s;
@@ -99,7 +105,7 @@ struct FmhaBwdDQDKDVKernel
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) +
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" );
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload");
#undef _SS_
#undef _TS_
// clang-format on
@@ -298,6 +304,24 @@ struct FmhaBwdDQDKDVKernel
using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>;
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <typename... Ts>
CK_TILE_HOST static constexpr Kargs
MakeKargs(Ts... args, const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <typename... Ts>
CK_TILE_HOST static constexpr Kargs
MakeKargs(Ts... args, const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
@@ -466,248 +490,6 @@ struct FmhaBwdDQDKDVKernel
return kargs;
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc,
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc,
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
@@ -854,208 +636,6 @@ struct FmhaBwdDQDKDVKernel
return kargs;
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
@@ -1082,6 +662,12 @@ struct FmhaBwdDQDKDVKernel
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
if constexpr(kIsAvialable)
run_(std::move(kargs));
}
CK_TILE_DEVICE void run_(Kargs kargs) const
{
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
@@ -1282,62 +868,33 @@ struct FmhaBwdDQDKDVKernel
{0, 0});
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
if constexpr(kIsDeterministic)
{
AccDataType* dq_acc_ptr =
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
batch_offset_dq_acc;
AccDataType* dq_acc_ptr = reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) + [&]() {
if constexpr(kIsDeterministic)
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
batch_offset_dq_acc;
else
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
batch_offset_dq_acc;
}();
auto dq_acc_dram = [&]() {
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
}();
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
}
else
{
AccDataType* dq_acc_ptr =
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
batch_offset_dq_acc;
auto dq_acc_dram = [&]() {
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
}();
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
}
constexpr auto DstInMemOp = conditional_expr<kIsDeterministic>(
memory_operation_enum::set, memory_operation_enum::atomic_add);
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
const auto dq_acc_dram = pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
}();
auto lse_dram_window =

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -54,6 +54,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
static_assert(!kUseTrLoad, "This pipeline does not use trload!");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this

View File

@@ -54,6 +54,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
static_assert(!kUseTrLoad, "This pipeline does not use trload!");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -654,9 +656,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}();
// STAGE 3, P^T@OGrad^T Gemm1
Policy::template PTFromGemm0CToGemm1A<Problem,
decltype(pt_reg_tensor),
decltype(p_gemm)>(pt_reg_tensor, p_gemm);
Policy::template PTFromGemm0CToGemm1A<Problem>(pt_reg_tensor, p_gemm);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
auto qt_reg_tensor = load_tile(qt_lds_read_window);
@@ -728,9 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
// STAGE 6, SGrad^T@Q^T Gemm3
const auto ds_gemm = cast_tile<GemmDataType>(ds);
Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor),
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
Policy::template SGradTFromGemm2CToGemm3A<Problem>(dst_reg_tensor, ds_gemm);
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);

View File

@@ -6,22 +6,30 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp"
namespace ck_tile {
template <typename Problem>
template <typename Problem, typename Policy>
class BlockFmhaBwdDQDKDVPipelineSelector
{
static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV;
public:
using type = std::conditional_t<has_dpad,
BlockFmhaBwdDQDKDVPipelineKRKTRVR<Problem>,
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<Problem>>;
template <typename... TS>
using type_ =
std::conditional_t<Problem::kUseTrLoad,
BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR<TS...>,
std::conditional_t<has_dpad,
BlockFmhaBwdDQDKDVPipelineKRKTRVR<TS...>,
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<TS...>>>;
using type = std::conditional_t<std::is_same_v<Policy, void>, //
type_<Problem>,
type_<Problem, Policy>>;
};
template <typename Problem>
class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector<Problem>::type
template <typename Problem, typename Policy = void>
class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector<Problem, Policy>::type
{
public:
static constexpr const char* name = "auto";

View File

@@ -0,0 +1,760 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
// using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
static_assert(kUseTrLoad, "This pipeline uses trload!");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad = 1;
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias = 1;
static constexpr const char* name = "trload_kr_ktr_vr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static LSEDataType get_validated_lse(const LSEDataType raw_lse)
{
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || FmhaMask::IsMasking)
return (raw_lse == -numeric<LSEDataType>::infinity()) //
? type_convert<LSEDataType>(0.f)
: raw_lse;
else
return raw_lse;
};
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_DEVICE auto operator()( //
const QDramBlockWindowTmp& q_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
FmhaMask mask,
PositionEncoding position_encoding,
float raw_scale,
float scale,
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
FmhaDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
// init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
// K, HBM ->LDS ->Reg
auto k_dram_window =
make_tile_window(Policy::template TransformXDramTensorView<KDataType>(
k_dram_block_window_tmp.get_bottom_tensor_view()),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
const auto k_origin = k_dram_window.get_window_origin();
// Early termination
const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return make_tuple(dk_acc, dv_acc);
}
}
// LDS allocation
const auto smem_ptr_ =
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType* __restrict__>(smem_ptr_);
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr0 = reinterpret_cast<QDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr1 = reinterpret_cast<QDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto lse_lds_ptr = reinterpret_cast<LSEDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
const auto d_lds_ptr = reinterpret_cast<DDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>());
const auto ds_lds_ptr = reinterpret_cast<GemmDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType* __restrict__>(ds_lds_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto v_dram_window =
make_tile_window(Policy::template TransformXDramTensorView<VDataType>(
v_dram_block_window_tmp.get_bottom_tensor_view()),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVDramTileDistribution<Problem>());
auto v_lds = make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
//------------------------------------------------------------------
// KT, HBM -> LDS --trload-->Reg
async_load_tile(k_lds_write_window, k_dram_window);
async_load_tile(v_lds_write_window, v_dram_window);
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto k_lds_read = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsReadBlockDescriptor<Problem>());
auto k_lds_read_window =
make_tile_window(k_lds_read,
make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(),
Policy::template MakeKRegBlockDescriptor<Problem>());
auto k_reg_tensor = load_tile(k_lds_read_window);
auto kt_lds_read_window =
make_tile_window(k_lds_read,
make_tuple(number<kN0>{}, number<kK0>{}),
{0, 0},
Policy::template MakeKTRegBlockDescriptor<Problem>());
auto kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
auto v_lds_read = make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeVLdsReadBlockDescriptor<Problem>());
auto v_lds_read_window =
make_tile_window(v_lds_read,
make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(),
Policy::template MakeVRegBlockDescriptor<Problem>());
auto v_reg_tensor = load_tile(v_lds_read_window);
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
//---------------------------- Loop Load in ----------------------------//
// Q: HBM -->LDS
auto q_dram_window =
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
q_dram_block_window_tmp.get_bottom_tensor_view()),
q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0},
Policy::template MakeQDramTileDistribution<Problem>());
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr0, Policy::template MakeQLdsWriteBlockDescriptor<Problem>());
auto q_lds_write_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto q_lds_read = make_tensor_view<address_space_enum::lds>(
q_lds_ptr0, Policy::template MakeQLdsReadBlockDescriptor<Problem>());
auto q_lds_read_window =
make_tile_window(q_lds_read,
make_tuple(number<kM0>{}, number<kK0>{}),
q_lds_write_window.get_window_origin(),
Policy::template MakeQRegSliceBlockDescriptor<Problem>());
auto qt_lds_read_window =
make_tile_window(q_lds_read,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
{0, 0},
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
// dO: HBM ->LDS ---load--> Reg
// dOT: \-loadtr-> Reg
auto do_dram_window =
make_tile_window(Policy::template TransformXDramTensorView<OGradDataType>(
do_dram_block_window_tmp.get_bottom_tensor_view()),
do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0},
Policy::template MakeOGradDramTileDistribution<Problem>());
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr0, Policy::template MakeOGradLdsWriteBlockDescriptor<Problem>());
auto do_lds_write_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto do_lds_read = make_tensor_view<address_space_enum::lds>(
do_lds_ptr0, Policy::template MakeOGradLdsReadBlockDescriptor<Problem>());
auto do_lds_read_window =
make_tile_window(do_lds_read,
make_tuple(number<kM0>{}, number<kK2>{}),
do_lds_write_window.get_window_origin(),
Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
auto dot_lds_read_window =
make_tile_window(do_lds_read,
make_tuple(number<kM0>{}, number<kK2>{}),
{0, 0},
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
// dS: Reg -> Reg -> LDS
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// transform it to make it from col-major to row-major; prepared for load_tile_transpose
auto ds_lds_t = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem, true>());
auto ds_lds_read_window =
make_tile_window(ds_lds_t,
make_tuple(number<kM0>{}, number<kK4>{}),
{0, 0},
Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
// Bias: HBM ->Reg ->Reg ->LDS
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
bias_dram_block_window_tmp.get_bottom_tensor_view()),
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})},
Policy::template MakeBiasTileDistribution<Problem>());
auto bias_lds = make_tensor_view<address_space_enum::lds>(
bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor<Problem>());
auto bias_lds_write_window =
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto bias_lds_read = make_tensor_view<address_space_enum::lds>(
bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor<Problem>());
auto bias_s_lds_read_window =
make_tile_window(bias_lds_read,
make_tuple(number<kM0>{}, number<kN0>{}),
bias_lds_write_window.get_window_origin(),
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window(
lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window(
d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start);
// BiasGrad
// Reg ->LDS ->Reg ->HBM
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
auto dbias_dram_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto dbias_lds_read_window =
make_tile_window(bias_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
// ----------------------------Loop write out------------------------------//
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4;
clear_tile(dv_acc);
clear_tile(dk_acc);
__builtin_amdgcn_sched_barrier(0);
decltype(load_tile(q_lds_read_window)) q_reg_tensor;
decltype(load_tile(lse_lds_read_window)) lse;
decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor;
decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor_next;
decltype(load_tile(do_lds_read_window)) do_reg_tensor;
decltype(load_tile_transpose(dot_lds_read_window)) dot_reg_tensor;
decltype(load_tile(d_lds_read_window)) d;
decltype(load_tile_transpose(qt_lds_read_window)) qt_reg_tensor;
decltype(gemm_0.MakeCBlockTile()) s_acc, p;
decltype(gemm_2.MakeCBlockTile()) dp_acc, ds;
decltype(gemm_4.MakeCBlockTile()) dq_acc;
decltype(load_tile(lse_dram_window)) lse_block_tile;
decltype(load_tile(d_dram_window)) d_block_tile;
index_t i_total_bodys = 0;
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
const bool is_even = (i_total_bodys % 2 == 0);
QDataType* const __restrict__ q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
QDataType* const __restrict__ q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
OGradDataType* const __restrict__ do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
OGradDataType* const __restrict__ do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
constexpr bool is_prologue = is_prologue_.value;
constexpr bool is_epilogue = is_epilogue_.value;
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
constexpr bool is_main_body = is_prologue && is_epilogue;
if constexpr(is_prologue)
{
q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
async_load_tile(q_lds_write_window, q_dram_window);
move_tile_window(q_dram_window, {kM0, 0});
lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
async_load_tile(do_lds_write_window, do_dram_window);
move_tile_window(do_dram_window, {kM0, 0});
d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
}
if constexpr(is_epilogue)
{
// STAGE 1, Q@K Gemm0
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
}
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm0();
__builtin_amdgcn_sched_barrier(0);
if constexpr(is_epilogue)
{
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
async_load_tile(bias_lds_write_window, bias_dram_window);
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
auto bias_s_tile = load_tile(bias_s_lds_read_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
},
s_acc,
bias_s_tile);
move_tile_window(bias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
{
bool need_perpixel_check = mask.IsEdgeTile(
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
constexpr auto p_spans = decltype(p)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
else
p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
});
});
if constexpr(FmhaDropout::IsDropout)
{
dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
}
const auto p_gemm = [&]() { // dropout / type conversion
if constexpr(FmhaDropout::IsDropout)
{
return tile_elementwise_in(
[](const auto& x) {
return type_convert<GemmDataType>(x > 0.f ? x : 0.f);
},
p);
}
else
{
return cast_tile<GemmDataType>(p);
}
}();
// STAGE 4, OGrad@V Gemm2
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr);
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
// STAGE 3, P^T@OGrad^T Gemm1
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakePTRegSliceBlockDescriptor<Problem>());
pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer();
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
}
block_sync_lds();
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm12();
__builtin_amdgcn_sched_barrier(0);
if constexpr(is_prologue)
{
store_tile(lse_lds_write_window, lse_block_tile);
store_tile(d_lds_write_window, d_block_tile);
}
if constexpr(is_epilogue)
{
// STAGE 5, P^T(PGrad^T - D)
constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = p[i_j_idx] >= 0;
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
? (dp_acc[i_j_idx] - d[i_idx])
: d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbias = [&]() {
if constexpr(FmhaDropout::IsDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
ds);
}
else
{
return cast_tile<BiasGradDataType>(ds);
}
}();
store_tile(bias_lds_write_window, dbias);
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_window, dbias_tile);
move_tile_window(dbias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
}
}
if constexpr(is_epilogue)
{
// STAGE 6, SGrad^T@Q^T Gemm3
const auto ds_gemm = cast_tile<GemmDataType>(ds);
auto dst_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
store_tile(ds_lds_window, ds_gemm);
}
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
if constexpr(is_prologue)
{
q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
q_reg_tensor = load_tile(q_lds_read_window);
lse = load_tile(lse_lds_read_window);
}
if constexpr(is_epilogue)
{
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm3();
__builtin_amdgcn_sched_barrier(0);
if constexpr(is_epilogue)
{
// STAGE7 SGrad@K^T Gemm4
clear_tile(dq_acc);
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
auto kt_reg_tensor_slice = get_slice_tile( //
kt_reg_tensor,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
}
});
move_tile_window(ds_lds_read_window, {-kN0, 0});
}
block_sync_lds();
if constexpr(is_prologue)
{
do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
do_reg_tensor = load_tile(do_lds_read_window);
d = load_tile(d_lds_read_window);
}
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm4();
if constexpr(is_epilogue)
{
// QGrad Scale
if constexpr(FmhaDropout::IsDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
if constexpr(kIsDeterministic)
{
store_tile(dq_dram_window, dq_acc);
}
else
{
update_tile(dq_dram_window, dq_acc);
}
move_tile_window(dq_dram_window, {kM0, 0});
}
i_total_bodys += 1;
};
main_body(std::true_type{}, std::false_type{});
// Hot loop
if(num_total_loop > 1)
{
do
{
main_body(std::true_type{}, std::true_type{});
i_total_loops += 1;
seqlen_q_step += kM0;
} while(i_total_loops < num_total_loop - 1);
}
main_body(std::false_type{}, std::true_type{});
// Results Scale
if constexpr(FmhaDropout::IsDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc);
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
return make_tuple(dk_acc, dv_acc);
}
};
} // namespace ck_tile

View File

@@ -64,7 +64,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
CK_TILE_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{
using GemmProblem =
BlockGemmProblem<typename Problem::GemmDataType,
@@ -77,13 +77,19 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm = WarpGemmDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>;
using WarpGemm =
WarpGemmDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true,
false, // SwizzleAccess
false, // UseStructuredSparsity
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
? WGAttrNumAccessEnum ::Double
: WGAttrNumAccessEnum ::Single>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
@@ -143,13 +149,19 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
using WarpGemm = WarpGemmDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
using WarpGemm =
WarpGemmDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true,
false, // SwizzleAccess
false, // UseStructuredSparsity
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
? WGAttrNumAccessEnum ::Double
: WGAttrNumAccessEnum ::Single>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,

View File

@@ -27,6 +27,7 @@ template <typename QDataType_,
bool kIsDeterministic_,
typename FmhaMask_,
typename FmhaDropout_,
bool kUseTrLoad_,
typename Traits_>
struct BlockFmhaBwdPipelineProblem
{
@@ -53,6 +54,7 @@ struct BlockFmhaBwdPipelineProblem
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
// attributes from traits
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -11,7 +11,9 @@ namespace ck_tile {
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
template <typename Problem_,
typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy,
bool TransposeC_ = false>
struct BlockGemmARegBRegCRegV1
{
private:
@@ -44,8 +46,9 @@ struct BlockGemmARegBRegCRegV1
};
public:
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
static constexpr bool TransposeC = TransposeC_;
using Traits = GemmTraits_<Problem, Policy>;
@@ -131,6 +134,7 @@ struct BlockGemmARegBRegCRegV1
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
{
using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;
if constexpr(UseDefaultScheduler)
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
@@ -138,7 +142,7 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
c_distr_ys_major,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
@@ -152,7 +156,7 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
c_distr_ys_major,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
@@ -172,25 +176,19 @@ struct BlockGemmARegBRegCRegV1
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode();
constexpr auto b_block_dstr_encode = MakeBBlockDistributionEncode();
constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode();
// check ABC-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"A distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"B distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
@@ -219,7 +217,6 @@ struct BlockGemmARegBRegCRegV1
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
@@ -227,16 +224,16 @@ struct BlockGemmARegBRegCRegV1
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
using c_iter_idx = std::
conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
@@ -244,7 +241,7 @@ struct BlockGemmARegBRegCRegV1
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
@@ -254,6 +251,7 @@ struct BlockGemmARegBRegCRegV1
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;
if constexpr(UseDefaultScheduler)
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
@@ -261,7 +259,7 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
c_distr_ys_major,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -277,7 +275,7 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
c_distr_ys_major,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(

View File

@@ -183,16 +183,7 @@ struct BlockReduce2dSync
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
// For reduce, use combine_partial_results for operations that require it
if constexpr(ReduceFunc::requires_special_combine)
{
v_local = reduce_func.combine_partial_results(v_local, v_remote);
}
else
{
v_local = reduce_func(v_local, v_remote);
}
v_local = reduce_func(v_local, v_remote);
});
}
});
@@ -309,16 +300,7 @@ struct BlockReduce2dCrossWarpSync
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
constexpr auto i_1 = number<i_1_n1 + 1>{};
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
// For reduce, use combine_partial_results for operations that require it
if constexpr(ReduceFunc::requires_special_combine)
{
v_local = reduce_func.combine_partial_results(v_local, v_remote);
}
else
{
v_local = reduce_func(v_local, v_remote);
}
v_local = reduce_func(v_local, v_remote);
});
y_tensor.get_thread_buffer()(i_0) = v_local;

View File

@@ -189,7 +189,9 @@ struct Reduce
/// @note Requirements:
/// - y_continous_dim % ThreadTile_N == 0 (for proper thread distribution)
/// - input_strides[-1] == 1 (for contiguous memory access)
CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, auto input_strides)
template <typename InputStrides>
CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim,
InputStrides input_strides)
{
using S = typename Problem::BlockShape;

View File

@@ -308,20 +308,8 @@ using TestConfig_F32_Max = std::tuple<float,
Shape1_WarpTile,
Shape1_ThreadTile>;
using TestConfig_F32_SquareAdd = std::tuple<float,
float,
float,
ck_tile::ReduceOp::SquareAdd,
Shape1_BlockWarps,
Shape1_BlockTile,
Shape1_WarpTile,
Shape1_ThreadTile>;
using TestTypes = ::testing::Types<TestConfig_F32_Add,
TestConfig_F16_Add,
TestConfig_F32_CrossWarp,
TestConfig_F32_Max,
TestConfig_F32_SquareAdd>;
using TestTypes = ::testing::
Types<TestConfig_F32_Add, TestConfig_F16_Add, TestConfig_F32_CrossWarp, TestConfig_F32_Max>;
TYPED_TEST_SUITE(TestCkTileReduce, TestTypes);