diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a21634b7d..9c942a776d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added +* Added a basic copy kernel example and supporting documentation for new CK Tile developers. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data * Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. * Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). diff --git a/Dockerfile.aiter b/Dockerfile.aiter index f6e66f460a..245e39fb75 100644 --- a/Dockerfile.aiter +++ b/Dockerfile.aiter @@ -1,16 +1,20 @@ ARG BASE_DOCKER="rocm/pytorch:latest" FROM $BASE_DOCKER -RUN groupadd -f render && \ +ARG AITER_BRANCH="main" +ARG CK_AITER_BRANCH="develop" +RUN groupadd -g 109 render && \ + usermod -u 1001 jenkins && \ + groupmod -g 1001 jenkins && \ pip install pandas zmq einops && \ pip install numpy==1.26.2 && \ sudo mkdir /home/jenkins && \ sudo mkdir /home/jenkins/workspace && \ cd /home/jenkins/workspace && \ rm -rf aiter && \ - git clone --recursive https://github.com/ROCm/aiter.git && \ + git clone -b "$AITER_BRANCH" --recursive https://github.com/ROCm/aiter.git && \ cd aiter && \ rm -rf 3rdparty/composable_kernel/ && \ - git clone https://github.com/ROCm/composable_kernel.git 3rdparty/composable_kernel/ && \ + git clone -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git 3rdparty/composable_kernel/ && \ python3 setup.py develop && \ chown -R jenkins:jenkins /home/jenkins/workspace && \ chmod -R a+rwx /home/jenkins/workspace && \ diff --git a/Jenkinsfile b/Jenkinsfile index a2b7d2f4b7..ed8b9c9d46 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -190,7 +190,7 @@ def buildDocker(install_prefix){ } else if(params.RUN_AITER_TESTS){ image_name = "rocm/composable_kernel:ck_aiter" - dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter . " + dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . " } else{ dockerArgs = dockerArgs + " -f Dockerfile . " @@ -438,34 +438,6 @@ def cmake_build(Map conf=[:]){ echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." } } - if (params.RUN_CK_TILE_TRANSPOSE_TESTS){ - try{ - archiveArtifacts "perf_transpose_*.log" - if (arch_type == 1){ - stash includes: "perf_transpose_**_gfx90a.log", name: "perf_transpose_log_gfx90a" - } - else if (arch_type == 2){ - stash includes: "perf_transpose_**_gfx942.log", name: "perf_transpose_log_gfx942" - } - } - catch(Exception err){ - echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." - } - } - if (params.RUN_CK_TILE_GEMM_TESTS){ - try{ - archiveArtifacts "perf_tile_gemm_**.log" - if (arch == 1){ - stash includes: "perf_tile_gemm_**_gfx90a.log", name: "perf_tile_gemm_log_gfx90a" - } - else if (arch == 2){ - stash includes: "perf_tile_gemm_**_gfx942.log", name: "perf_tile_gemm_log_gfx942" - } - } - catch(Exception err){ - echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." - } - } } def buildHipClangJob(Map conf=[:]){ @@ -762,24 +734,6 @@ def process_results(Map conf=[:]){ echo "could not locate the FMHA performance logs: ${err.getMessage()}." } } - if (params.RUN_CK_TILE_TRANSPOSE_TESTS){ - try{ - unstash "perf_transpose_log_gfx942" - unstash "perf_transpose_log_gfx90a" - } - catch(Exception err){ - echo "could not locate the Transpose performance logs: ${err.getMessage()}." - } - } - if (params.RUN_CK_TILE_GEMM_TESTS){ - try{ - unstash "perf_tile_gemm_log_gfx942" - unstash "perf_tile_gemm_log_gfx90a" - } - catch(Exception err){ - echo "could not locate the GEMM performance logs: ${err.getMessage()}." - } - } if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){ // unstash deb packages unstash "packages" @@ -843,10 +797,11 @@ def run_aiter_tests(Map conf=[:]){ withDockerContainer(image: image, args: dockerOpts) { timeout(time: 45, unit: 'MINUTES'){ try{ - sh "python3 --version" sh "rocminfo" - sh "python3 ../aiter/op_tests/test_gemm_a8w8_blockscale.py" - //sh "python3 ../aiter/op_tests/test_mha.py" + sh "python3 --version" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" } catch(e){ echo "Throwing error exception while running AITER tests" @@ -861,7 +816,7 @@ def run_aiter_tests(Map conf=[:]){ } //launch develop branch daily jobs -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true @@ -941,14 +896,6 @@ pipeline { name: "RUN_CK_TILE_FMHA_TESTS", defaultValue: false, description: "Run the ck_tile FMHA tests (default: OFF)") - booleanParam( - name: "RUN_CK_TILE_TRANSPOSE_TESTS", - defaultValue: false, - description: "Run the ck_tile Transpose tests (default: OFF)") - booleanParam( - name: "RUN_CK_TILE_GEMM_TESTS", - defaultValue: false, - description: "Run the ck_tile GEMM tests (default: OFF)") booleanParam( name: "RUN_TILE_ENGINE_GEMM_TESTS", defaultValue: false, @@ -1009,6 +956,14 @@ pipeline { name: "RUN_AITER_TESTS", defaultValue: false, description: "Run AITER tests with latest CK develop branch (default: OFF)") + string( + name: 'aiter_branch', + defaultValue: 'main', + description: 'Specify which branch of AITER to use (default: main)') + string( + name: 'ck_aiter_branch', + defaultValue: 'develop', + description: 'Specify which branch of CK to test with AITER (default: develop)') } environment{ dbuser = "${dbuser}" @@ -1093,13 +1048,13 @@ pipeline { { parallel { - stage("Run AITER Tests on gfx90a") + stage("Run AITER Tests on gfx942") { when { beforeAgent true expression { params.RUN_AITER_TESTS.toBoolean() } } - agent{ label rocmnode("gfx90a")} + agent{ label rocmnode("gfx942")} steps{ run_aiter_tests() cleanWs() @@ -1198,94 +1153,6 @@ pipeline { } } } - stage("Run CK_TILE_TRANSPOSE Tests") - { - parallel - { - stage("Run CK_TILE_TRANSPOSE Tests on gfx90a") - { - when { - beforeAgent true - expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() } - } - agent{ label rocmnode("gfx90a") } - environment{ - setup_args = "NO_CK_BUILD" - execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 tile_example_batched_transpose && \ - cd ../ && - example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ - } - steps{ - buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) - cleanWs() - } - } - stage("Run CK_TILE_TRANSPOSE Tests on gfx942") - { - when { - beforeAgent true - expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() } - } - agent{ label rocmnode("gfx942") } - environment{ - setup_args = "NO_CK_BUILD" - execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ - make -j64 tile_example_batched_transpose && \ - cd ../ && - example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ - } - steps{ - buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) - cleanWs() - } - } - } - } - stage("Run CK_TILE_GEMM Tests") - { - parallel - { - stage("Run CK_TILE_GEMM Tests on gfx90a") - { - when { - beforeAgent true - expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() } - } - agent{ label rocmnode("gfx90a") } - environment{ - setup_args = "NO_CK_BUILD" - execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 tile_example_gemm_universal && \ - cd ../ && - example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ - } - steps{ - buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) - cleanWs() - } - } - stage("Run CK_TILE_GEMM Tests on gfx942") - { - when { - beforeAgent true - expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() } - } - agent{ label rocmnode("gfx942") } - environment{ - setup_args = "NO_CK_BUILD" - execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ - make -j64 tile_example_gemm_universal && \ - cd ../ && - example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ - } - steps{ - buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) - cleanWs() - } - } - } - } stage("Run TILE_ENGINE_GEMM Tests") { parallel diff --git a/README.md b/README.md index 29d3d4e85a..459e17d9a3 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa 4. Build the entire CK library: ```bash - make -j + make -j"$(nproc)" ``` 5. Install CK: @@ -213,4 +213,4 @@ script/uninstall_precommit.sh ``` If you need to temporarily disable pre-commit hooks, you can add the `--no-verify` option to the -`git commit` command. \ No newline at end of file +`git commit` command. diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 47cf6b3ad4..8ca917cb6c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -83,6 +83,7 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< {F_deterministic}, fmha_mask_{F_idx}, fmha_dropout_{F_idx}, + {F_trload}, fmha_bwd_trait_{F_idx}>; using fmha_bwd_pipeline_{F_idx} = ck_tile::BlockFmhaBwdDQDKDVPipeline; @@ -113,7 +114,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dbias}, {F_dpad}, {F_dvpad}, - {F_deterministic}>; + {F_deterministic}, + {F_trload}>; #include @@ -168,29 +170,35 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) template <> float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ + const bool has_load_tr = ck_tile::is_load_tr_supported(); float r = -1; {F_dispatch} return r; }} """ -FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} +FMHA_BWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ +{F_body} }} """ -FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ -{F_inner_dispatch} - }} + +FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_body} + }} +""" +FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_body} + }} """ -FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && - ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; - using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>; - r = fmha_bwd_(s, a); - return r; - }} +FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && + ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>; + r = fmha_bwd_(s, a); + return r; + }} """ # M0 size for 1d kernels (dot/convert) @@ -250,6 +258,7 @@ class FmhaBwdDQDKDVKernel: F_mode : str # value from MODE_MAP F_deterministic : str # mask_impl : str # + F_trload : str # @property def template(self) -> str: @@ -291,6 +300,7 @@ class FmhaBwdDQDKDVKernel: F_mask = get_mask_map(self.mask_impl)[self.F_mask], F_mode = MODE_MAP[self.F_mode], F_deterministic = BOOL_MAP[self.F_deterministic], + F_trload = BOOL_MAP[self.F_trload], ) @property @@ -324,6 +334,9 @@ class FmhaBwdDQDKDVKernel: if self.F_deterministic == 't' : n += '_deterministic' else: n += '_ndeterministic' + + if self.F_trload == 't' : n += '_trload' + else: n += '_ntrload' return n @property @@ -332,8 +345,8 @@ class FmhaBwdDQDKDVKernel: # TODO: design a more practical way to do it # this is current supported tile size. -def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': +def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str, tr_load: str) -> Optional[dict]: + if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': return { '32' : FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), '64' : FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), @@ -341,6 +354,10 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict # '160' : FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), '256' : FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), } + elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't': + return { + '128' : FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), + } else: return None @@ -573,6 +590,7 @@ class FmhaBwdApiTrait: dvpad : str deterministic : str mask_impl : str + tr_load : bool @property def bm0(self) -> int: @@ -620,7 +638,7 @@ class FmhaBwdApiTrait: def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile, F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout, - F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl) + F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl, F_trload=self.tr_load) @property def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel: @@ -636,12 +654,13 @@ class FmhaBwdApiTrait: class FmhaBwdApiPool: def __init__(self, mask_impl): - self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(list)) + self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + self.mask_impl = mask_impl def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None: # TODO: do we need to check duplication? - self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + self.dq_dk_dv_pool[trait.tr_load][trait.dtype][trait.hdim].append(copy.copy(trait)) @staticmethod def if_(i: int) -> str: @@ -656,24 +675,31 @@ class FmhaBwdApiPool: F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_deterministic=BOOL_MAP[trait.deterministic]) + F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load]) i += 1 return inners @property def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.dq_dk_dv_pool): - per_hdim_case=str() - for j, hdim in enumerate(self.dq_dk_dv_pool[dtype]): - traits=self.dq_dk_dv_pool[dtype][hdim] - inners = self._api_innders(traits) - per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(j), F_hdim=hdim, F_inner_dispatch=inners) - per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(i), F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: + tr_load_cond_map = { + "t": "has_load_tr", + "f": "true" + } + per_tr_load = '' + for tr_load in ["t", "f"]: + per_dtypes = '' + for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load]): + per_hdim_case = '' + for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][dtype]): + traits = self.dq_dk_dv_pool[tr_load][dtype][hdim] + inners = self._api_innders(traits) + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(k), F_hdim=hdim, F_body=inners) + per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(j), F_dtype=dtype, F_body=per_hdim_case) + per_tr_load += FMHA_BWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_body=per_dtypes) + if not per_tr_load: # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) + per_tr_load += ' (void)t ; (void)s ; (void)a;' + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load) def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]: if filter_list == '': @@ -690,8 +716,8 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {} api_pool = FmhaBwdApiPool(mask_impl) - for dtype in BWD_DTYPE_MAP.keys(): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) + for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype, tr_load) if d is None: continue for hdim_str, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)): @@ -703,7 +729,9 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm continue if ("wg32" in dropout): continue - t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl) + if tr_load == "t" and (dpad == "t" or dvpad == "t"): + continue # tr_load cannot work with dpad or dvpad + t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load) if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): continue diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index c999cf750e..bd63c96eb1 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/host/device_prop.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" @@ -363,7 +364,8 @@ template + bool kIsDeterministic_, + bool kUseTrLoad_> struct fmha_bwd_dq_dk_dv_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -376,6 +378,7 @@ struct fmha_bwd_dq_dk_dv_traits_ static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; static constexpr bool kIsDeterministic = kIsDeterministic_; + static constexpr bool kUseTrLoad = kUseTrLoad_; }; template diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index e9403f4698..c0e4dc3d30 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -526,9 +526,9 @@ bool run(const ck_tile::ArgParser& arg_parser) static_cast(2) * mask.get_unmaskarea() * hdim_v); num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + - sizeof(KDataType) * real_seqlen_k * hdim_q + - sizeof(VDataType) * hdim_v * real_seqlen_k + sizeof(ODataType) * real_seqlen_q * hdim_v); + num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * hdim_v * real_seqlen_k); } } diff --git a/example/ck_tile/39_copy/CMakeLists.txt b/example/ck_tile/39_copy/CMakeLists.txt new file mode 100644 index 0000000000..98397a33d2 --- /dev/null +++ b/example/ck_tile/39_copy/CMakeLists.txt @@ -0,0 +1,7 @@ +add_executable(tile_example_copy EXCLUDE_FROM_ALL copy_basic.cpp) + +# Impact: This flag ensures that the compiler doesn't make +# assumptions about memory aliasing that could interfere with Composable Kernel's explicit memory access patterns. +target_compile_options(tile_example_copy PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 +) diff --git a/example/ck_tile/39_copy/README.md b/example/ck_tile/39_copy/README.md new file mode 100644 index 0000000000..f45fcb682b --- /dev/null +++ b/example/ck_tile/39_copy/README.md @@ -0,0 +1,313 @@ +# CK Tile Framework: Getting Started with Tile Copy Operations + +## Overview + +### Copy Kernel +A minimal CK_Tile memory copy implementation demonstrating the basic setup required to write a kernel in CK Tile. +This experimental kernel is intended for novice CK developers. It introduces the building blocks of CK Tile and provides a sandbox for experimenting with kernel parameters. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture +# (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# Make the copy kernel executable +make tile_example_copy -j +``` +This will result in an executable `build/bin/test_copy_basic` + +## example +``` +args: + -m input matrix rows. (default 64) + -n input matrix cols. (default 8) + -id wave to use for computation. (default 0) + -v validation flag to check device results. (default 1) + -prec datatype precision to use. (default fp16) + -warmup no. of warmup iterations. (default 50) + -repeat no. of iterations for kernel execution time. (default 100) +``` + +## CK Tile Architecture Components + +The CK Tile framework is built around four key architectural components that work together to define and execute GPU kernels: shape, policy, problem, and pipeline. + +### **1. Shape** +Defines the **hierarchical tile structure** and **memory layout** of the kernel: + +```cpp +using Shape = ck_tile::TileCopyShape; +``` + +**Components:** +- **BlockWaves**: Number of concurrent waves per block (e.g., `seq<4, 1>` for 4 waves along M, 1 along N) +- **BlockTile**: Total elements processed by one block (e.g., `seq<512, 8>`) +- **WaveTile**: Elements processed by one wave (e.g., `seq<32, 8>`) +- **Vector**: Elements processed by one thread (e.g., `seq<1, 4>` for 4 contiguous elements) + +**Purpose**: Defines the **work distribution hierarchy** from threads → waves → blocks. + +### **2. Problem** +Defines the **data types** and **kernel configuration**: + +```cpp +using Problem = ck_tile::TileCopyProblem; +``` + +**Components:** +- **XDataType**: Input/output data type (e.g., `float`, `half`) +- **Shape**: The tile shape defined above + +**Purpose**: Encapsulates **what** the kernel operates on and **how** it's configured. + +### **3. Policy** +Defines the **memory access patterns** and **distribution strategies**: + +```cpp +using Policy = ck_tile::TileCopyPolicy; +``` + +**Key Functions:** +- **MakeDRAMDistribution()**: Defines how threads access DRAM memory. + +**Purpose**: Defines **how** data is accessed and distributed across threads. + +### **4. Pipeline** +Defines the **execution flow** and **memory movement patterns**: + +```cpp +// Example pipeline stages: +// 1. DRAM → Registers (load_tile) +// 2. Registers → LDS (store_tile) +// 3. LDS → Registers (load_tile with distribution) +// 4. Registers → DRAM (store_tile) +``` + +**Purpose**: Defines the **sequence of operations** and **memory movement strategy**. + +### **Component Interaction** + +```cpp +// Complete kernel definition +using Shape = ck_tile::TileCopyShape; +using Problem = ck_tile::TileCopyProblem; +using Policy = ck_tile::TileCopyPolicy; +using Kernel = ck_tile::TileCopyKernel; +``` + +**Flow:** +1. **Shape** defines the tile structure and work distribution +2. **Problem** combines data types with the shape +3. **Policy** defines memory access patterns for the problem +4. **Kernel** implements the actual computation using all components + +### **Why This Architecture?** + +#### **Separation of Concerns** +- **Shape**: Focuses on **work distribution** and **tile structure** +- **Problem**: Focuses on **data types** and **configuration** +- **Policy**: Focuses on **memory access** and **optimization** +- **Pipeline**: Focuses on **execution flow** and **synchronization** + +#### **Reusability** +- Same **Shape** can be used with different **Problems** +- Same **Policy** can be applied to different **Shapes** +- **Pipelines** can be reused across different kernels + +#### **Performance Optimization** +- **Shape** enables optimal work distribution +- **Policy** enables optimal memory access patterns +- **Pipeline** enables optimal execution flow + +## Core Concepts + +### Hierarchical Tile Structure + +The CK Tile framework organizes work in a hierarchical manner: + +1. **Vector**: Number of contiguous elements processed by a single thread + - Enables vectorized memory loads/stores. + - Example: `Vector = seq<1, 4>` means each thread loads 4 contiguous elements along the N dimension + - A Vector can be imagined as a thread-level tile + +2. **WaveTile**: Number of elements covered by a single wave (64 threads on AMD) + - Must satisfy: `Wave_Tile_M / Vector_M * Wave_Tile_N / Vector_N == WaveSize` + - This ensures the number of threads needed equals the wave size + - Example: `WaveTile = seq<64, 4>` with `Vector = seq<1, 4>` means: + - Each thread handles 4 elements (Vector_N = 4) + - Wave needs 64×4/4 = 64 threads to cover 64×4 = 256 elements + - Total elements = 256, which requires WaveSize = 64 threads + +3. **BlockTile**: Number of elements covered by one block (typically mapped to one CU) + - Example: `BlockTile = seq<256, 64>` means each block processes 256×64 elements + +4. **BlockWaves**: Number of concurrent waves active in a block + - Usually 4 waves per block on modern AMD GPUs + - Example: `BlockWaves = seq<4, 1>` means 4 waves along M dimension, 1 along N + +### Wave Repetition + +In many scenarios, the total work (BlockTile) is larger than what the available waves can cover in a single iteration. This requires **wave repetition**: + +```cpp +// Calculate how many times a wave needs to repeat to cover the entire block tile +static constexpr index_t WaveRepetitionPerBlock_M = + Block_Tile_M / (Waves_Per_Block_M * Wave_Tile_M); +static constexpr index_t WaveRepetitionPerBlock_N = + Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N); +``` + +**Key Insight**: When waves repeat, the effective work per thread becomes `Vector * Repeat`, not just `Vector`. + +## Tile Distribution Encoding + +The tile distribution encoding specifies how work is distributed across threads: + +```cpp +constexpr auto outer_encoding = + tile_distribution_encoding, // replication + tuple, sequence>, // hierarchy + tuple, sequence<1, 2>>, // parallelism + tuple, sequence<2, 0>>, // paralleism + sequence<1, 2>, // yield + sequence<0, 1>>{}; // yield +``` + +### Encoding Parameters Explained + +- **M0, M1, M2**: Hierarchical distribution along M dimension + - M0: Number of wave iterations along M + - M1: Number of waves along M + - M2: Number of threads per wave along M +- **N0, N1**: Distribution along N dimension + - N0: Number of threads along N + - N1: Vector size (elements per thread) +- **YIELD arguments**: Both `Repeat` and `Vector` because effective work per thread is `Vector * Repeat` + +## Tensor Abstractions + +### Tensor Descriptor +Defines the logical structure of a tensor: +```cpp +auto desc = make_naive_tensor_descriptor( + make_tuple(M, N), // tensor dimensions + make_tuple(N, 1), // strides + number{}, // vector length for vectorized access + number<1>{} // guaranteed last dimension vector stride +); +``` + +### Tensor View +Combines memory buffer with tensor descriptor: +```cpp +auto x_m_n = make_naive_tensor_view( + p_x, // memory buffer + make_tuple(M, N), // dimensions + make_tuple(N, 1), // strides + number{}, // vector length + number<1>{} // guaranteed last dimension vector stride +); +``` + +### Tile Window +A view into a specific tile of the tensor with thread distribution: +```cpp +auto x_window = make_tile_window( + x_m_n, // tensor view + make_tuple(Block_Tile_M, Block_Tile_N), // tile size + {iM, 0}, // tile origin + tile_distribution // how work is distributed among threads +); +``` + +## The test_copy_basic Kernel + +### Kernel Structure + +The `TileCopyKernel` implements a basic copy operation from input tensor `x` to output tensor `y`: + +```cpp +template +struct TileCopyKernel +{ + CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const + { + // 1. Create tensor views + // 2. Create tile windows + // 3. Iterate over N dimension tiles + // 4. Load, copy, and store data + } +}; +``` + +### Step-by-Step Execution + +1. **Tensor View Creation**: + ```cpp + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + ``` + - Creates views for both input and output tensors + - Specifies vectorized access with `Vector_N` elements per load + +2. **Tile Window Creation**: + ```cpp + auto x_window = make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {iM, 0}, + Policy::template MakeDRAMDistribution()); + ``` + - Creates windows into specific tiles of the tensors + - Each block processes one tile starting at `{iM, 0}` + - Tile distribution determines how threads access data + +3. **N-Dimension Iteration**: + ```cpp + index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N)); + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + ``` + - If tensor N dimension > Block_Tile_N, multiple iterations are needed + - Each iteration processes one tile along N dimension + +4. **Load-Store Operations**: + ```cpp + dram_reg_tile dram_tile; + load_tile(dram_tile, x_window); // Load from global memory to registers + store_tile(y_window, dram_tile); // Store from registers to global memory + move_tile_window(x_window, {0, S::Block_Tile_N}); // Move to next N tile + move_tile_window(y_window, {0, S::Block_Tile_N}); + ``` + +### How Load/Store Works + +1. **Load Tile**: + - Each thread loads its assigned elements based on tile distribution + - Vectorized loads enable efficient memory bandwidth utilization + - Data is distributed to per-thread register buffers + +2. **Store Tile**: + - Each thread writes its assigned elements back to global memory + - Maintains the same distribution pattern as load + +3. **Tile Window Movement**: + - Moves the window to the next tile along N dimension + - Enables processing of large tensors that don't fit in one tile + +## Memory Access Patterns + +### Vectorized Access +- Enabled by specifying vector length in tensor views +- Each thread loads/stores multiple contiguous elements in one operation +- Improves memory bandwidth utilization + +### Thread Distribution +- Tile distribution encoding determines which threads access which elements +- Ensures all threads participate and no data is missed +- Enables memory coalescing for optimal performance + +### Coordinate Transform (Embed) +- Maps multi-dimensional tensor indices to linear memory addresses +- Handles stride calculations automatically +- Enables efficient access to non-contiguous memory layouts diff --git a/example/ck_tile/39_copy/copy_basic.cpp b/example/ck_tile/39_copy/copy_basic.cpp new file mode 100644 index 0000000000..d46add879c --- /dev/null +++ b/example/ck_tile/39_copy/copy_basic.cpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include +#include "copy_basic.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "128", "m dimension") + .insert("n", "8", "n dimension") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision(fp16 or fp32)") + .insert("warmup", "50", "cold iter") + .insert("repeat", "100", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; + using YDataType = DataType; + + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + // Create host tensors + ck_tile::HostTensor x_host({m, n}); // input matrix + ck_tile::HostTensor y_host_ref({m, n}); // reference output matrix + ck_tile::HostTensor y_host_dev({m, n}); // device output matrix + + // Initialize input data with increasing values + ck_tile::half_t value = 1; + for(int i = 0; i < m; i++) + { + value = 1; + for(int j = 0; j < n; j++) + { + x_host(i, j) = value++; + } + } + + // Allocate device memory + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + + // Define tile configuration + using Vector = ck_tile::sequence<1, 4>; // vector size along M and N dimension + using WaveTile = ck_tile::sequence<64, 4>; // wave size along M and N dimension + using BlockWaves = ck_tile::sequence<4, 1>; // number of waves along M dimension + using BlockTile = ck_tile::sequence<512, 4>; // block size along M and N dimension + + // Calculate grid size + ck_tile::index_t kGridSize = + ck_tile::integer_divide_ceil(m, BlockTile::at(ck_tile::number<0>{})); + std::cout << "grid size (number of blocks per grid) " << kGridSize << std::endl; + + // Define kernel types + using Shape = ck_tile::TileCopyShape; + using Problem = ck_tile::TileCopyProblem; + using Policy = ck_tile::TileCopyPolicy; + using Kernel = ck_tile::ElementWiseTileCopyKernel; + // using Kernel = ck_tile::TileCopyKernel; + // using Kernel = ck_tile::TileCopyKernel_LDS; + + // question: Why do we not have a pipeline? + // answer: For basic copy operation, pipeline is not needed. + // we intentionally do not use pipeline for this example and let the kernel be composite of + // Problem and Policy + + constexpr ck_tile::index_t kBlockSize = Shape::BlockSize; + + // Print configuration information + std::cout << "block size (number of threads per block) " << kBlockSize << std::endl; + std::cout << "wave size (number of threads per wave) " << ck_tile::get_warp_size() << std::endl; + std::cout << "block waves (number of waves per block) " << BlockWaves::at(ck_tile::number<0>{}) + << " " << BlockWaves::at(ck_tile::number<1>{}) << std::endl; + std::cout << "block tile (number of elements per block) " << BlockTile::at(ck_tile::number<0>{}) + << " " << BlockTile::at(ck_tile::number<1>{}) << std::endl; + std::cout << "wave tile (number of elements per wave) " << WaveTile::at(ck_tile::number<0>{}) + << " " << WaveTile::at(ck_tile::number<1>{}) << std::endl; + std::cout << "vector (number of elements per thread) " << Vector::at(ck_tile::number<0>{}) + << " " << Vector::at(ck_tile::number<1>{}) << std::endl; + std::cout << "WaveRepetitionPerBlock_M = " << Shape::WaveRepetitionPerBlock_M << " --> (" + << Shape::Block_Tile_M << "/" << Shape::Waves_Per_Block_M << "*" << Shape::Wave_Tile_M + << ")" << std::endl; + std::cout << "WaveRepetitionPerBlock_N = " << Shape::WaveRepetitionPerBlock_N << " --> (" + << Shape::Block_Tile_N << "/" << Shape::Waves_Per_Block_N << "*" << Shape::Wave_Tile_N + << ")" << std::endl; + + // Launch kernel + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, warmup, repeat, 1}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + m, + n)); + + // Calculate and print performance metrics + std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m * n; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + // Copy results back to host + y_buf.FromDevice(y_host_dev.mData.data()); + // Use exact equality (tolerance = 0) for copy operations since copy should be exact + pass = ck_tile::check_err(y_host_dev, x_host, "Error: Copy operation failed!", 0.0, 0.0); + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + // Print results for debugging + // std::cout << "Input matrix (x_host):" << std::endl; + // std::cout << x_host << std::endl; + // std::cout << "Output matrix (y_host_dev):" << std::endl; + // std::cout << y_host_dev << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + if(arg_parser.get_str("prec") == "fp16") + return run(arg_parser) ? 0 : -2; + else + return run(arg_parser) ? 0 : -2; +} diff --git a/example/ck_tile/39_copy/copy_basic.hpp b/example/ck_tile/39_copy/copy_basic.hpp new file mode 100644 index 0000000000..bbeb964fda --- /dev/null +++ b/example/ck_tile/39_copy/copy_basic.hpp @@ -0,0 +1,369 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +namespace ck_tile { + +/** + * @brief Tile copy shape configuration + * + * @tparam BlockWaves Number of waves along seq + * @tparam BlockTile Block size, seq + * @tparam WaveTile Wave size, seq + * @tparam Vector Contiguous elements (vector size) along seq + */ +template +struct TileCopyShape +{ + // Vector dimensions for memory operations + static constexpr index_t Vector_M = Vector::at(number<0>{}); + static constexpr index_t Vector_N = Vector::at(number<1>{}); + + // Wave tile dimensions + static constexpr index_t Wave_Tile_M = WaveTile::at(number<0>{}); + static constexpr index_t Wave_Tile_N = WaveTile::at(number<1>{}); + + // Block tile dimensions + static constexpr index_t Block_Tile_M = BlockTile::at(number<0>{}); + static constexpr index_t Block_Tile_N = BlockTile::at(number<1>{}); + + // Waves per block configuration + static constexpr index_t Waves_Per_Block_M = BlockWaves::at(number<0>{}); + static constexpr index_t Waves_Per_Block_N = BlockWaves::at(number<1>{}); + + // Calculate wave repetition to cover entire block tile + static constexpr index_t WaveRepetitionPerBlock_M = + Block_Tile_M / (Waves_Per_Block_M * Wave_Tile_M); + static constexpr index_t WaveRepetitionPerBlock_N = + Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N); + + // Hardware configuration + static constexpr index_t WaveSize = get_warp_size(); + static constexpr index_t BlockSize = Waves_Per_Block_M * Waves_Per_Block_N * WaveSize; + + // Configuration validation + static_assert(Block_Tile_M > 0 && Block_Tile_N > 0, "Block tile dimensions must be positive"); + static_assert(Wave_Tile_M > 0 && Wave_Tile_N > 0, "Wave tile dimensions must be positive"); + static_assert(Vector_M > 0 && Vector_N > 0, "Vector dimensions must be positive"); + static_assert(Waves_Per_Block_M > 0 && Waves_Per_Block_N > 0, + "Waves per block must be positive"); + static_assert(Waves_Per_Block_M * Wave_Tile_M > 0, + "Invalid wave configuration for M dimension"); + static_assert(Waves_Per_Block_N * Wave_Tile_N > 0, + "Invalid wave configuration for N dimension"); + + // Ensure wave tile dimensions align with wave size + static_assert(Wave_Tile_M / Vector_M * Wave_Tile_N / Vector_N == WaveSize, + "(Wave_Tile_M/Vector_M) * (Wave_Tile_N/Vector_N) != WaveSize"); +}; + +/** + * @brief Problem definition for tile copy operation + */ +template +struct TileCopyProblem +{ + using XDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; +}; + +/** + * @brief Policy for tile copy operation + */ +template +struct TileCopyPolicy +{ + using Problem = ck_tile::remove_cvref_t; + using XDataType = typename Problem::XDataType; + + /** + * @brief Create DRAM distribution for optimal memory access + */ + template + CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() + { + using S = typename Problem::BlockShape; + + constexpr index_t wave_size = S::WaveSize; + constexpr index_t block_size = S::BlockSize; + + // Distribution calculation to ensure all threads participate + constexpr index_t N1 = S::Vector_N; // Elements per thread along N + constexpr index_t N0 = S::Block_Tile_N / N1; // Threads needed along N + + constexpr index_t M2 = wave_size / N0; // Threads per wave along M + constexpr index_t M1 = block_size / wave_size; // Waves possible along M + constexpr index_t M0 = S::Block_Tile_M / (M1 * M2); // Wave iterations along M + + // Validate complete coverage + static_assert(M0 * M1 * M2 * N0 * N1 == S::Block_Tile_M * S::Block_Tile_N, + "Tile distribution must cover entire block tile"); + + constexpr auto outer_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}; + return make_static_tile_distribution(outer_encoding); + } +}; + +/** + * @brief Direct copy kernel from global memory to global memory + */ +template +struct TileCopyKernel +{ + using Problem = ck_tile::remove_cvref_t; + using XDataType = typename Problem::XDataType; + using Policy = ck_tile::remove_cvref_t; + + CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const + { + using S = typename Problem::BlockShape; + + // Calculate tile block origin and validate bounds + // Use __builtin_amdgcn_readfirstlane to broadcast the same value to all threads in a wave + // This saves VGPR usage by avoiding per-thread storage of the same value + const auto tile_block_origin_m = + __builtin_amdgcn_readfirstlane(get_block_id() * S::Block_Tile_M); + if(tile_block_origin_m >= M) + { + return; // Early exit for out-of-bounds blocks + } + + // Create tensor views for input and output + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + const auto y_m_n = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + // Create tile windows with DRAM distribution + auto x_window = + make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); + + auto y_window = + make_tile_window(y_m_n, + make_tuple(number{}, number{}), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); + + // Calculate iterations needed to cover N dimension + // Note: This kernel uses data parallelism only in the M dimension. + // Each block processes one tile in M dimension, but iterates through N dimension tiles. + // This design choice is for simplicity and to avoid complex tile distribution. + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N)); + + // Get tile distribution for register tensor + auto DramTileDist = x_window.get_tile_distribution(); + using dram_reg_tile = decltype(make_static_distributed_tensor(DramTileDist)); + + // Main copy loop - processes N dimension tiles sequentially within each block + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + dram_reg_tile dram_tile; + + // Direct copy implementation + load_tile(dram_tile, x_window); + store_tile(y_window, dram_tile); + + // Move to next N tile + move_tile_window(x_window, {0, S::Block_Tile_N}); + move_tile_window(y_window, {0, S::Block_Tile_N}); + } + } +}; + +/** + * @brief Element-wise copy kernel for data transformation scenarios + * + * This kernel performs element-wise copy operations, allowing for data transformation + * during the copy process. Useful when data needs to be processed or converted + * between different formats. + */ +template +struct ElementWiseTileCopyKernel +{ + using Problem = ck_tile::remove_cvref_t; + using XDataType = typename Problem::XDataType; + using Policy = ck_tile::remove_cvref_t; + + CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const + { + using S = typename Problem::BlockShape; + + // Calculate block origin and validate bounds + // Use __builtin_amdgcn_readfirstlane to broadcast the same value to all threads in a wave + // This saves VGPR usage by avoiding per-thread storage of the same value + const auto tile_block_origin_m = + __builtin_amdgcn_readfirstlane(get_block_id() * S::Block_Tile_M); + if(tile_block_origin_m >= M) + { + return; // Early exit for out-of-bounds blocks + } + + // Create tensor views for input and output + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + const auto y_m_n = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + // Create tile windows with DRAM distribution + auto x_window = + make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); + + auto y_window = + make_tile_window(y_m_n, + make_tuple(number{}, number{}), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); + + // Calculate iterations needed to cover N dimension + // Note: This kernel uses data parallelism only in the M dimension. + // Each block processes one tile in M dimension, but iterates through N dimension tiles. + // This design choice is for simplicity and to avoid complex tile distribution. + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N)); + + // Main element-wise copy loop - processes N dimension tiles sequentially within each block + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + // Element-wise copy implementation for data transformation + const auto xa = load_tile(x_window); + auto y_compute = load_tile(y_window); + + constexpr auto spans = decltype(xa)::get_distributed_spans(); + + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + const auto x = ck_tile::type_convert(xa[i_j_idx]); + y_compute(i_j_idx) = x; + }); + }); + + store_tile(y_window, y_compute); + + // Move to next N tile + move_tile_window(x_window, {0, S::Block_Tile_N}); + move_tile_window(y_window, {0, S::Block_Tile_N}); + } + } +}; + +/** + * @brief LDS-based copy kernel for data processing scenarios + * + * This kernel copies data from global memory to LDS and then to global memory, + * useful when data needs to be processed or transformed during the copy operation. + */ +template +struct TileCopyKernel_LDS +{ + using Problem = ck_tile::remove_cvref_t; + using XDataType = typename Problem::XDataType; + using Policy = ck_tile::remove_cvref_t; + + CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const + { + using S = typename Problem::BlockShape; + + // Calculate block origin and validate bounds + // Use __builtin_amdgcn_readfirstlane to broadcast the same value to all threads in a wave + // This saves VGPR usage by avoiding per-thread storage of the same value + const auto tile_block_origin_m = + __builtin_amdgcn_readfirstlane(get_block_id() * S::Block_Tile_M); + if(tile_block_origin_m >= M) + { + return; // Early exit for out-of-bounds blocks + } + + // LDS buffer allocation + __shared__ XDataType x_lds_buffer[S::Block_Tile_M * S::Block_Tile_N]; + + // LDS tensor descriptor and view + const auto x_lds_descriptor = + make_naive_tensor_descriptor(make_tuple(S::Block_Tile_M, S::Block_Tile_N), + make_tuple(S::Block_Tile_N, 1), + number{}, + number<1>{}); + + auto x_lds_view = make_tensor_view(x_lds_buffer, x_lds_descriptor); + + // LDS windows with different distributions for optimal access patterns + auto x_lds_write_window = make_tile_window( + x_lds_view, make_tuple(number{}, number{}), {0, 0}); + + auto x_lds_read_window = + make_tile_window(x_lds_view, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeDRAMDistribution()); + + // Global memory tensor views + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + const auto y_m_n = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + // Global memory tile windows + auto x_window = + make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {tile_block_origin_m, 0}, + Policy::template MakeDRAMDistribution()); + + auto y_window = + make_tile_window(y_m_n, + make_tuple(number{}, number{}), + {tile_block_origin_m, 0}); + + // Calculate iterations needed to cover N dimension + // Note: This kernel uses data parallelism only in the M dimension. + // Each block processes one tile in M dimension, but iterates through N dimension tiles. + // This design choice is for simplicity and to avoid complex tile distribution. + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_Tile_N)); + + // Main copy loop with LDS staging - processes N dimension tiles sequentially within each + // block + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + // Global memory to LDS + auto dram_tile = load_tile(x_window); + store_tile(x_lds_write_window, dram_tile); + + // Synchronize LDS access + block_sync_lds(); + + // LDS to global memory + auto lds_tile = load_tile(x_lds_read_window); + store_tile(y_window, lds_tile); + + // Move to next N tile + move_tile_window(x_window, {0, S::Block_Tile_N}); + move_tile_window(y_window, {0, S::Block_Tile_N}); + } + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 630b96ede0..8fce70ba04 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -23,3 +23,4 @@ add_subdirectory(20_grouped_convolution) add_subdirectory(21_elementwise) add_subdirectory(35_batched_transpose) add_subdirectory(38_block_scale_gemm) +add_subdirectory(39_copy) diff --git a/include/ck_tile/core/numeric/integral_constant.hpp b/include/ck_tile/core/numeric/integral_constant.hpp index 2ba2fd10c6..1eec80828a 100644 --- a/include/ck_tile/core/numeric/integral_constant.hpp +++ b/include/ck_tile/core/numeric/integral_constant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -83,4 +83,14 @@ CK_TILE_BINARY_OP(<=) #undef CK_TILE_LEFT_UNARY_OP #undef CK_TILE_BINARY_OP +template +struct is_constant : std::false_type +{ +}; +template +struct is_constant> : std::true_type +{ +}; +template +inline constexpr bool is_constant_v = is_constant::value; } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index ec5538d79c..eb226debfd 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -259,6 +259,7 @@ struct tensor_adaptor CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); } + template CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides( const array& guaranteed_vector_lengths, const array& guaranteed_vector_strides) @@ -266,7 +267,9 @@ struct tensor_adaptor auto vector_lengths = guaranteed_vector_lengths; auto vector_strides = guaranteed_vector_strides; - static_for<0, get_num_of_transform(), 1>{}([&](auto itran) { + static_for<0, + Internal ? std::min(Internal, get_num_of_transform()) : get_num_of_transform(), + 1>{}([&](auto itran) { constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran); constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran); @@ -298,11 +301,16 @@ struct tensor_adaptor set_container_subset(vector_lengths, up_dims, up_vector_lengths); set_container_subset(vector_strides, up_dims, up_vector_strides); }); - - constexpr auto top_dims = TopDimensionHiddenIds{}; - - return make_tuple(get_container_subset(vector_lengths, top_dims), - get_container_subset(vector_strides, top_dims)); + if constexpr(Internal > 0) + { + return make_tuple(vector_lengths, vector_strides); + } + else + { + constexpr auto top_dims = TopDimensionHiddenIds{}; + return make_tuple(get_container_subset(vector_lengths, top_dims), + get_container_subset(vector_strides, top_dims)); + } } private: diff --git a/include/ck_tile/core/tensor/tensor_descriptor.hpp b/include/ck_tile/core/tensor/tensor_descriptor.hpp index 0e4787a2f1..3b372d45dd 100644 --- a/include/ck_tile/core/tensor/tensor_descriptor.hpp +++ b/include/ck_tile/core/tensor/tensor_descriptor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -133,9 +133,10 @@ struct tensor_descriptor : public tensor_adaptor CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides() { - return Base::get_top_dimension_safe_vector_length_strides( + return Base::template get_top_dimension_safe_vector_length_strides( to_array(GuaranteedVectorLengths{}), to_array(GuaranteedVectorStrides{})); } @@ -377,12 +378,29 @@ make_naive_tensor_descriptor_packed(const tuple& lengths, const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{}); + constexpr index_t first_dim_length = []() { + if constexpr(is_constant_v>) + return decltype(element_space_size)::value; + else + return -1; + }(); + using last_t = remove_cvref_t())>; + constexpr index_t last_dim_length = []() { + if constexpr(is_constant_v) + return std::max(last_t::value, GuaranteedLastDimensionVectorLength); + else + return -1; + }(); + using GuaranteedVectorLengths = - typename sequence_merge::type, - sequence>::type; + typename sequence_merge, + typename uniform_sequence_gen::type, + sequence>::type; using GuaranteedVectorStrides = - typename sequence_merge::type, sequence<1>>::type; + typename sequence_merge, + typename uniform_sequence_gen::type, + sequence<1>>::type; return tensor_descriptor, remove_cv_t, diff --git a/include/ck_tile/core/utility/reduce_operator.hpp b/include/ck_tile/core/utility/reduce_operator.hpp index 2d7ac78b06..a698c91e45 100644 --- a/include/ck_tile/core/utility/reduce_operator.hpp +++ b/include/ck_tile/core/utility/reduce_operator.hpp @@ -35,8 +35,6 @@ struct Add return type_convert(y_ + x_); } - - static constexpr bool requires_special_combine = false; }; struct SquareAdd @@ -64,28 +62,6 @@ struct SquareAdd float x_ = type_convert(x); return type_convert(y_ + (x_ * x_)); } - - // For combining partial results - template || std::is_same_v || - std::is_same_v || std::is_same_v>> - CK_TILE_HOST_DEVICE constexpr T combine_partial_results(const T& partial1, - const T& partial2) const - { - return partial1 + partial2; // Just add the partial sums, don't square again - } - - template || std::is_same_v || - std::is_same_v || std::is_same_v>> - CK_TILE_HOST_DEVICE constexpr T combine_partial_results(T& partial1, T& partial2) const - { - float partial1_ = type_convert(partial1); - float partial2_ = type_convert(partial2); - return type_convert(partial1_ + partial2_); - } - - static constexpr bool requires_special_combine = true; }; struct Max @@ -109,8 +85,6 @@ struct Max { return max(y, x); } - - static constexpr bool requires_special_combine = false; }; struct AbsMax @@ -134,8 +108,6 @@ struct AbsMax { return max(y, abs(x)); } - - static constexpr bool requires_special_combine = false; }; } // namespace ReduceOp diff --git a/include/ck_tile/host/device_prop.hpp b/include/ck_tile/host/device_prop.hpp index 502eb38e12..f86e4b889a 100644 --- a/include/ck_tile/host/device_prop.hpp +++ b/include/ck_tile/host/device_prop.hpp @@ -65,6 +65,11 @@ inline bool is_gfx12_supported() return get_device_name() == "gfx1200" || get_device_name() == "gfx1201"; } +inline bool is_load_tr_supported() +{ + // Check if load transpose is supported. + return get_device_name() == "gfx950"; +} } // namespace ck_tile #endif diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 313de5f29a..276ec4852f 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -25,8 +25,10 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 8b184b18f3..595e2cfccf 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -62,6 +62,12 @@ struct FmhaBwdDQDKDVKernel static constexpr bool kHasDropout = FmhaDropout::IsDropout; static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval; static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic; + static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad; +#if defined(__gfx950__) + static constexpr bool kIsAvialable = true; +#else + static constexpr bool kIsAvialable = !kUseTrLoad; +#endif // clang-format off template struct t2s; @@ -99,7 +105,7 @@ struct FmhaBwdDQDKDVKernel ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) + - (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ); + (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload"); #undef _SS_ #undef _TS_ // clang-format on @@ -298,6 +304,24 @@ struct FmhaBwdDQDKDVKernel using Kargs = std::conditional_t; + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr Kargs + MakeKargs(Ts... args, const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr Kargs + MakeKargs(Ts... args, const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + template CK_TILE_HOST static constexpr std::enable_if_t MakeKargsImpl(const void* q_ptr, @@ -466,248 +490,6 @@ struct FmhaBwdDQDKDVKernel return kargs; } - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - const void* lse_ptr, - const void* do_ptr, - const void* d_ptr, - void* rand_val_ptr, - void* dk_ptr, - void* dv_ptr, - void* dbias_ptr, - void* dq_acc_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_do, - ck_tile::index_t stride_dq_acc, - ck_tile::index_t stride_dk, - ck_tile::index_t stride_dv, - ck_tile::index_t stride_dbias, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_do, - ck_tile::index_t nhead_stride_lsed, - ck_tile::index_t nhead_stride_dq_acc, - ck_tile::index_t nhead_stride_dk, - ck_tile::index_t nhead_stride_dv, - ck_tile::index_t nhead_stride_dbias, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_do, - ck_tile::index_t batch_stride_lsed, - ck_tile::index_t batch_stride_dq_acc, - ck_tile::index_t batch_stride_dk, - ck_tile::index_t batch_stride_dv, - ck_tile::index_t batch_stride_dbias, - ck_tile::index_t split_stride_dq_acc, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - const std::tuple& drop_seed_offset) - { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - bias_ptr, - lse_ptr, - do_ptr, - d_ptr, - rand_val_ptr, - dk_ptr, - dv_ptr, - dbias_ptr, - dq_acc_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_do, - stride_dq_acc, - stride_dk, - stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, - nhead_stride_dk, - nhead_stride_dv, - nhead_stride_dbias, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_do, - batch_stride_lsed, - batch_stride_dq_acc, - batch_stride_dk, - batch_stride_dv, - batch_stride_dbias, - split_stride_dq_acc, - window_size_left, - window_size_right, - mask_type, - p_drop, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); - } - - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - const void* lse_ptr, - const void* do_ptr, - const void* d_ptr, - void* rand_val_ptr, - void* dk_ptr, - void* dv_ptr, - void* dbias_ptr, - void* dq_acc_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_do, - ck_tile::index_t stride_dq_acc, - ck_tile::index_t stride_dk, - ck_tile::index_t stride_dv, - ck_tile::index_t stride_dbias, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_do, - ck_tile::index_t nhead_stride_lsed, - ck_tile::index_t nhead_stride_dq_acc, - ck_tile::index_t nhead_stride_dk, - ck_tile::index_t nhead_stride_dv, - ck_tile::index_t nhead_stride_dbias, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_do, - ck_tile::index_t batch_stride_lsed, - ck_tile::index_t batch_stride_dq_acc, - ck_tile::index_t batch_stride_dk, - ck_tile::index_t batch_stride_dv, - ck_tile::index_t batch_stride_dbias, - ck_tile::index_t split_stride_dq_acc, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - const std::tuple& drop_seed_offset) - { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - bias_ptr, - lse_ptr, - do_ptr, - d_ptr, - rand_val_ptr, - dk_ptr, - dv_ptr, - dbias_ptr, - dq_acc_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_do, - stride_dq_acc, - stride_dk, - stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, - nhead_stride_dk, - nhead_stride_dv, - nhead_stride_dbias, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_do, - batch_stride_lsed, - batch_stride_dq_acc, - batch_stride_dk, - batch_stride_dv, - batch_stride_dbias, - split_stride_dq_acc, - window_size_left, - window_size_right, - mask_type, - p_drop, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); - } - template CK_TILE_HOST static constexpr std::enable_if_t MakeKargsImpl(const void* q_ptr, @@ -854,208 +636,6 @@ struct FmhaBwdDQDKDVKernel return kargs; } - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - const void* lse_ptr, - const void* do_ptr, - const void* d_ptr, - void* rand_val_ptr, - void* dk_ptr, - void* dv_ptr, - void* dbias_ptr, - void* dq_acc_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_do, - ck_tile::index_t stride_dq_acc, - ck_tile::index_t stride_dk, - ck_tile::index_t stride_dv, - ck_tile::index_t stride_dbias, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_do, - ck_tile::index_t nhead_stride_lsed, - ck_tile::index_t nhead_stride_dq_acc, - ck_tile::index_t nhead_stride_dk, - ck_tile::index_t nhead_stride_dv, - ck_tile::index_t nhead_stride_dbias, - ck_tile::index_t split_stride_dq_acc, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - const std::tuple& drop_seed_offset) - { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - bias_ptr, - lse_ptr, - do_ptr, - d_ptr, - rand_val_ptr, - dk_ptr, - dv_ptr, - dbias_ptr, - dq_acc_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_do, - stride_dq_acc, - stride_dk, - stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, - nhead_stride_dk, - nhead_stride_dv, - nhead_stride_dbias, - split_stride_dq_acc, - window_size_left, - window_size_right, - mask_type, - p_drop, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); - } - - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - const void* lse_ptr, - const void* do_ptr, - const void* d_ptr, - void* rand_val_ptr, - void* dk_ptr, - void* dv_ptr, - void* dbias_ptr, - void* dq_acc_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_do, - ck_tile::index_t stride_dq_acc, - ck_tile::index_t stride_dk, - ck_tile::index_t stride_dv, - ck_tile::index_t stride_dbias, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_do, - ck_tile::index_t nhead_stride_lsed, - ck_tile::index_t nhead_stride_dq_acc, - ck_tile::index_t nhead_stride_dk, - ck_tile::index_t nhead_stride_dv, - ck_tile::index_t nhead_stride_dbias, - ck_tile::index_t split_stride_dq_acc, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - const std::tuple& drop_seed_offset) - { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - bias_ptr, - lse_ptr, - do_ptr, - d_ptr, - rand_val_ptr, - dk_ptr, - dv_ptr, - dbias_ptr, - dq_acc_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_do, - stride_dq_acc, - stride_dk, - stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, - nhead_stride_dk, - nhead_stride_dv, - nhead_stride_dbias, - split_stride_dq_acc, - window_size_left, - window_size_right, - mask_type, - p_drop, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); - } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) { @@ -1082,6 +662,12 @@ struct FmhaBwdDQDKDVKernel } CK_TILE_DEVICE void operator()(Kargs kargs) const + { + if constexpr(kIsAvialable) + run_(std::move(kargs)); + } + + CK_TILE_DEVICE void run_(Kargs kargs) const { // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; @@ -1282,62 +868,33 @@ struct FmhaBwdDQDKDVKernel {0, 0}); auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() { - if constexpr(kIsDeterministic) - { - AccDataType* dq_acc_ptr = - reinterpret_cast(kargs.dq_acc_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + - static_cast(i_tile_n_) * kargs.split_stride_dq_acc + - batch_offset_dq_acc; + AccDataType* dq_acc_ptr = reinterpret_cast(kargs.dq_acc_ptr) + [&]() { + if constexpr(kIsDeterministic) + return static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + + static_cast(i_tile_n_) * kargs.split_stride_dq_acc + + batch_offset_dq_acc; + else + return static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + + batch_offset_dq_acc; + }(); - auto dq_acc_dram = [&]() { - const auto dq_acc_dram_naive = - make_naive_tensor_view( - dq_acc_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_dq_acc, 1), - number{}, - number<1>{}); - - return pad_tensor_view( - dq_acc_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - return make_tile_window( - dq_acc_dram, - make_tuple(number{}, number{}), - {0, 0}); - } - else - { - AccDataType* dq_acc_ptr = - reinterpret_cast(kargs.dq_acc_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_dq_acc + - batch_offset_dq_acc; - - auto dq_acc_dram = [&]() { - const auto dq_acc_dram_naive = - make_naive_tensor_view( - dq_acc_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_dq_acc, 1), - number{}, - number<1>{}); - - return pad_tensor_view( - dq_acc_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - return make_tile_window( - dq_acc_dram, - make_tuple(number{}, number{}), - {0, 0}); - } + constexpr auto DstInMemOp = conditional_expr( + memory_operation_enum::set, memory_operation_enum::atomic_add); + const auto dq_acc_dram_naive = + make_naive_tensor_view( + dq_acc_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_dq_acc, 1), + number{}, + number<1>{}); + const auto dq_acc_dram = pad_tensor_view( + dq_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + return make_tile_window( + dq_acc_dram, + make_tuple(number{}, number{}), + {0, 0}); }(); auto lse_dram_window = diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 1f11569533..d36f8ad724 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -54,6 +54,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + static constexpr bool kUseTrLoad = Problem::kUseTrLoad; + static_assert(!kUseTrLoad, "This pipeline does not use trload!"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 967fe2362d..88fb1281aa 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -54,6 +54,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + static constexpr bool kUseTrLoad = Problem::kUseTrLoad; + static_assert(!kUseTrLoad, "This pipeline does not use trload!"); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -654,9 +656,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP }(); // STAGE 3, P^T@OGrad^T Gemm1 - Policy::template PTFromGemm0CToGemm1A(pt_reg_tensor, p_gemm); + Policy::template PTFromGemm0CToGemm1A(pt_reg_tensor, p_gemm); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); auto qt_reg_tensor = load_tile(qt_lds_read_window); @@ -728,9 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP // STAGE 6, SGrad^T@Q^T Gemm3 const auto ds_gemm = cast_tile(ds); - Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, ds_gemm); + Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, ds_gemm); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp index 80c311de86..bf38c3c07d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp @@ -6,22 +6,30 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp" namespace ck_tile { -template +template class BlockFmhaBwdDQDKDVPipelineSelector { static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV; public: - using type = std::conditional_t, - BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP>; + template + using type_ = + std::conditional_t, + std::conditional_t, + BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP>>; + using type = std::conditional_t, // + type_, + type_>; }; -template -class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector::type +template +class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector::type { public: static constexpr const char* name = "auto"; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp new file mode 100644 index 0000000000..1d95bc2801 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -0,0 +1,760 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using FmhaDropout = remove_cvref_t; + // using HotLoopScheduler = typename Policy::template HotLoopScheduler; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kIsDeterministic = Problem::kIsDeterministic; + static constexpr bool kUseTrLoad = Problem::kUseTrLoad; + static_assert(kUseTrLoad, "This pipeline uses trload!"); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = 1; + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = 1; + + static constexpr const char* name = "trload_kr_ktr_vr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static LSEDataType get_validated_lse(const LSEDataType raw_lse) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || FmhaMask::IsMasking) + return (raw_lse == -numeric::infinity()) // + ? type_convert(0.f) + : raw_lse; + else + return raw_lse; + }; + + template + CK_TILE_DEVICE auto operator()( // + const QDramBlockWindowTmp& q_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, + float scale, + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + FmhaDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + // K, HBM ->LDS ->Reg + auto k_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + k_dram_block_window_tmp.get_bottom_tensor_view()), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); + + const auto k_origin = k_dram_window.get_window_origin(); + + // Early termination + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return make_tuple(dk_acc, dv_acc); + } + } + + // LDS allocation + const auto smem_ptr_ = + reinterpret_cast(smem_ptr); // cast to char* to do pointer arithmetic + + const auto k_lds_ptr = reinterpret_cast(smem_ptr_); + const auto v_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeK()); + + const auto do_lds_ptr0 = reinterpret_cast(smem_ptr_); + const auto do_lds_ptr1 = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad()); + const auto q_lds_ptr0 = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad()); + const auto q_lds_ptr1 = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ()); + const auto lse_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ()); + const auto d_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeLSE()); + const auto ds_lds_ptr = reinterpret_cast( + smem_ptr_ + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeOGrad() + + Policy::template GetSmemSizeQ() + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeLSE() + Policy::template GetSmemSizeD()); + const auto bias_lds_ptr = reinterpret_cast(ds_lds_ptr); + + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); + auto k_lds_write_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + //------------------------------------------------------------------ + // V, HBM ->LDS ->Reg + auto v_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + v_dram_block_window_tmp.get_bottom_tensor_view()), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVDramTileDistribution()); + auto v_lds = make_tensor_view( + v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); + auto v_lds_write_window = + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + + //------------------------------------------------------------------ + // KT, HBM -> LDS --trload-->Reg + async_load_tile(k_lds_write_window, k_dram_window); + async_load_tile(v_lds_write_window, v_dram_window); + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + + //------------------------------------------------------------------ + // Pre-Load KV into Registers + auto k_lds_read = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsReadBlockDescriptor()); + auto k_lds_read_window = + make_tile_window(k_lds_read, + make_tuple(number{}, number{}), + k_lds_write_window.get_window_origin(), + Policy::template MakeKRegBlockDescriptor()); + auto k_reg_tensor = load_tile(k_lds_read_window); + + auto kt_lds_read_window = + make_tile_window(k_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeKTRegBlockDescriptor()); + + auto kt_reg_tensor = load_tile_transpose(kt_lds_read_window); + + auto v_lds_read = make_tensor_view( + v_lds_ptr, Policy::template MakeVLdsReadBlockDescriptor()); + auto v_lds_read_window = + make_tile_window(v_lds_read, + make_tuple(number{}, number{}), + v_lds_write_window.get_window_origin(), + Policy::template MakeVRegBlockDescriptor()); + auto v_reg_tensor = load_tile(v_lds_read_window); + + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + //---------------------------- Loop Load in ----------------------------// + // Q: HBM -->LDS + auto q_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + q_dram_block_window_tmp.get_bottom_tensor_view()), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeQDramTileDistribution()); + + auto q_lds = make_tensor_view( + q_lds_ptr0, Policy::template MakeQLdsWriteBlockDescriptor()); + auto q_lds_write_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + auto q_lds_read = make_tensor_view( + q_lds_ptr0, Policy::template MakeQLdsReadBlockDescriptor()); + auto q_lds_read_window = + make_tile_window(q_lds_read, + make_tuple(number{}, number{}), + q_lds_write_window.get_window_origin(), + Policy::template MakeQRegSliceBlockDescriptor()); + auto qt_lds_read_window = + make_tile_window(q_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeQTRegSliceBlockDescriptor()); + + // dO: HBM ->LDS ---load--> Reg + // dOT: \-loadtr-> Reg + auto do_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + do_dram_block_window_tmp.get_bottom_tensor_view()), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}, + Policy::template MakeOGradDramTileDistribution()); + + auto do_lds = make_tensor_view( + do_lds_ptr0, Policy::template MakeOGradLdsWriteBlockDescriptor()); + auto do_lds_write_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + auto do_lds_read = make_tensor_view( + do_lds_ptr0, Policy::template MakeOGradLdsReadBlockDescriptor()); + auto do_lds_read_window = + make_tile_window(do_lds_read, + make_tuple(number{}, number{}), + do_lds_write_window.get_window_origin(), + Policy::template MakeOGradRegSliceBlockDescriptor()); + auto dot_lds_read_window = + make_tile_window(do_lds_read, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeOGradTRegSliceBlockDescriptor()); + + // dS: Reg -> Reg -> LDS + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // transform it to make it from col-major to row-major; prepared for load_tile_transpose + auto ds_lds_t = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_read_window = + make_tile_window(ds_lds_t, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeSGradRegSliceBlockDescriptor()); + + // Bias: HBM ->Reg ->Reg ->LDS + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + + auto bias_dram_window = + make_tile_window(Policy::template TransformXDramTensorView( + bias_dram_block_window_tmp.get_bottom_tensor_view()), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}, + Policy::template MakeBiasTileDistribution()); + + auto bias_lds = make_tensor_view( + bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor()); + auto bias_lds_write_window = + make_tile_window(bias_lds, make_tuple(number{}, number{}), {0, 0}); + + auto bias_lds_read = make_tensor_view( + bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor()); + auto bias_s_lds_read_window = + make_tile_window(bias_lds_read, + make_tuple(number{}, number{}), + bias_lds_write_window.get_window_origin(), + Policy::template MakeBiasSTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // LSE: HBM -> LDS ->Reg + auto lse_dram_window = make_tile_window( + lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); + + auto lse_lds = make_tensor_view( + lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + + auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number{}), {0}); + + auto lse_lds_read_window = make_tile_window( + lse_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // D: HBM ->Reg + auto d_dram_window = make_tile_window( + d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}, + Policy::template MakeLSEDDramTileDistribution()); + + auto d_lds = make_tensor_view( + d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor()); + auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number{}), {0}); + auto d_lds_read_window = make_tile_window( + d_lds, + make_tuple(number{}), + {0}, + Policy::template MakeLSEDLdsReadBlockDescriptor()); + + // RandVal: HBM ->Reg + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + // BiasGrad + // Reg ->LDS ->Reg ->HBM + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + + auto dbias_dram_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto dbias_lds_read_window = + make_tile_window(bias_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + // ----------------------------Loop write out------------------------------// + auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + index_t i_total_loops = 0; + index_t seqlen_q_step = seqlen_q_start; + static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0"); + static_assert(kM0 == kK1, "kM0 should equal to kK1"); + static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2"); + static_assert(kM0 == kK3, "kM0 should equal to kK3"); + constexpr index_t k4_loops = kN0 / kK4; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + __builtin_amdgcn_sched_barrier(0); + + decltype(load_tile(q_lds_read_window)) q_reg_tensor; + decltype(load_tile(lse_lds_read_window)) lse; + decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor; + decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor_next; + decltype(load_tile(do_lds_read_window)) do_reg_tensor; + decltype(load_tile_transpose(dot_lds_read_window)) dot_reg_tensor; + decltype(load_tile(d_lds_read_window)) d; + decltype(load_tile_transpose(qt_lds_read_window)) qt_reg_tensor; + decltype(gemm_0.MakeCBlockTile()) s_acc, p; + decltype(gemm_2.MakeCBlockTile()) dp_acc, ds; + decltype(gemm_4.MakeCBlockTile()) dq_acc; + + decltype(load_tile(lse_dram_window)) lse_block_tile; + decltype(load_tile(d_dram_window)) d_block_tile; + + index_t i_total_bodys = 0; + auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable { + const bool is_even = (i_total_bodys % 2 == 0); + QDataType* const __restrict__ q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0; + QDataType* const __restrict__ q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1; + OGradDataType* const __restrict__ do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0; + OGradDataType* const __restrict__ do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1; + + constexpr bool is_prologue = is_prologue_.value; + constexpr bool is_epilogue = is_epilogue_.value; + static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true"); + constexpr bool is_main_body = is_prologue && is_epilogue; + + if constexpr(is_prologue) + { + q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next); + async_load_tile(q_lds_write_window, q_dram_window); + move_tile_window(q_dram_window, {kM0, 0}); + + lse_block_tile = load_tile(lse_dram_window); + move_tile_window(lse_dram_window, {kM0}); + + do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next); + async_load_tile(do_lds_write_window, do_dram_window); + move_tile_window(do_dram_window, {kM0, 0}); + + d_block_tile = load_tile(d_dram_window); + move_tile_window(d_dram_window, {kM0}); + } + if constexpr(is_epilogue) + { + // STAGE 1, Q@K Gemm0 + s_acc = gemm_0(q_reg_tensor, k_reg_tensor); + + dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr); + dot_reg_tensor = load_tile_transpose(dot_lds_read_window); + } + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm0(); + __builtin_amdgcn_sched_barrier(0); + if constexpr(is_epilogue) + { + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + async_load_tile(bias_lds_write_window, bias_dram_window); + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + auto bias_s_tile = load_tile(bias_s_lds_read_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { + x = scale * x + log2e_v * type_convert(y); + }, + s_acc, + bias_s_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + + { + bool need_perpixel_check = mask.IsEdgeTile( + seqlen_q_step, k_origin.at(number<0>{}), number{}, number{}); + if(need_perpixel_check) + { + set_tile_if(s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = seqlen_q_step + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + constexpr auto p_spans = decltype(p)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); + + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse); + else + p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse); + }); + }); + + if constexpr(FmhaDropout::IsDropout) + { + dropout.template Run( + seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window); + } + const auto p_gemm = [&]() { // dropout / type conversion + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [](const auto& x) { + return type_convert(x > 0.f ? x : 0.f); + }, + p); + } + else + { + return cast_tile(p); + } + }(); + + // STAGE 4, OGrad@V Gemm2 + dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); + + qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr); + qt_reg_tensor = load_tile_transpose(qt_lds_read_window); + + // STAGE 3, P^T@OGrad^T Gemm1 + auto pt_reg_tensor = make_static_distributed_tensor( + Policy::template MakePTRegSliceBlockDescriptor()); + pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer(); + gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); + } + block_sync_lds(); + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm12(); + __builtin_amdgcn_sched_barrier(0); + if constexpr(is_prologue) + { + store_tile(lse_lds_write_window, lse_block_tile); + store_tile(d_lds_write_window, d_block_tile); + } + if constexpr(is_epilogue) + { + // STAGE 5, P^T(PGrad^T - D) + constexpr auto ds_spans = decltype(ds)::get_distributed_spans(); + sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = p[i_j_idx] >= 0; + ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag + ? (dp_acc[i_j_idx] - d[i_idx]) + : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbias = [&]() { + if constexpr(FmhaDropout::IsDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + ds); + } + else + { + return cast_tile(ds); + } + }(); + store_tile(bias_lds_write_window, dbias); + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); + auto dbias_tile = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbias_tile, shuffled_dbias_tile); + store_tile(dbias_dram_window, dbias_tile); + move_tile_window(dbias_dram_window, {kM0, 0}); + __builtin_amdgcn_sched_barrier(0); + } + } + if constexpr(is_epilogue) + { + // STAGE 6, SGrad^T@Q^T Gemm3 + const auto ds_gemm = cast_tile(ds); + auto dst_reg_tensor = make_static_distributed_tensor( + Policy::template MakeSGradTRegSliceBlockDescriptor()); + dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); + gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + + store_tile(ds_lds_window, ds_gemm); + } + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + if constexpr(is_prologue) + { + q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next); + q_reg_tensor = load_tile(q_lds_read_window); + lse = load_tile(lse_lds_read_window); + } + if constexpr(is_epilogue) + { + ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + move_tile_window(ds_lds_read_window, {kK4, 0}); + } + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm3(); + __builtin_amdgcn_sched_barrier(0); + if constexpr(is_epilogue) + { + // STAGE7 SGrad@K^T Gemm4 + clear_tile(dq_acc); + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + move_tile_window(ds_lds_read_window, {kK4, 0}); + } + auto kt_reg_tensor_slice = get_slice_tile( // + kt_reg_tensor, + sequence<0, i_k4 * kK4>{}, + sequence{}); + gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice); + + if constexpr(i_k4 < k4_loops - 1) + { + ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer(); + } + }); + move_tile_window(ds_lds_read_window, {-kN0, 0}); + } + block_sync_lds(); + if constexpr(is_prologue) + { + do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next); + do_reg_tensor = load_tile(do_lds_read_window); + d = load_tile(d_lds_read_window); + } + if constexpr(is_main_body) + Policy::template HotLoopScheduler::SchedulerGemm4(); + if constexpr(is_epilogue) + { + // QGrad Scale + if constexpr(FmhaDropout::IsDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + if constexpr(kIsDeterministic) + { + store_tile(dq_dram_window, dq_acc); + } + else + { + update_tile(dq_dram_window, dq_acc); + } + move_tile_window(dq_dram_window, {kM0, 0}); + } + i_total_bodys += 1; + }; + + main_body(std::true_type{}, std::false_type{}); + // Hot loop + if(num_total_loop > 1) + { + do + { + main_body(std::true_type{}, std::true_type{}); + i_total_loops += 1; + seqlen_q_step += kM0; + } while(i_total_loops < num_total_loop - 1); + } + main_body(std::false_type{}, std::true_type{}); + + // Results Scale + if constexpr(FmhaDropout::IsDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + + return make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index c86d2525bc..68ead7c765 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -64,7 +64,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() + CK_TILE_DEVICE static constexpr auto GetPTOGradTBlockGemm() { using GemmProblem = BlockGemmProblem>; - using WarpGemm = WarpGemmDispatcher{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), - true>; + using WarpGemm = + WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, // SwizzleAccess + false, // UseStructuredSparsity + (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) + ? WGAttrNumAccessEnum ::Double + : WGAttrNumAccessEnum ::Single>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy>; - using WarpGemm = WarpGemmDispatcher{}), - Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}), - true>; + using WarpGemm = + WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}), + true, + false, // SwizzleAccess + false, // UseStructuredSparsity + (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) + ? WGAttrNumAccessEnum ::Double + : WGAttrNumAccessEnum ::Single>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy struct BlockFmhaBwdPipelineProblem { @@ -53,6 +54,7 @@ struct BlockFmhaBwdPipelineProblem static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsDeterministic = kIsDeterministic_; + static constexpr bool kUseTrLoad = kUseTrLoad_; // attributes from traits static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp new file mode 100644 index 0000000000..6cef1db730 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp @@ -0,0 +1,1220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +#include "ck_tile/core/utility/debug.hpp" + +namespace ck_tile { + +struct BlockFmhaBwdPipelineTrLoadDefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + + constexpr auto SwizzleA = false; + using WarpGemm = WarpGemmMfmaDispatcher< // + typename Problem::QDataType, + typename Problem::KDataType, + typename Problem::AccDataType, + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), + false, + SwizzleA>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() + { + return BlockFmhaBwdPipelineDefaultPolicy::GetPTOGradTBlockGemm(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm2BlockWarps, + typename Problem::BlockFmhaShape::Gemm2WarpTile>>; + + using WarpGemm = WarpGemmMfmaDispatcher< + typename Problem::OGradDataType, + typename Problem::VDataType, + typename Problem::AccDataType, + Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}), + false, + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() + { + return BlockFmhaBwdPipelineDefaultPolicy::GetSGradTQTBlockGemm(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() + { + using BlockFmhaShape = typename Problem::BlockFmhaShape; + using GemmProblem = BlockGemmProblem< + typename Problem::GemmDataType, + typename Problem::KDataType, + typename Problem::AccDataType, + Problem::kBlockSize, + TileGemmShape< + sequence, + typename BlockFmhaShape::Gemm4BlockWarps, + typename BlockFmhaShape::Gemm4WarpTile>>; + + using WarpGemm = WarpGemmMfmaDispatcher{}), + BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), + BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), + false, + false, + false, + WGAttrNumAccessEnum::Double>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + // these are for global load + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentX() noexcept + { + return 16 / sizeof(T); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + return GetAlignmentX(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + return GetAlignmentX(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + return GetAlignmentX(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + return GetAlignmentX(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOGrad() + { + return GetAlignmentX(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias() + { + return GetAlignmentX(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentKGrad() + { + return GetAlignmentX(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentVGrad() + { + return GetAlignmentX(); + } + + // these are for load_tr_b64 + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentX() noexcept + { + return 8 / sizeof(T); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentQ() noexcept + { + return GetTransposedAlignmentX(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + return total_pixels / GetAlignmentOGrad(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentBias() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + + return total_pixels / GetAlignmentBias(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentPostQGradAcc() + { + using AccDataType = remove_cvref_t; + return 16 / sizeof(AccDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentPostQGrad() + { + return GetAlignmentPostQGradAcc(); + } + + // It is found that alignment of 8x dwordx4 can avoid bank conflicts for both transposed and + // non-transposed load + static constexpr index_t WarpAlignmentBytes = 128; + + // As load_lds requires contiguous LDS write, we need to transform the distribution of DRAM for + // reading + template + CK_TILE_HOST_DEVICE static constexpr auto TransformXDramTensorView(const TensorView& naive_view) + { + if constexpr(std::is_same_v) + { + return naive_view; + } + else + { + const auto transformed_desc = + TransformXDramDescriptor(naive_view.get_tensor_descriptor()); + return tensor_view, + TensorView::DstInMemOp>{naive_view.buf_, transformed_desc}; + } + } + template + CK_TILE_HOST_DEVICE static constexpr auto + TransformXDramDescriptor(const tensor_descriptor& from_desc) + { + using from_desc_t = tensor_descriptor; + + constexpr auto ndims = from_desc_t::get_num_of_dimension(); + static_assert(ndims == 2, "XDram descriptor must have 2 dimensions"); + const auto Rows = from_desc.get_length(number<0>{}); + // constexpr auto Cols = 128; + // assert(from_desc.get_length(number<1>{}) == 128); + const auto Cols = from_desc.get_length(number<1>{}); + + constexpr index_t Dwordx4Bytes = 16; + constexpr index_t K2 = Dwordx4Bytes / sizeof(T); + constexpr index_t K1 = WarpAlignmentBytes / Dwordx4Bytes; + const index_t K0 = Cols / K1; + const auto ColLens = make_tuple(K0, number{}, number{}); + + const auto desc_tmp1 = transform_tensor_descriptor( + from_desc, + make_tuple(make_pass_through_transform(Rows), make_unmerge_transform(ColLens)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2, 3>{})); + + const auto desc_tmp2 = transform_tensor_descriptor( + desc_tmp1, + make_tuple(make_xor_transform(make_tuple(Rows, number{})), + make_pass_through_transform(K0), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + return transform_tensor_descriptor( + desc_tmp2, + make_tuple(make_pass_through_transform(Rows), + make_merge_transform_v3_division_mod(ColLens)), + make_tuple(sequence<0>{}, sequence<1, 2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kWarps = kBlockSize / get_warp_size(); + + constexpr index_t K2 = GetAlignmentK(); + constexpr index_t K1 = WarpAlignmentBytes / sizeof(T) / K2; + constexpr index_t K0 = ColsPerBlock / K1 / K2; + static_assert((K0 * K1 * K2 == ColsPerBlock) && K1 * K2 * sizeof(T) == WarpAlignmentBytes, + "ColsPerBlock notdivisible"); + + constexpr index_t N2 = get_warp_size() / K1; + constexpr index_t N1 = kWarps / K0; + constexpr index_t N0 = RowsPerBlock / N1 / N2; + static_assert((N0 * N1 * N2 == RowsPerBlock) && (K0 * N1 == kWarps) && + (K1 * N2 == get_warp_size()), + "RowsPerBlock not divisible"); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, // K0 N1, N2 K1 + tuple, sequence<2, 1>>, + sequence<1, 2>, // N0 K2 + sequence<0, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() + { + return MakeXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution() + { + return MakeXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + return MakeXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution() + { + return MakeXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution() + { + return BlockFmhaBwdPipelineDefaultPolicy::MakeLSEDDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTileDistribution() + { + return BlockFmhaBwdPipelineDefaultPolicy::MakeBiasTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution() + { + constexpr index_t K1 = 16 / sizeof(DataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = 1; + constexpr index_t M1 = get_warp_size(); + constexpr index_t M0 = MPerBlock / M1; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1>>, + tuple, sequence<1>>, + sequence<1, 2, 2>, + sequence<2, 0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreODramTileDistribution() + { + using ODataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kKPerBlock = Problem::kVHeaddim; + + return MakePreXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreOGradDramTileDistribution() + { + using OGradDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kKPerBlock = Problem::kVHeaddim; + + return MakePreXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDramTileDistribution() + { + using AccDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kKPerBlock = Problem::kQKHeaddim; + + constexpr index_t K1 = 16 / sizeof(AccDataType); + constexpr index_t K0 = kKPerBlock / K1; + + constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M1 * M2); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence, sequence>, + tuple, sequence<2, 3>>, + tuple, sequence<2, 0>>, + sequence<1, 2, 3>, + sequence<0, 0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradDramTileDistribution() + { + using AccDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kKPerBlock = Problem::kQKHeaddim; + + constexpr index_t K1 = 16 / sizeof(AccDataType); + constexpr index_t K0 = kKPerBlock / K1; + + constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M1 * M2); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor() + { + return BlockFmhaBwdPipelineDefaultPolicy::MakeKRegBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor() + { + return BlockFmhaBwdPipelineDefaultPolicy::MakeVRegBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKTRegBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto kt_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, // 2 4, 4 + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + auto output = + make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(kt_block_dstr_encode), + typename Problem::KDataType>::TransposedDstrEncode{}); + return output; + } + + // lds write descriptor used together with block_sync_lds (transformed dram descriptor) + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsWriteBlockDescriptor() + { + constexpr index_t KPack = WarpAlignmentBytes / sizeof(T); + + constexpr auto desc_0 = make_naive_tensor_descriptor_packed( + make_tuple(number{}, number{}, number{})); + return transform_tensor_descriptor( + desc_0, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor() + { + return MakeXLdsWriteBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor() + { + return MakeXLdsWriteBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsWriteBlockDescriptor() + { + return MakeXLdsWriteBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsWriteBlockDescriptor() + { + return MakeXLdsWriteBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsWriteBlockDescriptor() + { + return MakeXLdsWriteBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor() + { + // SGrad should be of the same distr as Gemm2 OGradV's output (i.e. PGrad) + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t M2 = WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane; + constexpr index_t M1 = WarpGemm::WarpGemmAttribute::Impl::kCMLane; + static_assert(WarpGemm::WarpGemmAttribute::Impl::kCM0PerLane == 1, "kCM0PerLane must be 1"); + constexpr index_t M0 = kMPerBlock / (M1 * M2); + + constexpr index_t N1 = WarpGemm::WarpGemmAttribute::Impl::kCNLane; + constexpr index_t N0 = kNPerBlock / N1; + + constexpr auto desc_0 = make_naive_tensor_descriptor_packed( + make_tuple(number{}, number{}, number{}, number{}, number{})); + + constexpr index_t M1_0 = 2, M1_1 = 2; + constexpr index_t N1_0 = 2, N1_1 = 8; + static_assert(M1_0 * M1_1 == M1, "M1_0 * M1_1 must equal M1"); + static_assert(N1_0 * N1_1 == N1, "N1_0 * N1_1 must equal N1"); + + constexpr auto desc_1 = transform_tensor_descriptor( + desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4, 5>{}, sequence<6>{})); + constexpr auto desc_2 = transform_tensor_descriptor( + desc_1, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2, 4>{}, + sequence<3>{}, + sequence<5>{}, + sequence<6>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2, 4>{}, + sequence<3>{}, + sequence<5>{}, + sequence<6>{})); + + constexpr auto top_dims = []() { + if constexpr(Transposed) + return make_tuple(sequence<1>{}, sequence<0>{}); + else + return make_tuple(sequence<0>{}, sequence<1>{}); + }(); + return transform_tensor_descriptor( + desc_2, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 2, 3, 6>{}, sequence<1, 4, 5>{}), + top_dims); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsReadBlockDescriptor() + { + const auto Dwordx4Bytes = 16; + const auto K2 = Dwordx4Bytes / sizeof(T); + const auto K1 = WarpAlignmentBytes / Dwordx4Bytes; + const auto K0 = KPerBlock / (K1 * K2); + + constexpr auto desc_0 = make_naive_tensor_descriptor_packed( + make_tuple(number{}, number{}, number{}, number{})); + constexpr auto desc_1 = transform_tensor_descriptor( + desc_0, + make_tuple(make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + return transform_tensor_descriptor( + desc_1, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsReadBlockDescriptor() + { + return MakeXLdsReadBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsReadBlockDescriptor() + { + return MakeXLdsReadBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsReadBlockDescriptor() + { + return MakeXLdsReadBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsReadBlockDescriptor() + { + return MakeXLdsReadBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsReadBlockDescriptor() + { + return MakeXLdsReadBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto q_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); + + return q_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQTRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto qt_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + qt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(qt_block_dstr_encode), + typename Problem::QDataType>::TransposedDstrEncode{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSGradTRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto dst_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + dst_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto dst_block_dstr = make_static_tile_distribution(dst_block_dstr_encode); + + return dst_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsWriteBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + using LSEDType = remove_cvref_t; + constexpr index_t kMPack = 16 / sizeof(LSEDType); + + constexpr auto lsed_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(number{}), + make_tuple(number<1>{}), + number{}, + number<1>{}); + + return lsed_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsReadBlockDescriptor() + { + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + + constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane; + constexpr index_t N0 = NWarp; + + // M4 *2 and M2 /2 when swizzle mode enabled + constexpr index_t SwizzleConfig = WG::kM == 16 ? 1 : 2; + // constexpr index_t SwizzleConfig = 1; + constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * SwizzleConfig; + constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane; + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / SwizzleConfig; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple>, + tuple, sequence<1, 0>>, + tuple, sequence<3, 1>>, + sequence<1, 1, 1>, + sequence<0, 2, 4>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto do_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + do_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto do_block_dstr = make_static_tile_distribution(do_block_dstr_encode); + + return do_block_dstr; + } + + template + CK_TILE_DEVICE static constexpr auto MakeOGradTRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + // constexpr index_t kNPerBlock = 32; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto dot_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + dot_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + // CK_PRINT(); + // CK_PRINT(); + + return make_static_tile_distribution( + typename InputTileDistributionTraits< + decltype(dot_block_dstr_encode), + typename Problem::OGradDataType>::TransposedDstrEncode{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakePTRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto pt_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + pt_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + constexpr auto pt_block_dstr = make_static_tile_distribution(pt_block_dstr_encode); + + return pt_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSGradRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK4; + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto ds_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + ds_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return make_static_tile_distribution( + typename InputTileDistributionTraits< + decltype(ds_block_dstr_encode), + typename Problem::GemmDataType>::TransposedDstrEncode{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t N1 = GetAlignmentBias(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t M2 = GetTransposedAlignmentBias(); + constexpr index_t M1 = get_warp_size() / N0; + constexpr index_t M0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<2, 1>, + sequence<1, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasSTileDistribution() + { + using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile()); + return c_block_tensor_type::get_tile_distribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeQ() + { + return sizeof(typename Problem::QDataType) * + MakeQLdsWriteBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeK() + { + return sizeof(typename Problem::KDataType) * + MakeKLdsWriteBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeLSE() + { + return sizeof(typename Problem::LSEDataType) * + MakeLSEDLdsWriteBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeD() + { + return sizeof(typename Problem::DDataType) * + MakeLSEDLdsWriteBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeV() + { + return sizeof(typename Problem::VDataType) * + MakeVLdsWriteBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeOGrad() + { + return sizeof(typename Problem::OGradDataType) * + MakeOGradLdsWriteBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeSGrad() + { + return sizeof(typename Problem::GemmDataType) * + MakeSGradLdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeBias() + { + if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return sizeof(typename Problem::BiasDataType) * + MakeBiasLdsWriteBlockDescriptor().get_element_space_size(); + else + return 0; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_q = GetSmemSizeQ(); + constexpr index_t smem_size_lse = GetSmemSizeLSE(); + constexpr index_t smem_size_k = GetSmemSizeK(); + constexpr index_t smem_size_v = GetSmemSizeV(); + constexpr index_t smem_size_do = GetSmemSizeOGrad(); + constexpr index_t smem_size_d = GetSmemSizeD(); + constexpr index_t smem_size_ds = GetSmemSizeSGrad(); + constexpr index_t smem_size_bias = GetSmemSizeBias(); + + constexpr index_t smem_size_stage0 = smem_size_k + smem_size_v; + constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 + smem_size_lse + + smem_size_d + max(smem_size_bias, smem_size_ds); + return max(smem_size_stage0, smem_size_stage1); + } + + template + class HotLoopScheduler + { + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0; + static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0; + static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim; + static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0; + static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2; + static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4; + + static constexpr index_t WarpGemmM = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + static constexpr index_t WarpGemmN = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}); + static constexpr index_t WarpGemmK = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}); + static constexpr index_t Gemm4MWarp = + Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); + static constexpr index_t Gemm4NWarp = + Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); + + static constexpr index_t blockWarps = kBlockSize / get_warp_size(); + using GemmDataType = typename Problem::GemmDataType; + + // Compute + static constexpr index_t Gemm0MFMA = + kM0 * kN0 * kK0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm1MFMA = + kN0 * kVHeaddim * kM0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm2MFMA = + kM0 * kN0 * kK2 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm3MFMA = + kN0 * kQKHeaddim * kM0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm4MFMA = + kM0 * kQKHeaddim * kN0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK); + + // VMEM + static constexpr index_t Q_VMEM_READ = + kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ(); + static constexpr index_t OGrad_VMEM_READ = + kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); + static constexpr index_t LSE_VMEM_READ = 1; + static constexpr index_t D_VMEM_READ = 1; + + // LDS Read + static constexpr index_t OGradT_LDS_READ = + kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad(); + static constexpr index_t QT_LDS_READ = + kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ(); + static constexpr index_t SGradT_LDS_READ_P1 = + kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetTransposedAlignmentX(); + static constexpr index_t SGradT_LDS_READ_P2 = + kM0 * kN0 / (get_warp_size() * Gemm4MWarp) / GetTransposedAlignmentX() - + SGradT_LDS_READ_P1; + static constexpr index_t Q_LDS_READ = + kM0 * kK0 / get_warp_size() / GetAlignmentQ(); + static constexpr index_t LSE_LDS_READ = kM0 / (4 * 4); + static constexpr index_t D_LDS_READ = LSE_LDS_READ; + static constexpr index_t OGrad_LDS_READ = + kM0 * kK2 / kBlockSize / GetAlignmentOGrad(); + + // LDS Write + static constexpr index_t Q_LDS_WRITE = + kM0 * kQKHeaddim / Problem::kBlockSize / GetAlignmentQ(); + static constexpr index_t QT_LDS_WRITE = + kM0 * kQKHeaddim / kBlockSize / GetTransposedAlignmentQ(); + static constexpr index_t OGrad_LDS_WRITE = + kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); + static constexpr index_t OGradT_LDS_WRITE = + kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad(); + static constexpr index_t LSE_LDS_WRITE = 1; + static constexpr index_t D_LDS_WRITE = 1; + static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize; + + public: + CK_TILE_DEVICE static constexpr void SchedulerGemm0() + { + // Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load + // Comp: Q x K + constexpr index_t VMEM_READ_INST = + Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ; + constexpr index_t MFMA_INST = Gemm0MFMA; + constexpr index_t LDS_READ_INST = OGradT_LDS_READ; + + constexpr index_t lcm_inst = lcm(VMEM_READ_INST, MFMA_INST, LDS_READ_INST); + static_for<0, lcm_inst, 1>{}([&](auto i) { + if constexpr(i % (lcm_inst / VMEM_READ_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + if constexpr(i % (lcm_inst / MFMA_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(i % (lcm_inst / LDS_READ_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + } + + CK_TILE_DEVICE static constexpr void SchedulerGemm12() + { + // Mem: Q^T LDS load + // Comp: PT x OGrad + constexpr index_t LDS_READ_INST = QT_LDS_READ; + constexpr index_t MFMA_INST = Gemm1MFMA + Gemm2MFMA; + + constexpr index_t lcm_inst = lcm(MFMA_INST, LDS_READ_INST); + static_for<0, lcm_inst, 1>{}([&](auto i) { + if constexpr(i % (lcm_inst / MFMA_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(i % (lcm_inst / LDS_READ_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // VMEM read + }); + } + + CK_TILE_DEVICE static constexpr void SchedulerGemm3() + { + // Mem: LSE/D LDS store, SGradT LDS store, SGrad, Q, LSE LDS load. + // Comp: SGradT x QT + constexpr index_t LDS_WRITE_INST = LSE_LDS_WRITE + D_LDS_WRITE + SGradT_LDS_WRITE; + constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ; + constexpr index_t MFMA_INST = Gemm3MFMA; + + constexpr index_t lds_rw_inst = LDS_WRITE_INST + LDS_READ_INST; + constexpr index_t lcm_inst = lcm(MFMA_INST, lds_rw_inst); + + static_for<0, lcm_inst, 1>{}([&](auto i) { + if constexpr(i % (lcm_inst / MFMA_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(i % (lcm_inst / lds_rw_inst) == 0) + { + if constexpr(i / (lcm_inst / lds_rw_inst) < LDS_WRITE_INST) + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + else + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS Read + } + }); + } + + CK_TILE_DEVICE static constexpr void SchedulerGemm4() + { + // Mem: SGrad, OGrad, D LDS load. + // Comp: SGrad x KT + constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ; + constexpr index_t MFMA_INST = Gemm4MFMA; + + constexpr index_t lcm_inst = lcm(MFMA_INST, LDS_READ_INST); + static_for<0, lcm_inst, 1>{}([&](auto i) { + if constexpr(i % (lcm_inst / MFMA_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(i % (lcm_inst / LDS_READ_INST) == 0) + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + } + }; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index 28d8b3eead..4652e5f20f 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -11,7 +11,9 @@ namespace ck_tile { // A is block distributed tensor // B is block distributed tensor // C is block distributed tensor -template +template struct BlockGemmARegBRegCRegV1 { private: @@ -44,8 +46,9 @@ struct BlockGemmARegBRegCRegV1 }; public: - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + static constexpr bool TransposeC = TransposeC_; using Traits = GemmTraits_; @@ -131,6 +134,7 @@ struct BlockGemmARegBRegCRegV1 CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() { + using c_distr_ys_major = std::conditional_t, sequence<1, 2>>; if constexpr(UseDefaultScheduler) { constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< @@ -138,7 +142,7 @@ struct BlockGemmARegBRegCRegV1 tuple, sequence>, tuple<>, tuple<>, - sequence<1, 2>, + c_distr_ys_major, sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); @@ -152,7 +156,7 @@ struct BlockGemmARegBRegCRegV1 tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + c_distr_ys_major, sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); @@ -172,25 +176,19 @@ struct BlockGemmARegBRegCRegV1 std::is_same_v>, "wrong!"); - constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode(); - - constexpr auto b_block_dstr_encode = MakeBBlockDistributionEncode(); - - constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode(); - // check ABC-block-distribution static_assert( - std::is_same_v, + std::is_same_v, remove_cvref_t>, "A distribution is wrong!"); static_assert( - std::is_same_v, + std::is_same_v, remove_cvref_t>, "B distribution is wrong!"); static_assert( - std::is_same_v, + std::is_same_v, remove_cvref_t>, "C distribution is wrong!"); @@ -219,7 +217,6 @@ struct BlockGemmARegBRegCRegV1 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A Block window AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( merge_sequences(sequence{}, a_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); @@ -227,16 +224,16 @@ struct BlockGemmARegBRegCRegV1 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B block tensor BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); // read C warp tensor from C block tensor + using c_iter_idx = std:: + conditional_t, sequence>; CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM @@ -244,7 +241,7 @@ struct BlockGemmARegBRegCRegV1 // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); }); @@ -254,6 +251,7 @@ struct BlockGemmARegBRegCRegV1 CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { + using c_distr_ys_major = std::conditional_t, sequence<1, 2>>; if constexpr(UseDefaultScheduler) { constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< @@ -261,7 +259,7 @@ struct BlockGemmARegBRegCRegV1 tuple, sequence>, tuple<>, tuple<>, - sequence<1, 2>, + c_distr_ys_major, sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( @@ -277,7 +275,7 @@ struct BlockGemmARegBRegCRegV1 tuple, sequence>, tuple>, tuple>, - sequence<1, 2>, + c_distr_ys_major, sequence<0, 0>>{}; constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index 849fa6c252..b72657b785 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -183,16 +183,7 @@ struct BlockReduce2dSync // pull data from remote lane const auto v_remote = warp_shuffle(v_local, src_lane); - - // For reduce, use combine_partial_results for operations that require it - if constexpr(ReduceFunc::requires_special_combine) - { - v_local = reduce_func.combine_partial_results(v_local, v_remote); - } - else - { - v_local = reduce_func(v_local, v_remote); - } + v_local = reduce_func(v_local, v_remote); }); } }); @@ -309,16 +300,7 @@ struct BlockReduce2dCrossWarpSync static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { constexpr auto i_1 = number{}; const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; - - // For reduce, use combine_partial_results for operations that require it - if constexpr(ReduceFunc::requires_special_combine) - { - v_local = reduce_func.combine_partial_results(v_local, v_remote); - } - else - { - v_local = reduce_func(v_local, v_remote); - } + v_local = reduce_func(v_local, v_remote); }); y_tensor.get_thread_buffer()(i_0) = v_local; diff --git a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp index f65487ea6e..0cae4023b7 100644 --- a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp +++ b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp @@ -189,7 +189,9 @@ struct Reduce /// @note Requirements: /// - y_continous_dim % ThreadTile_N == 0 (for proper thread distribution) /// - input_strides[-1] == 1 (for contiguous memory access) - CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, auto input_strides) + template + CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, + InputStrides input_strides) { using S = typename Problem::BlockShape; diff --git a/test/ck_tile/reduce/test_reduce2d.cpp b/test/ck_tile/reduce/test_reduce2d.cpp index 4ce0b56ef3..821d0a6c3e 100644 --- a/test/ck_tile/reduce/test_reduce2d.cpp +++ b/test/ck_tile/reduce/test_reduce2d.cpp @@ -308,20 +308,8 @@ using TestConfig_F32_Max = std::tuple; -using TestConfig_F32_SquareAdd = std::tuple; - -using TestTypes = ::testing::Types; +using TestTypes = ::testing:: + Types; TYPED_TEST_SUITE(TestCkTileReduce, TestTypes);