diff --git a/Jenkinsfile b/Jenkinsfile index 1e16b2f6f0..22468401dc 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -320,7 +320,7 @@ def cmake_build(Map conf=[:]){ if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) { archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true } - if (params.RUN_CK_TILE_TESTS){ + if (params.RUN_CK_TILE_FMHA_TESTS){ try{ archiveArtifacts "perf_fmha_fwd_*.log" archiveArtifacts "perf_fmha_bwd_*.log" @@ -371,7 +371,7 @@ def buildHipClangJob(Map conf=[:]){ def retimage (retimage, image) = getDockerImage(conf) - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 48, unit: 'HOURS') { @@ -426,7 +426,7 @@ def runCKProfiler(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { @@ -563,7 +563,7 @@ def Build_CK(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { @@ -668,7 +668,7 @@ def process_results(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) } @@ -682,7 +682,7 @@ def process_results(Map conf=[:]){ timeout(time: 1, unit: 'HOURS'){ try{ dir("script"){ - if (params.RUN_CK_TILE_TESTS){ + if (params.RUN_CK_TILE_FMHA_TESTS){ try{ unstash "perf_fmha_fwd_gfx942.log" unstash "perf_fmha_bwd_gfx942.log" @@ -838,7 +838,7 @@ pipeline { dbsshport = "${dbsshport}" dbsshuser = "${dbsshuser}" dbsshpassword = "${dbsshpassword}" - status_wrapper_creds = "${status_wrapper_creds}" + ck_git_creds = "${ck_git_creds}" gerrit_cred="${gerrit_cred}" DOCKER_BUILDKIT = "1" } diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index ce17874cab..fa1897e23b 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.8.1 +rocm-docs-core==1.8.2 sphinxcontrib-bibtex==2.6.3 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index a74b498f64..7d0c92d04f 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -103,7 +103,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core==1.8.1 +rocm-docs-core==1.8.2 # via -r requirements.in six==1.16.0 # via pybtex diff --git a/example/66_complex_contraction_bilinear/CMakeLists.txt b/example/66_complex_contraction_bilinear/CMakeLists.txt new file mode 100755 index 0000000000..c417caf8e7 --- /dev/null +++ b/example/66_complex_contraction_bilinear/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_complex_contraction_bilinear_xdl_fp32 complex_contraction_bilinear_xdl_fp32.cpp) +add_example_executable(example_complex_contraction_bilinear_xdl_fp64 complex_contraction_bilinear_xdl_fp64.cpp) + diff --git a/example/66_complex_contraction_bilinear/README.md b/example/66_complex_contraction_bilinear/README.md new file mode 100755 index 0000000000..04d92da0d2 --- /dev/null +++ b/example/66_complex_contraction_bilinear/README.md @@ -0,0 +1,11 @@ +# Instructions for ```example_complex_contraction_bilinear_xdl_fp32``` + +## Run +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +./bin/example_contraction_bilinear_xdl_fp32 1 1 1 +``` + + diff --git a/example/66_complex_contraction_bilinear/common_instances.hpp b/example/66_complex_contraction_bilinear/common_instances.hpp new file mode 100644 index 0000000000..480ca5a0af --- /dev/null +++ b/example/66_complex_contraction_bilinear/common_instances.hpp @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using F64 = double; + +template +using S = ck::Sequence; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Generic instances for fp32, fp16 and bf16 data types. +template +// clang-format off +using DeviceOpInstanceKK_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceKN_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMK_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMN_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +// Fp64 instances. +template +// clang-format off +using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceKN_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMK_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMN_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on diff --git a/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp new file mode 100755 index 0000000000..619279c47b --- /dev/null +++ b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_complex_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_complex_contraction_bilinear_example(argc, argv); } diff --git a/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp new file mode 100755 index 0000000000..f3528d0901 --- /dev/null +++ b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F64; +using BDataType = F64; +using AccDataType = F64; +using CShuffleDataType = F64; +using DDataType = F64; +using DsDataType = ck::Tuple; +using EDataType = F64; +using ComputeDataType = F64; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_complex_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_complex_contraction_bilinear_example(argc, argv); } diff --git a/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc b/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc new file mode 100755 index 0000000000..b548427548 --- /dev/null +++ b/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc @@ -0,0 +1,484 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +int run_complex_contraction_bilinear_example(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 28) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + alpha = std::stof(argv[26]); + beta = std::stof(argv[27]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg26 to 27: alpha, beta\n"); + exit(0); + } + + // For Real Part of Complex Tensor + Tensor a_ms_ks_re(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks_re(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns_re(d_ms_ns_lengths, d_ms_ns_strides); + + Tensor e_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result_re(e_ms_ns_lengths, e_ms_ns_strides); + + // For Imaginary Part of Complex Tensor + Tensor a_ms_ks_img(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks_img(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns_img(d_ms_ns_lengths, d_ms_ns_strides); + + Tensor e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides); + + // Intermediate E tensor Definition + Tensor e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks_re: " << a_ms_ks_re.mDesc << std::endl; + std::cout << "b_ns_ks_re: " << b_ns_ks_re.mDesc << std::endl; + std::cout << "d_ms_ns_re: " << d_ms_ns_re.mDesc << std::endl; + std::cout << "e_ms_ns_re: " << e_ms_ns_host_result_re.mDesc << std::endl; + + std::cout << "a_ms_ks_img: " << a_ms_ks_img.mDesc << std::endl; + std::cout << "b_ns_ks_img: " << b_ns_ks_img.mDesc << std::endl; + std::cout << "d_ms_ns_img: " << d_ms_ns_img.mDesc << std::endl; + std::cout << "e_ms_ns_img: " << e_ms_ns_host_result_img.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + + a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + + default: + a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + + a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + + break; + } + + DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf_re(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); + + DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); + + // Intermediate Value For E Real and Img + DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); + + + a_device_buf_re.ToDevice(a_ms_ks_re.mData.data()); + b_device_buf_re.ToDevice(b_ns_ks_re.mData.data()); + d_device_buf_re.ToDevice(d_ms_ns_re.mData.data()); + + a_device_buf_img.ToDevice(a_ms_ks_img.mData.data()); + b_device_buf_img.ToDevice(b_ns_ks_img.mData.data()); + d_device_buf_img.ToDevice(d_ms_ns_img.mData.data()); + + // set zero + e_device_buf_re.SetZero(); + e_device_buf_img.SetZero(); + + // set zero for intermediate values + e_device_buf_re1.SetZero(); + e_device_buf_img1.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + // For real Intermediate Value re_1 + + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument_re1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), + b_device_buf_re.GetDeviceBuffer(), + std::array{d_device_buf_re.GetDeviceBuffer()}, + e_device_buf_re1.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument_re1)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel}); + + + alpha = -1.f; + beta = 1.f; + + a_element_op = AElementOp{}; + b_element_op = BElementOp{}; + cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + // For real Intermediate Value re_2 + // auto op = DeviceOpInstance{}; + // auto invoker = op.MakeInvoker(); + auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), + b_device_buf_img.GetDeviceBuffer(), + std::array{e_device_buf_re1.GetDeviceBuffer()}, + e_device_buf_re.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument_re2)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel}); + + + alpha = 1.f; + beta = 1.f; + + a_element_op = AElementOp{}; + b_element_op = BElementOp{}; + cde_element_op = CDEElementOp{alpha, beta}; + + auto argument_img1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), + b_device_buf_img.GetDeviceBuffer(), + std::array{d_device_buf_img.GetDeviceBuffer()}, + e_device_buf_img1.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + + if(!op.IsSupportedArgument(argument_img1)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time_img1 = invoker.Run(argument_img1, StreamConfig{nullptr, time_kernel}); + + alpha = 1.f; + beta = 1.f; + + auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), + b_device_buf_re.GetDeviceBuffer(), + std::array{e_device_buf_img1.GetDeviceBuffer()}, + e_device_buf_img.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + + + if(!op.IsSupportedArgument(argument_img2)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel}); + + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K * 2; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2; + + float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1 ; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf_re.FromDevice(e_ms_ns_device_result_re.mData.data()); + e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data()); + + auto isRealOk = 0; + auto isImgOk = 0; + + if(do_verification) + { + // Real Part Verification + Tensor c_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_re1(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument_re = + ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument_re); + + alpha = 1.f; + beta = 1.f; + + cde_element_op = CDEElementOp{alpha, beta}; + + + for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result_re.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result_re.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result_re(m0, m1, n0, n1), + c_ms_ns_host_result_re(m0, m1, n0, n1), + d_ms_ns_re(m0, m1, n0, n1)); + } + } + } + } + + alpha = 1.f; + beta = -1.f; + + cde_element_op = CDEElementOp{alpha, beta}; + + auto ref_argument_re1 = + ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument_re1); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result_re.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result_re.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result_re(m0, m1, n0, n1), + e_ms_ns_host_result_re(m0, m1, n0, n1), + c_ms_ns_host_result_re1(m0, m1, n0, n1)); + } + } + } + } + + isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; + + + + + // Img Part Verification + Tensor c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + + auto ref_argument_img = + ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument_img); + + alpha = 1.f; + beta = 1.f; + + cde_element_op = CDEElementOp{alpha, beta}; + + for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result_img.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result_img.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result_img.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result_img(m0, m1, n0, n1), + c_ms_ns_host_result_img(m0, m1, n0, n1), + d_ms_ns_img(m0, m1, n0, n1)); + } + } + } + } + + auto ref_argument_img1 = + ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument_img1); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result_img.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result_img.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result_img.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result_img(m0, m1, n0, n1), + e_ms_ns_host_result_img(m0, m1, n0, n1), + c_ms_ns_host_result_img1(m0, m1, n0, n1)); + } + } + } + } + + isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; + + return (isRealOk && isImgOk); + } + + return 0; +} diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index efae4e284a..c2f554f6cc 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -99,13 +99,26 @@ auto create_args(int argc, char* argv[]) // different threshold for different dtype template -auto get_elimit(int /*init_method*/) +auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) { double rtol = 1e-2; double atol = 1e-2; return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) +{ + double rtol = 1e-2; + double atol = 1e-2; + if(hdim_q > 128 && hdim_v > 128) // 3.2 for RTZ/1.5 for RTN + { + rtol = 3.2e-2; + atol = 3.2e-2; + } + return ck_tile::make_tuple(rtol, atol); +} + template bool run(const ck_tile::ArgParser& arg_parser) { @@ -899,7 +912,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } // clang-format on - auto [rtol, atol] = get_elimit(init_method); + auto [rtol, atol] = get_elimit(hdim_q, hdim_v); bool dq_cur_pass = ck_tile::check_err(dq_host_result, dq_host_ref, std::string("Error: QGrad Incorrect results!"), diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 723546a452..b9cb9a1ec2 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -552,16 +552,33 @@ bool run(const ck_tile::ArgParser& arg_parser) } #endif - auto get_lengths = [&](bool permute, - ck_tile::index_t b /*batch*/, - ck_tile::index_t h /*nhead*/, - ck_tile::index_t s /*seqlen*/, - ck_tile::index_t d /*hdim*/) { - if(permute) - return std::array{b, h, s, d}; - else - return std::array{b, s, h, d}; - }; + struct + { + auto operator()(bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) + { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + } + + auto operator()(bool permute, + ck_tile::index_t ns /*num_splits*/, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) + { + if(permute) + return std::array{ns, b, h, s, d}; + else + return std::array{ns, b, s, h, d}; + } + } get_lengths; bool is_v_rowmajor = vlayout == std::string("r"); @@ -617,7 +634,7 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{1, 1, 1, 1}); ck_tile::HostTensor o_acc_host( 1 < num_splits || use_kvcache - ? std::array{num_splits, batch, nhead, max_seqlen_q, hdim_v} + ? get_lengths(o_perm, num_splits, shape_batch, nhead, shape_seqlen_q, hdim_v) : std::array{1, 1, 1, 1, 1}); // batch mode of lse data layout is [batch, nhead, seqlen_q] @@ -854,7 +871,7 @@ bool run(const ck_tile::ArgParser& arg_parser) }(); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); - const ck_tile::index_t stride_o_acc = hdim_v; + const ck_tile::index_t stride_o_acc = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); @@ -881,7 +898,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q; - const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o_acc = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); @@ -897,12 +914,12 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q); - const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); + const ck_tile::index_t split_stride_o_acc = (shape_batch * nhead * shape_seqlen_q * hdim_v); args.q_ptr = q_buf.GetDeviceBuffer(); args.k_ptr = k_buf.GetDeviceBuffer(); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 183475064a..5dcad7907f 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -398,10 +398,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.nhead_stride_bias, args.nhead_stride_lse_acc, args.nhead_stride_o_acc, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_lse_acc, - args.batch_stride_o_acc, + args.batch_stride_k, // only used for paged-kvcache + args.batch_stride_v, // only used for paged-kvcache args.split_stride_lse_acc, args.split_stride_o_acc, args.window_size_left, @@ -475,7 +473,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.lse_ptr, args.o_ptr, args.batch, - args.max_seqlen_q, args.seqstart_q_ptr, args.hdim_v, args.num_splits, @@ -486,7 +483,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.nhead_stride_o_acc, args.nhead_stride_lse, args.nhead_stride_o, - args.batch_stride_o_acc, args.split_stride_lse_acc, args.split_stride_o_acc); } @@ -497,7 +493,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.lse_ptr, args.o_ptr, args.batch, - args.max_seqlen_q, args.seqlen_q, args.hdim_v, args.num_splits, diff --git a/example/ck_tile/04_img2col/CMakeLists.txt b/example/ck_tile/04_img2col/CMakeLists.txt new file mode 100644 index 0000000000..3864c9ed9d --- /dev/null +++ b/example/ck_tile/04_img2col/CMakeLists.txt @@ -0,0 +1,3 @@ +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +add_executable(tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp) diff --git a/example/ck_tile/04_img2col/README.md b/example/ck_tile/04_img2col/README.md new file mode 100644 index 0000000000..6ae2cea5e5 --- /dev/null +++ b/example/ck_tile/04_img2col/README.md @@ -0,0 +1,12 @@ +# Image to Column + +This folder contains example for Image to Column using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_img2col -j +``` +This will result in an executable `build/bin/tile_example_img2col` diff --git a/example/ck_tile/04_img2col/image_to_column.cpp b/example/ck_tile/04_img2col/image_to_column.cpp new file mode 100644 index 0000000000..6380cd2994 --- /dev/null +++ b/example/ck_tile/04_img2col/image_to_column.cpp @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck_tile/host.hpp" +#include "image_to_column.hpp" + +// Host API implementation +template <> +float image_to_column(const image_to_column_traits& traits, + const image_to_column_args<2>& args, + const ck_tile::stream_config& stream_conf) +{ + if(traits.data_type.compare("fp16") == 0) + { + constexpr ck_tile::index_t NDimSpatial = 2; + constexpr ck_tile::index_t VectorSize = 8; + + using thread_tile = ck_tile::sequence<8, 8>; + using warp_tile = ck_tile::sequence<64, 64>; + using block_tile = ck_tile::sequence<128, 128>; + + using Shape = ck_tile::TileImageToColumnShape; + + using InDataType = ck_tile::half_t; + using OutDataType = ck_tile::half_t; + + using PipelineProblem = ck_tile::BlockImageToColumnProblem; + + using Kernel = ck_tile::ImageToColumn; + + auto kargs = Kernel::MakeKargs(args.p_in, + args.p_out, + args.G, + args.N, + args.C, + args.input_spatial_lengths, + args.filter_spatial_lengths, + args.output_spatial_lengths, + args.image_g_n_c_wis_strides, + args.gemm_g_m_k_strides, + args.conv_filter_strides, + args.conv_filter_dilations, + args.input_left_pads, + args.input_right_pads); + + const dim3 grids = Kernel::GridSize( + args.N * args.output_spatial_lengths[0] * args.output_spatial_lengths[1], + args.filter_spatial_lengths[0] * args.filter_spatial_lengths[1] * args.C, + args.G); + constexpr dim3 blocks = Kernel::BlockSize(); + + constexpr ck_tile::index_t kBlockPerCu = 2; + + float ave_time = ck_tile::launch_kernel( + stream_conf, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + + return 0; +} + +int main(int argc, char* argv[]) +{ + constexpr ck_tile::index_t NDimSpatial = 2; + + ExecutionConfig config; + ck_tile::conv::ConvParam conv_params = DefaultConvParams; + + if(!parse_cmd_args(argc, argv, config, conv_params)) + { + return EXIT_FAILURE; + } + + if(conv_params.num_dim_spatial_ != NDimSpatial) + { + std::cerr << "unsupported # of spatial dimensions" << std::endl; + return EXIT_FAILURE; + } + + using InDataType = ck_tile::half_t; + using OutDataType = ck_tile::half_t; + using ImLayout = ck_tile::tensor_layout::convolution::NHWGC; + + const auto G = conv_params.G_; + const auto N = conv_params.N_; + const auto C = conv_params.C_; + + const ck_tile::long_index_t NHoWo = + N * std::accumulate(conv_params.output_spatial_lengths_.begin(), + std::next(conv_params.output_spatial_lengths_.begin(), NDimSpatial), + 1, + std::multiplies<>()); + + const ck_tile::long_index_t CYX = + C * std::accumulate(conv_params.filter_spatial_lengths_.begin(), + std::next(conv_params.filter_spatial_lengths_.begin(), NDimSpatial), + 1, + std::multiplies<>()); + + const auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_params); + const auto out_desc = ck_tile::HostTensorDescriptor({G, NHoWo, CYX}); + + // host verify + ck_tile::HostTensor in(in_desc); + ck_tile::HostTensor out_device(out_desc); + ck_tile::HostTensor out_host(out_desc); + + switch(config.init_method) + { + case 0: break; + case 1: ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(in); break; + default: ck_tile::FillUniformDistribution{-0.5, 0.5}(in); break; + } + + ck_tile::DeviceMem in_device_buf(in.get_element_space_size_in_bytes()); + ck_tile::DeviceMem out_device_buf(out_device.get_element_space_size_in_bytes()); + + in_device_buf.ToDevice(in.data()); + + image_to_column_traits traits{"fp16"}; + + image_to_column_args args{ + in_device_buf.GetDeviceBuffer(), + out_device_buf.GetDeviceBuffer(), + G, + N, + C, + ck_tile::to_array(conv_params.input_spatial_lengths_), + ck_tile::to_array(conv_params.filter_spatial_lengths_), + ck_tile::to_array(conv_params.output_spatial_lengths_), + ck_tile::to_array(in_desc.get_strides()), + ck_tile::to_array(out_desc.get_strides()), + ck_tile::to_array(conv_params.conv_filter_strides_), + ck_tile::to_array(conv_params.conv_filter_dilations_), + ck_tile::to_array(conv_params.input_left_pads_), + ck_tile::to_array(conv_params.input_right_pads_)}; + + float ave_time = + image_to_column(traits, args, ck_tile::stream_config{nullptr, config.time_kernel}); + + std::size_t num_btype = G * NHoWo * CYX * (sizeof(OutDataType) + sizeof(InDataType)); + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(config.do_verification) + { + // reference + ck_tile::reference_im2col(in, out_host, conv_params); + + out_device_buf.FromDevice(out_device.data()); + pass = ck_tile::check_err(out_device, out_host); + + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; + } + + return !pass; +} diff --git a/example/ck_tile/04_img2col/image_to_column.hpp b/example/ck_tile/04_img2col/image_to_column.hpp new file mode 100644 index 0000000000..90484e08ec --- /dev/null +++ b/example/ck_tile/04_img2col/image_to_column.hpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/image_to_column.hpp" +#include + +#define DefaultConvParams \ + ck_tile::conv::ConvParam \ + { \ + 2, 2, 32, 32, 32, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \ + } + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +inline void print_help_msg() +{ + std::cerr << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck_tile::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +inline bool parse_cmd_args(int argc, + char* argv[], + ExecutionConfig& config, + ck_tile::conv::ConvParam& conv_params) +{ + constexpr int num_execution_config_args = + 3; // arguments for do_verification, init_method, time_kernel + constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ + + constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; + constexpr int threshold_to_catch_all_args = + threshold_to_catch_partial_args + num_conv_param_leading_args; + + if(argc == 1) + { + // use default + config = ExecutionConfig{}; + } + // catch only ExecutionConfig arguments + else if(argc == threshold_to_catch_partial_args) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + // catch both ExecutionConfig & ConvParam arguments + else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + const ck_tile::index_t num_dim_spatial = std::stoi(argv[4]); + conv_params = + ck_tile::conv::parse_conv_param(num_dim_spatial, threshold_to_catch_partial_args, argv); + } + else + { + print_help_msg(); + return false; + } + + return true; +} + +struct image_to_column_traits +{ + std::string data_type; +}; + +template +struct image_to_column_args +{ + const void* p_in; + void* p_out; + const ck_tile::long_index_t G; + const ck_tile::long_index_t N; + const ck_tile::long_index_t C; + const ck_tile::array input_spatial_lengths; + const ck_tile::array filter_spatial_lengths; + const ck_tile::array output_spatial_lengths; + const ck_tile::array image_g_n_c_wis_strides; + const ck_tile::array gemm_g_m_k_strides; + const ck_tile::array conv_filter_strides; + const ck_tile::array conv_filter_dilations; + const ck_tile::array input_left_pads; + const ck_tile::array input_right_pads; +}; + +// host API +template +float image_to_column(const image_to_column_traits&, + const image_to_column_args&, + const ck_tile::stream_config&); diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 3b4d1ca8be..fe1e9c9edf 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -5,3 +5,4 @@ include_directories(AFTER add_subdirectory(01_fmha) add_subdirectory(02_layernorm2d) add_subdirectory(03_gemm) +add_subdirectory(04_img2col) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp index 1121cc4550..438d7d8ac3 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 } template <> - __device__ static constexpr auto TailScheduler<1>() + __device__ constexpr auto TailScheduler<1>() { // schedule constexpr auto num_ds_read_inst = @@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 } template <> - __device__ static constexpr auto TailScheduler<2>() + __device__ constexpr auto TailScheduler<2>() { // schedule constexpr auto num_ds_read_inst = diff --git a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp index a184431648..409bb9f674 100644 --- a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp @@ -324,55 +324,55 @@ struct DppSelector static constexpr auto GetDpp(); template <> - static constexpr auto GetDpp() + constexpr auto GetDpp() { return DppInstr::dpp8_f16_8x32x2; } template <> - static constexpr auto GetDpp() + constexpr auto GetDpp() { return DppInstr::dpp8_f16_8x16x2; } template <> - static constexpr auto GetDpp() + constexpr auto GetDpp() { return DppInstr::dpp8_f16_16x16x2; } template <> - static constexpr auto GetDpp() + constexpr auto GetDpp() { return DppInstr::dpp8_f16_32x8x2; } template <> - static constexpr auto GetDpp() + constexpr auto GetDpp() { return DppInstr::dpp8_f16_1x32x2; } template <> - static constexpr auto GetDpp() + constexpr auto GetDpp() { return DppInstr::dpp8_f16_2x32x2; } template <> - static constexpr auto GetDpp() + constexpr auto GetDpp() { return DppInstr::dpp8_f16_2x16x2; } template <> - static constexpr auto GetDpp() + constexpr auto GetDpp() { return DppInstr::dpp8_f16_4x16x2; } template <> - static constexpr auto GetDpp() + constexpr auto GetDpp() { return DppInstr::dpp8_f16_4x32x2; } diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 9a9ebf5595..b435a2a129 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -415,7 +415,7 @@ struct WmmaSelector static constexpr auto GetWmma(); template <> - static constexpr auto GetWmma() + constexpr auto GetWmma() { #ifdef __gfx12__ return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; @@ -425,7 +425,7 @@ struct WmmaSelector } template <> - static constexpr auto GetWmma() + constexpr auto GetWmma() { #ifdef __gfx12__ return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; @@ -435,19 +435,19 @@ struct WmmaSelector } template <> - static constexpr auto GetWmma() + constexpr auto GetWmma() { return WmmaInstr::wmma_f16_16x16x16_f16; } template <> - static constexpr auto GetWmma() + constexpr auto GetWmma() { return WmmaInstr::wmma_bf16_16x16x16_bf16; } template <> - static constexpr auto GetWmma() + constexpr auto GetWmma() { #ifdef __gfx12__ return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; @@ -458,7 +458,7 @@ struct WmmaSelector #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> - static constexpr auto GetWmma() + constexpr auto GetWmma() { return WmmaInstr::wmma_i32_16x16x16_iu4; } diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 835075b7f2..24fac91e22 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -651,97 +651,97 @@ struct MfmaSelector static constexpr auto GetMfma(); template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f64_16x16x4f64; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x1xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x1xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x1xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_4x4x1xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_4x4x1xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x2xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x4xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x4f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x4f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x8f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x16f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x4f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_4x4x4f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_4x4x4f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_32x32x8bf16_1k; @@ -751,7 +751,7 @@ struct MfmaSelector } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_16x16x16bf16_1k; @@ -762,72 +762,72 @@ struct MfmaSelector #if defined(CK_USE_AMD_MFMA_GFX940) template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_i32_32x32x16i8; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_i32_16x16x32i8; } #else template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_i32_32x32x8i8; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_i32_16x16x16i8; } #endif template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16f8f8; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32f8f8; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16bf8bf8; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32bf8bf8; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16f8bf8; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32f8bf8; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16bf8f8; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32bf8f8; } diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp index c272b01f54..78768bbbfc 100644 --- a/include/ck_tile/core/container/array.hpp +++ b/include/ck_tile/core/container/array.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include +#include #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/integer.hpp" @@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array& a, const arr return !(a == b); } +template +CK_TILE_HOST_DEVICE constexpr auto to_array(const std::vector& x) +{ + array arr; + + static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; }); + + return arr; +} + template CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x) { diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index deebe90bf7..b382710b19 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -5,6 +5,8 @@ #include "ck_tile/host/arg_parser.hpp" #include "ck_tile/host/check_err.hpp" +#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp" +#include "ck_tile/host/convolution_parameter.hpp" #include "ck_tile/host/device_memory.hpp" #include "ck_tile/host/fill.hpp" #include "ck_tile/host/hip_check_error.hpp" diff --git a/include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp b/include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp new file mode 100644 index 0000000000..b7317fc04b --- /dev/null +++ b/include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { +namespace conv { +namespace detail { + +template +CK_TILE_HOST std::vector get_layout_transpose_gnchw_to_old() +{ + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + return {0, 1, 2, 3}; + } + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + return {0, 1, 2, 3, 4}; + } + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + return {0, 1, 2, 3, 4, 5}; + } + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + return {0, 1, 3, 2}; + } + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + return {0, 1, 4, 2, 3}; + } + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + return {0, 1, 5, 2, 3, 4}; + } + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + return {2, 0, 3, 1}; + } + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + return {3, 0, 4, 1, 2}; + } + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + return {4, 0, 5, 1, 2, 3}; + } + else + { + printf("%s\n", __func__); + throw std::runtime_error("wrong! unsupported layout"); + } +} + +} // namespace detail + +// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW +// regardless of physical layout +template +CK_TILE_HOST HostTensorDescriptor +make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param) +{ + std::vector physical_lengths; + + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.end(), + param.input_spatial_lengths_.begin(), + param.input_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.input_spatial_lengths_.begin(), + param.input_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + physical_lengths = std::vector{static_cast(param.N_), + static_cast(param.G_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 1, + param.input_spatial_lengths_.begin(), + param.input_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else + { + printf("%s\n", __func__); + printf("%s\n", InLayout::name); + throw std::runtime_error("wrong! unsupported layout"); + } + + return transpose_host_tensor_descriptor_given_new2old( + HostTensorDescriptor(physical_lengths), + detail::get_layout_transpose_gnchw_to_old()); +} + +// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX +// regardless of physical layout +template +CK_TILE_HOST HostTensorDescriptor +make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param) +{ + std::vector physical_lengths; + + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(param.G_ != 1) + { + throw std::runtime_error("wrong! G != 1"); + } + + physical_lengths = std::vector{static_cast(param.K_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.end(), + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.K_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.end(), + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.K_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + physical_lengths = std::vector{static_cast(param.K_), + static_cast(param.G_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 1, + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else + { + printf("%s\n", __func__); + printf("%s\n", WeiLayout::name); + throw std::runtime_error("wrong! unsupported layout"); + } + + return transpose_host_tensor_descriptor_given_new2old( + HostTensorDescriptor(physical_lengths), + detail::get_layout_transpose_gnchw_to_old()); +} + +// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW +// regardless of physical layout +template +CK_TILE_HOST HostTensorDescriptor +make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param) +{ + std::vector physical_lengths; + + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.K_)}; + + physical_lengths.insert(physical_lengths.end(), + param.output_spatial_lengths_.begin(), + param.output_spatial_lengths_.begin() + param.num_dim_spatial_); + } + // separate from legacy code above + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.K_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.output_spatial_lengths_.begin(), + param.output_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + physical_lengths = std::vector{static_cast(param.N_), + static_cast(param.G_), + static_cast(param.K_)}; + + physical_lengths.insert(physical_lengths.begin() + 1, + param.output_spatial_lengths_.begin(), + param.output_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else + { + printf("%s\n", __func__); + printf("%s\n", OutLayout::name); + throw std::runtime_error("wrong! unsupported layout"); + } + + return transpose_host_tensor_descriptor_given_new2old( + HostTensorDescriptor(physical_lengths), + detail::get_layout_transpose_gnchw_to_old()); +} + +} // namespace conv +} // namespace ck_tile diff --git a/include/ck_tile/host/convolution_parameter.hpp b/include/ck_tile/host/convolution_parameter.hpp new file mode 100644 index 0000000000..741a25ad73 --- /dev/null +++ b/include/ck_tile/host/convolution_parameter.hpp @@ -0,0 +1,283 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +namespace ck_tile { +namespace conv { + +struct ConvParam +{ + ConvParam(); + ConvParam(ck_tile::index_t n_dim, + ck_tile::index_t group_count, + ck_tile::index_t n_batch, + ck_tile::index_t n_out_channels, + ck_tile::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads) + : num_dim_spatial_(static_cast(n_dim)), + G_(static_cast(group_count)), + N_(static_cast(n_batch)), + K_(static_cast(n_out_channels)), + C_(static_cast(n_in_channels)), + filter_spatial_lengths_(num_dim_spatial_), + input_spatial_lengths_(num_dim_spatial_), + output_spatial_lengths_(num_dim_spatial_), + conv_filter_strides_(num_dim_spatial_), + conv_filter_dilations_(num_dim_spatial_), + input_left_pads_(num_dim_spatial_), + input_right_pads_(num_dim_spatial_) + { + if(static_cast(filter_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(input_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(conv_filter_strides_.size()) != num_dim_spatial_ || + static_cast(conv_filter_dilations_.size()) != num_dim_spatial_ || + static_cast(input_left_pads_.size()) != num_dim_spatial_ || + static_cast(input_right_pads_.size()) != num_dim_spatial_) + { + throw(std::runtime_error( + "ConvParam::ConvParam: " + "parameter size is different from number of declared dimensions!")); + } + + for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i) + { + filter_spatial_lengths_[i] = static_cast(filters_len[i]); + input_spatial_lengths_[i] = static_cast(input_len[i]); + conv_filter_strides_[i] = static_cast(strides[i]); + conv_filter_dilations_[i] = static_cast(dilations[i]); + input_left_pads_[i] = static_cast(left_pads[i]); + input_right_pads_[i] = static_cast(right_pads[i]); + + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck_tile::long_index_t x_eff = + (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; + + output_spatial_lengths_[i] = + (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / + conv_filter_strides_[i] + + 1; + } + } + + ConvParam(ck_tile::long_index_t n_dim, + ck_tile::long_index_t group_count, + ck_tile::long_index_t n_batch, + ck_tile::long_index_t n_out_channels, + ck_tile::long_index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads) + : num_dim_spatial_(n_dim), + G_(group_count), + N_(n_batch), + K_(n_out_channels), + C_(n_in_channels), + filter_spatial_lengths_(filters_len), + input_spatial_lengths_(input_len), + output_spatial_lengths_(num_dim_spatial_), + conv_filter_strides_(strides), + conv_filter_dilations_(dilations), + input_left_pads_(left_pads), + input_right_pads_(right_pads) + { + if(static_cast(filter_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(input_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(conv_filter_strides_.size()) != num_dim_spatial_ || + static_cast(conv_filter_dilations_.size()) != num_dim_spatial_ || + static_cast(input_left_pads_.size()) != num_dim_spatial_ || + static_cast(input_right_pads_.size()) != num_dim_spatial_) + { + throw(std::runtime_error( + "ConvParam::ConvParam: " + "parameter size is different from number of declared dimensions!")); + } + + for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i) + { + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck_tile::long_index_t x_eff = + (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; + + output_spatial_lengths_[i] = + (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / + conv_filter_strides_[i] + + 1; + } + } + + ck_tile::long_index_t num_dim_spatial_; + ck_tile::long_index_t G_; + ck_tile::long_index_t N_; + ck_tile::long_index_t K_; + ck_tile::long_index_t C_; + + std::vector filter_spatial_lengths_; + std::vector input_spatial_lengths_; + std::vector output_spatial_lengths_; + + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + + std::vector input_left_pads_; + std::vector input_right_pads_; + + std::vector GetOutputSpatialLengths() const + { + return output_spatial_lengths_; + } + + std::size_t GetFlops() const + { + // 2 * G * N * K * C * * + return static_cast(2) * G_ * N_ * K_ * C_ * + std::accumulate(std::begin(output_spatial_lengths_), + std::next(std::begin(output_spatial_lengths_), num_dim_spatial_), + 1, + std::multiplies<>()) * + std::accumulate(std::begin(filter_spatial_lengths_), + std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_), + 1, + std::multiplies<>()); + } + + template + std::size_t GetInputByte() const + { + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * + (G_ * N_ * C_ * + std::accumulate(std::begin(input_spatial_lengths_), + std::next(std::begin(input_spatial_lengths_), num_dim_spatial_), + 1, + std::multiplies<>())); + } + + template + std::size_t GetWeightByte() const + { + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * + (G_ * K_ * C_ * + std::accumulate(std::begin(filter_spatial_lengths_), + std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_), + 1, + std::multiplies<>())); + } + + template + std::size_t GetOutputByte() const + { + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * (G_ * N_ * K_ * + std::accumulate(std::begin(output_spatial_lengths_), + std::end(output_spatial_lengths_), + static_cast(1), + std::multiplies())); + } + + template + std::size_t GetByte() const + { + return GetInputByte() + GetWeightByte() + + GetOutputByte(); + } +}; + +ConvParam::ConvParam() + : ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}) +{ +} + +CK_TILE_HOST std::string get_conv_param_parser_helper_msg() +{ + std::string msg; + + msg += "Following arguments (depending on number of spatial dims):\n" + " Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n" + " G, N, K, C, \n" + " , (ie Y, X for 2D)\n" + " , (ie Hi, Wi for 2D)\n" + " , (ie Sy, Sx for 2D)\n" + " , (ie Dy, Dx for 2D)\n" + " , (ie LeftPy, LeftPx for 2D)\n" + " , (ie RightPy, RightPx for 2D)\n"; + + return msg; +} + +CK_TILE_HOST ck_tile::conv::ConvParam +parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]) +{ + const ck_tile::long_index_t G = std::stol(argv[arg_idx++]); + const ck_tile::long_index_t N = std::stol(argv[arg_idx++]); + const ck_tile::long_index_t K = std::stol(argv[arg_idx++]); + const ck_tile::long_index_t C = std::stol(argv[arg_idx++]); + + std::vector filter_spatial_lengths(num_dim_spatial); + std::vector input_spatial_lengths(num_dim_spatial); + std::vector conv_filter_strides(num_dim_spatial); + std::vector conv_filter_dilations(num_dim_spatial); + std::vector input_left_pads(num_dim_spatial); + std::vector input_right_pads(num_dim_spatial); + + for(int i = 0; i < num_dim_spatial; ++i) + { + filter_spatial_lengths[i] = std::stol(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_spatial_lengths[i] = std::stol(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + conv_filter_strides[i] = std::stol(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + conv_filter_dilations[i] = std::stol(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_left_pads[i] = std::stol(argv[arg_idx++]); + } + + for(int i = 0; i < num_dim_spatial; ++i) + { + input_right_pads[i] = std::stol(argv[arg_idx++]); + } + + return ck_tile::conv::ConvParam{num_dim_spatial, + G, + N, + K, + C, + filter_spatial_lengths, + input_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; +} + +} // namespace conv +} // namespace ck_tile diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index 918abc69cc..f533d5c189 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -176,7 +176,20 @@ struct HostTensorDescriptor return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } - friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); + friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) + { + os << "dim " << desc.get_num_of_dimension() << ", "; + + os << "lengths {"; + LogRange(os, desc.get_lengths(), ", "); + os << "}, "; + + os << "strides {"; + LogRange(os, desc.get_strides(), ", "); + os << "}"; + + return os; + } private: std::vector mLens; diff --git a/include/ck_tile/host/reference/reference_im2col.hpp b/include/ck_tile/host/reference/reference_im2col.hpp index 410140daa6..392d6abd47 100644 --- a/include/ck_tile/host/reference/reference_im2col.hpp +++ b/include/ck_tile/host/reference/reference_im2col.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,53 +9,125 @@ namespace ck_tile { -template -CK_TILE_HOST void reference_im2col(HostTensor& in_mtx_host_ref, - const HostTensor& in_host, - int /*N*/, - int /*K*/, - int C, - int /*Y*/, - int X, - int Hi, - int Wi, - int Ho, - int Wo, - int ConvStrideH, - int ConvStrideW, - int ConvDilationH, - int ConvDilationW, - int InLeftPadH, - int InLeftPadW, - int /*InRightPadH*/, - int /*InRightPadW*/) +template +CK_TILE_HOST void reference_im2col(const HostTensor& in_host, + HostTensor& out_host, + const ck_tile::conv::ConvParam& conv_params) { - int GemmM = in_mtx_host_ref.get_lengths()[0]; - int GemmK = in_mtx_host_ref.get_lengths()[1]; + const long_index_t G = in_host.get_lengths()[0]; + const long_index_t N = in_host.get_lengths()[1]; + const long_index_t C = in_host.get_lengths()[2]; - for(int gemm_m = 0; gemm_m < GemmM; ++gemm_m) + if constexpr(NDimSpatial == 1) { - int mtmp = gemm_m; - int n = mtmp / (Ho * Wo); - mtmp -= n * Ho * Wo; - int ho = mtmp / Wo; - int wo = mtmp - ho * Wo; + const long_index_t Wo = conv_params.output_spatial_lengths_[0]; + auto func = [&](auto g, auto n, auto wo) { + long_index_t row = n * Wo + wo; + long_index_t column = 0; - for(int gemm_k = 0; gemm_k < GemmK; ++gemm_k) - { - int ktmp = gemm_k; - int y = ktmp / (X * C); - ktmp -= y * X * C; - int x = ktmp / C; - int c = ktmp - x * C; + for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[0]; ++x) + { + auto wi = static_cast(wo * conv_params.conv_filter_strides_[0]) + + static_cast(x * conv_params.conv_filter_dilations_[0]) - + static_cast(conv_params.input_left_pads_[0]); - int hi = y * ConvDilationH + ho * ConvStrideH - InLeftPadH; - int wi = x * ConvDilationW + wo * ConvStrideW - InLeftPadW; + for(long_index_t c = 0; c < C; ++c) + { + if(wi >= 0 && type_convert(wi) < in_host.get_lengths()[3]) + { + InDataType v_in = in_host(g, n, c, wi); + out_host(g, row, column) = type_convert(v_in); + } + column++; + } + } + }; - bool inbound = (hi >= 0 && hi < Hi && wi >= 0 && wi < Wi); + make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 2) + { + const long_index_t Ho = conv_params.output_spatial_lengths_[0]; + const long_index_t Wo = conv_params.output_spatial_lengths_[1]; - in_mtx_host_ref(gemm_m, gemm_k) = inbound ? in_host(n, hi, wi, c) : 0; - } + auto func = [&](auto g, auto n, auto ho, auto wo) { + long_index_t row = n * Ho * Wo + ho * Wo + wo; + long_index_t column = 0; + + for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[0]; ++y) + { + auto hi = static_cast(ho * conv_params.conv_filter_strides_[0]) + + static_cast(y * conv_params.conv_filter_dilations_[0]) - + static_cast(conv_params.input_left_pads_[0]); + + for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[1]; ++x) + { + auto wi = static_cast(wo * conv_params.conv_filter_strides_[1]) + + static_cast(x * conv_params.conv_filter_dilations_[1]) - + static_cast(conv_params.input_left_pads_[1]); + + for(long_index_t c = 0; c < C; ++c) + { + + if(hi >= 0 && type_convert(hi) < in_host.get_lengths()[3] && + wi >= 0 && type_convert(wi) < in_host.get_lengths()[4]) + { + InDataType v_in = in_host(g, n, c, hi, wi); + out_host(g, row, column) = type_convert(v_in); + } + column++; + } + } + } + }; + + make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 3) + { + const long_index_t Do = conv_params.output_spatial_lengths_[0]; + const long_index_t Ho = conv_params.output_spatial_lengths_[1]; + const long_index_t Wo = conv_params.output_spatial_lengths_[2]; + + auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) { + long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo; + long_index_t column = 0; + + for(long_index_t z = 0; z < conv_params.filter_spatial_lengths_[0]; ++z) + { + auto di = static_cast(d_o * conv_params.conv_filter_strides_[0]) + + static_cast(z * conv_params.conv_filter_dilations_[0]) - + static_cast(conv_params.input_left_pads_[0]); + for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[1]; ++y) + { + auto hi = static_cast(ho * conv_params.conv_filter_strides_[1]) + + static_cast(y * conv_params.conv_filter_dilations_[1]) - + static_cast(conv_params.input_left_pads_[1]); + for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[2]; ++x) + { + auto wi = + static_cast(wo * conv_params.conv_filter_strides_[2]) + + static_cast(x * conv_params.conv_filter_dilations_[2]) - + static_cast(conv_params.input_left_pads_[2]); + for(long_index_t c = 0; c < C; ++c) + { + if(di >= 0 && + type_convert(di) < in_host.get_lengths()[3] && + hi >= 0 && + type_convert(hi) < in_host.get_lengths()[4] && + wi >= 0 && type_convert(wi) < in_host.get_lengths()[5]) + { + InDataType v_in = in_host(g, n, c, di, hi, wi); + out_host(g, row, column) = type_convert(v_in); + } + column++; + } + } + } + } + }; + + make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(std::thread::hardware_concurrency()); } } } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index c022edf723..1569c93565 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask { auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); - const index_t x_per_split = ck_tile::max(1, x_total / num_splits); + const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); const index_t split_start = x_per_split * i_split; - const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split); + const index_t split_end = split_start + x_per_split; return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), ck_tile::min(origin_end, split_end)); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index e2c7db3e1b..ca9da91a5d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel void* o_ptr; ck_tile::index_t batch; - ck_tile::index_t max_seqlen_q; - ck_tile::index_t seqlen_q; ck_tile::index_t hdim_v; ck_tile::index_t num_splits; @@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o; - ck_tile::index_t batch_stride_o_acc; - ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_o_acc; }; @@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel std::conditional_t>, std::conditional_t> { - ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + ck_tile::index_t batch_stride_o; }; struct GroupModeKargs @@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel void* lse_ptr, void* o_ptr, ck_tile::index_t batch, - ck_tile::index_t max_seqlen_q, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, @@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel o_acc_ptr, o_ptr, batch, - max_seqlen_q, seqlen_q, hdim_v, num_splits, @@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel nhead_stride_lse_acc, nhead_stride_o_acc, nhead_stride_o, - batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg {}, // placeholder for lse {}, // placeholder for fp8_static_quant args - batch_stride_o, - batch_stride_lse_acc}; + batch_stride_lse_acc, + batch_stride_o_acc, + batch_stride_o}; if constexpr(kStoreLSE) { @@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel void* lse_ptr, void* o_ptr, ck_tile::index_t batch, - ck_tile::index_t max_seqlen_q, const void* seqstart_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t num_splits, @@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc) { @@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel o_acc_ptr, o_ptr, batch, - max_seqlen_q, -1, // seqlen will be updated by another pointer hdim_v, num_splits, @@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel nhead_stride_lse_acc, nhead_stride_o_acc, nhead_stride_o, - batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg {}, // placeholder for lse @@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel return kargs; } - __host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t hdim_v) { - return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); + return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - const long_index_t batch_offset_o_acc = - static_cast(i_batch) * kargs.batch_stride_o_acc; - long_index_t batch_offset_lse_acc = 0; + long_index_t batch_offset_o_acc = 0; long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; @@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel // get starting offset for each batch const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - batch_offset_o = query_start * kargs.row_stride_o; batch_offset_lse_acc = query_start; + batch_offset_o_acc = query_start * kargs.row_stride_o_acc; if constexpr(kStoreLSE) { batch_offset_lse = query_start; } + batch_offset_o = query_start * kargs.row_stride_o; + // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; @@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel } else { - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; batch_offset_lse_acc = static_cast(i_batch) * kargs.batch_stride_lse_acc; + batch_offset_o_acc = static_cast(i_batch) * kargs.batch_stride_o_acc; if constexpr(kStoreLSE) { batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; } + + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } // for simplicity, batch stride we just modify the pointer @@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel auto o_acc_dram = [&]() { const auto o_acc_dram_naive = make_naive_tensor_view( o_acc_ptr, - make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v), + make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v), make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1), number{}, number<1>{}); @@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel make_tuple(number<1>{}, number{}, number{}), sequence{}); - const index_t padded_max_seqlen_q = + const index_t padded_seqlen_q = o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}]; const index_t padded_hdim_v = o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}]; return transform_tensor_view( o_acc_dram_view, - make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)), + make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)), make_pass_through_transform(padded_hdim_v)), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel identity{}, // lse_element_func composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func kargs.num_splits, - kargs.max_seqlen_q, + kargs.seqlen_q, smem_ptr); } else @@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel o_acc_dram_window, lse_dram_window, kargs.num_splits, - kargs.max_seqlen_q, + kargs.seqlen_q, smem_ptr); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp index 9f04843a39..3b73909712 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp @@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner static constexpr ck_tile::index_t kM0 = kM0_; static constexpr ck_tile::index_t kN1 = kN1_; - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t hdim_v) { // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * - ck_tile::integer_divide_ceil(hdim_v_, kN1), - nhead_, - batch_size_); + return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) * + ck_tile::integer_divide_ceil(hdim_v, kN1), + nhead, + batch_size); } CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) { - // const index_t num_tile_m0 = seqlen_q / kM0; const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); const index_t i_block = blockIdx.x; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 22978f1a3c..34f75990c6 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_o_acc; - ck_tile::index_t batch_stride_lse_acc; - ck_tile::index_t batch_stride_o_acc; - ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_o_acc; }; @@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; }; struct GroupModeKargs @@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; - ck_tile::index_t batch_stride_k; - ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_k; // only used for paged-kvcache + ck_tile::index_t batch_stride_v; // only used for paged-kvcache }; using Kargs = std::conditional_t; @@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel nhead_stride_v, nhead_stride_lse_acc, nhead_stride_o_acc, - batch_stride_lse_acc, - batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg {}, // placeholder for bias @@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel reinterpret_cast(seqlen_k_ptr), batch_stride_q, batch_stride_k, - batch_stride_v}; + batch_stride_v, + batch_stride_lse_acc, + batch_stride_o_acc}; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_lse_acc, - ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t batch_stride_k, // only used for paged-kvcache + ck_tile::index_t batch_stride_v, // only used for paged-kvcache ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, @@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel nhead_stride_v, nhead_stride_lse_acc, nhead_stride_o_acc, - batch_stride_lse_acc, - batch_stride_o_acc, split_stride_lse_acc, split_stride_o_acc}, // args for common karg {}, // placeholder for bias @@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, - ck_tile::index_t seqlen_q, + ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits) { - return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v, num_splits); + return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel long_index_t batch_offset_v = 0; long_index_t batch_offset_bias = 0; long_index_t batch_offset_lse_acc = 0; - const long_index_t batch_offset_o_acc = - static_cast(i_batch) * kargs.batch_stride_o_acc; + long_index_t batch_offset_o_acc = 0; if constexpr(kIsGroupMode) { @@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - batch_offset_lse_acc = query_start; + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) { batch_offset_v = key_start * kargs.stride_v; @@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel batch_offset_bias = query_start * kargs.stride_bias + key_start; } + batch_offset_lse_acc = query_start; + batch_offset_o_acc = query_start * kargs.stride_o_acc; + // get real # queries & # keys under group mode kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; @@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel batch_offset_k = static_cast(i_cache_batch) * kargs.batch_stride_k; batch_offset_v = static_cast(i_cache_batch) * kargs.batch_stride_v; batch_offset_lse_acc = static_cast(i_batch) * kargs.batch_stride_lse_acc; + batch_offset_o_acc = static_cast(i_batch) * kargs.batch_stride_o_acc; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel const auto o_acc_dram_naive = make_naive_tensor_view( o_acc_ptr, make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.hdim_v, 1), - number{}, + make_tuple(kargs.stride_o_acc, 1), + number<1>{}, number<1>{}); return pad_tensor_view( diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp index aec37cb36f..2d06ba1762 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp @@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, - ck_tile::index_t seqlen_q, + ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits) { // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) * + return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) * ck_tile::integer_divide_ceil(hdim_v, kN1), nhead * num_splits, batch_size); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 9e6a2725c9..3156e4a356 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP }, s_acc, bias_s_tile); + __builtin_amdgcn_sched_barrier(0); } else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) { @@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); HotLoopScheduler::template GemmStagedScheduler<1>(); + __builtin_amdgcn_sched_barrier(0); // STAGE 4, OGrad@V Gemm2 auto dp_acc = SPGradBlockTileType{}; @@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); HotLoopScheduler::template GemmStagedScheduler<2>(); + __builtin_amdgcn_sched_barrier(0); // STAGE 5, P^T(PGrad^T - D) auto ds = SPGradBlockTileType{}; @@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP Policy::template MakeBiasTileDistribution()); shuffle_tile(dbias_tile, shuffled_dbias_tile); store_tile(dbias_dram_window, dbias_tile); + __builtin_amdgcn_sched_barrier(0); } // STAGE 6, SGrad^T@Q^T Gemm3 @@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP move_tile_window(ds_lds_read_window, {0, kK4}); HotLoopScheduler::template GemmStagedScheduler<3>(); + __builtin_amdgcn_sched_barrier(0); // STAGE 7, SGrad@K^T Gemm4 auto dq_acc = QGradBlockTileType{}; clear_tile(dq_acc); @@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP }); HotLoopScheduler::template GemmStagedScheduler<4>(); + __builtin_amdgcn_sched_barrier(0); // Results Scale if constexpr(FmhaDropout::IsDropout) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 9e1ab81125..8647a7d25a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1727,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy } template <> - CK_TILE_DEVICE static constexpr void GemmStagedScheduler<0>() + CK_TILE_DEVICE constexpr void GemmStagedScheduler<0>() { // Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load // Comp: Q x K @@ -1759,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy } template <> - CK_TILE_DEVICE static constexpr void GemmStagedScheduler<1>() + CK_TILE_DEVICE constexpr void GemmStagedScheduler<1>() { // Mem: Q^T LDS load // Comp: OGrad x V @@ -1777,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy } template <> - CK_TILE_DEVICE static constexpr void GemmStagedScheduler<2>() + CK_TILE_DEVICE constexpr void GemmStagedScheduler<2>() { // Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store // Comp: PT x OGrad @@ -1796,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy } template <> - CK_TILE_DEVICE static constexpr void GemmStagedScheduler<3>() + CK_TILE_DEVICE constexpr void GemmStagedScheduler<3>() { // Mem: SGradT LDS store, SGrad, Q, LSE LDS load. // Comp: SGradT x QT @@ -1830,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy } template <> - CK_TILE_DEVICE static constexpr void GemmStagedScheduler<4>() + CK_TILE_DEVICE constexpr void GemmStagedScheduler<4>() { // Mem: SGrad, OGrad, D LDS load. // Comp: SGrad x KT diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 7efdb798cb..842090afbe 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline const LSEElementFunction& lse_element_func, const OaccElementFunction& o_acc_element_func, index_t num_splits, - index_t max_seqlen_q, + index_t seqlen_q, void* smem_ptr) const { // lse_acc tile in LDS @@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline auto o_acc = make_static_distributed_tensor(o_acc_dist); clear_tile(o_acc); - const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0; + const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0; for(index_t i_split = 0; i_split < num_splits; ++i_split) { @@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline }); } - move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0}); + move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0}); } o_acc = tile_elementwise_in(o_acc_element_func, o_acc); @@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline const OaccDramBlockWindow& o_acc_dram_block_window, LSEDramBlockWindow& lse_dram_block_window, index_t num_splits, - index_t max_seqlen_q, + index_t seqlen_q, void* smem_ptr) const { return operator()(lse_acc_dram_block_window, @@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline identity{}, identity{}, num_splits, - max_seqlen_q, + seqlen_q, smem_ptr); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index b257b9e93d..75af7be82f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); }(); - static constexpr index_t kAlignmentO = - kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); @@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking || kHasUnevenSplits) + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) { const index_t original_num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); @@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); const auto tmp = [&]() { - if constexpr(FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; } diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp new file mode 100644 index 0000000000..57e83a7a51 --- /dev/null +++ b/include/ck_tile/ops/image_to_column.hpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" +#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" +#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp b/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp new file mode 100644 index 0000000000..ee74f1588f --- /dev/null +++ b/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +namespace ck_tile { + +template +struct ImageToColumn +{ + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; + static constexpr auto I4 = number<4>{}; + + using Problem = remove_cvref_t; + + using InDataType = remove_cvref_t; + using OutDataType = remove_cvref_t; + + static constexpr index_t NDimSpatial = Problem::NDimSpatial; + + static constexpr index_t AligmentIn = Problem::AligmentIn; + static constexpr index_t AligmentOut = Problem::AligmentOut; + + static_assert(NDimSpatial == 2, "Not supported."); + + static constexpr index_t kMPerBlock = Problem::BlockShape::kMPerBlock; + static constexpr index_t kKPerBlock = Problem::BlockShape::kKPerBlock; + + struct Kargs + { + const void* p_in; + void* p_out; + + const long_index_t G; + const long_index_t N; + const long_index_t C; + + const array input_spatial_lengths; + const array filter_spatial_lengths; + const array output_spatial_lengths; + const array image_g_n_c_wis_strides; + const array gemm_g_m_k_strides; + const array conv_filter_strides; + const array conv_filter_dilations; + const array input_left_pads; + const array input_right_pads; + }; + + CK_TILE_HOST static constexpr Kargs + MakeKargs(const void* p_in, + void* p_out, + const long_index_t G, + const long_index_t N, + const long_index_t C, + const array input_spatial_lengths, + const array filter_spatial_lengths, + const array output_spatial_lengths, + const array image_g_n_c_wis_strides, + const array gemm_g_m_k_strides, + const array conv_filter_strides, + const array conv_filter_dilations, + const array input_left_pads, + const array input_right_pads) + { + return Kargs{p_in, + p_out, + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + CK_TILE_HOST static constexpr auto GridSize(index_t GemmM, index_t GemmK, index_t Batch) + { + return dim3( + integer_divide_ceil(GemmM, kMPerBlock), integer_divide_ceil(GemmK, kKPerBlock), Batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; } + + CK_TILE_DEVICE auto MakeImageMKDesc(const Kargs& kargs) const + { + static_assert(NDimSpatial == 2, "Not supported."); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple( + kargs.N, kargs.input_spatial_lengths[I0], kargs.input_spatial_lengths[I1], kargs.C), + make_tuple(kargs.image_g_n_c_wis_strides[I1], + kargs.image_g_n_c_wis_strides[I3], + kargs.image_g_n_c_wis_strides[I4], + kargs.image_g_n_c_wis_strides[I2]), + number{}, + I1); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(kargs.N), + make_pad_transform(kargs.input_spatial_lengths[I0], + kargs.input_left_pads[I0], + kargs.input_right_pads[I0]), + make_pad_transform(kargs.input_spatial_lengths[I1], + kargs.input_left_pads[I1], + kargs.input_right_pads[I1]), + make_pass_through_transform(kargs.C)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple( + make_pass_through_transform(kargs.N), + make_embed_transform( + make_tuple(kargs.filter_spatial_lengths[I0], kargs.output_spatial_lengths[I0]), + make_tuple(kargs.conv_filter_dilations[I0], kargs.conv_filter_strides[I0])), + make_embed_transform( + make_tuple(kargs.filter_spatial_lengths[I1], kargs.output_spatial_lengths[I1]), + make_tuple(kargs.conv_filter_dilations[I1], kargs.conv_filter_strides[I1])), + make_pass_through_transform(kargs.C)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple( + kargs.N, kargs.output_spatial_lengths[I0], kargs.output_spatial_lengths[I1])), + make_merge_transform(make_tuple( + kargs.filter_spatial_lengths[I0], kargs.filter_spatial_lengths[I1], kargs.C))), + make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + + CK_TILE_DEVICE auto CalculateMKDims(const Kargs& kargs) const + { + static_assert(NDimSpatial == 2, "Not supported."); + const index_t M = kargs.N * static_cast(kargs.output_spatial_lengths[I0] * + kargs.output_spatial_lengths[I1]); + const index_t K = kargs.C * static_cast(kargs.filter_spatial_lengths[I0] * + kargs.filter_spatial_lengths[I1]); + return make_tuple(M, K); + } + + CK_TILE_DEVICE static constexpr auto MakeBlockTileDistribution() + { + using P = typename Problem::BlockShape; + // P: {kMWarpPerBlock * kKWarpPerBlock, kMThreadPerWarp * kKThreadPerWarp} + // Y: {kMPerThread, kKPerThread} + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 1>>, + sequence<1, 2>, + sequence<2, 2>>{}); + } + + CK_TILE_DEVICE void ConvTensorRearrange(const Kargs& kargs) const + { + const auto [M, K] = CalculateMKDims(kargs); + + const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock); + const index_t iK = __builtin_amdgcn_readfirstlane(blockIdx.y * kKPerBlock); + const index_t iBatch = __builtin_amdgcn_readfirstlane(blockIdx.z); + + const auto in_offset = iBatch * kargs.image_g_n_c_wis_strides[I0]; + const auto out_offset = iBatch * kargs.gemm_g_m_k_strides[I0]; + + const auto image_m_k = make_tensor_view( + static_cast(kargs.p_in) + in_offset, MakeImageMKDesc(kargs)); + const auto gemm_m_k = make_naive_tensor_view( + static_cast(kargs.p_out) + out_offset, + make_tuple(M, K), + make_tuple(kargs.gemm_g_m_k_strides[I1], kargs.gemm_g_m_k_strides[I2]), + number{}, + I1); + + const auto image_m_k_padded = + pad_tensor_view(image_m_k, + make_tuple(number{}, number{}), + sequence{}); + const auto gemm_m_k_padded = + pad_tensor_view(gemm_m_k, + make_tuple(number{}, number{}), + sequence{}); + + constexpr auto dstr = MakeBlockTileDistribution(); + + const auto image_tile = + make_tile_window(image_m_k_padded, + make_tuple(number{}, number{}), + {iM, iK}, + dstr); + + auto gemm_tile = make_tile_window(gemm_m_k_padded, + make_tuple(number{}, number{}), + {iM, iK}, + dstr); + + // load from Global + const auto loaded_tile = load_tile(image_tile); + // save to Global + store_tile(gemm_tile, loaded_tile); + } + + CK_TILE_DEVICE void operator()(Kargs& kargs) const { ConvTensorRearrange(kargs); } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp b/include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp new file mode 100644 index 0000000000..8d50ffde6d --- /dev/null +++ b/include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct BlockImageToColumnProblem +{ + using InDataType = remove_cvref_t; + using OutDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + + static constexpr index_t NDimSpatial = NDimSpatial_; + static constexpr index_t AligmentIn = AligmentIn_; + static constexpr index_t AligmentOut = AligmentOut_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp b/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp new file mode 100644 index 0000000000..b038472fcf --- /dev/null +++ b/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +template // Sequence<... +struct TileImageToColumnShape +{ + static constexpr index_t kMPerThread = ThreadTile::at(number<0>{}); + static constexpr index_t kKPerThread = ThreadTile::at(number<1>{}); + + static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); + static constexpr index_t kKPerWarp = WarpTile::at(number<1>{}); + + static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; + static constexpr index_t kKThreadPerWarp = kKPerWarp / kKPerThread; + + static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); + static constexpr index_t kKPerBlock = BlockTile::at(number<1>{}); + + static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp; + static constexpr index_t kKWarpPerBlock = kKPerBlock / kKPerWarp; + + static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kKWarpPerBlock; +}; + +} // namespace ck_tile diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 26326523f4..5dae86089a 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1 if [ $# -ge 2 ] ; then GPU_TARGETS=$2 + REST_ARGS=${@:3} else GPU_TARGETS="gfx908;gfx90a;gfx940" + REST_ARGS= fi cmake \ @@ -20,4 +22,5 @@ cmake -D GPU_TARGETS=$GPU_TARGETS \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ +$REST_ARGS \ ${MY_PROJECT_SOURCE} diff --git a/script/cmake-ck-release.sh b/script/cmake-ck-release.sh index 25ccb5c799..f65ec610dd 100755 --- a/script/cmake-ck-release.sh +++ b/script/cmake-ck-release.sh @@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1 if [ $# -ge 2 ] ; then GPU_TARGETS=$2 + REST_ARGS=${@:3} else GPU_TARGETS="gfx908;gfx90a;gfx940" + REST_ARGS= fi cmake \ @@ -20,5 +22,6 @@ cmake -D GPU_TARGETS=$GPU_TARGETS \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ +$REST_ARGS \ ${MY_PROJECT_SOURCE} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 71bde7e267..e61d937f08 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -173,6 +173,7 @@ function(add_gtest_executable TEST_NAME) endfunction() add_compile_options(-Wno-c++20-extensions) +add_subdirectory(ck_tile) add_subdirectory(magic_number_division) add_subdirectory(space_filling_curve) add_subdirectory(conv_util) diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt new file mode 100644 index 0000000000..9075ca2ed0 --- /dev/null +++ b/test/ck_tile/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(image_to_column) diff --git a/test/ck_tile/image_to_column/CMakeLists.txt b/test/ck_tile/image_to_column/CMakeLists.txt new file mode 100644 index 0000000000..247358dd4d --- /dev/null +++ b/test/ck_tile/image_to_column/CMakeLists.txt @@ -0,0 +1,4 @@ +# Currently ck_tile is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_tile_image_to_column test_tile_image_to_column.cpp) +endif() diff --git a/test/ck_tile/image_to_column/test_tile_image_to_column.cpp b/test/ck_tile/image_to_column/test_tile_image_to_column.cpp new file mode 100644 index 0000000000..9c0746e972 --- /dev/null +++ b/test/ck_tile/image_to_column/test_tile_image_to_column.cpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/image_to_column.hpp" + +// Host API implementation +template +class TestCkTileImageToColumn : public ::testing::Test +{ + static constexpr ck_tile::index_t VectorSize = 1; + static constexpr ck_tile::index_t NDimSpatial = 2; + + protected: + void Run(const ck_tile::conv::ConvParam conv_params) + { + + using ImLayout = ck_tile::tensor_layout::convolution::NHWGC; + + const auto G = conv_params.G_; + const auto N = conv_params.N_; + const auto C = conv_params.C_; + + const ck_tile::long_index_t NDoHoWo = + N * std::accumulate(conv_params.output_spatial_lengths_.begin(), + std::next(conv_params.output_spatial_lengths_.begin(), NDimSpatial), + 1, + std::multiplies<>()); + + const ck_tile::long_index_t CZYX = + C * std::accumulate(conv_params.filter_spatial_lengths_.begin(), + std::next(conv_params.filter_spatial_lengths_.begin(), NDimSpatial), + 1, + std::multiplies<>()); + + const auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_params); + const auto out_desc = ck_tile::HostTensorDescriptor({G, NDoHoWo, CZYX}); + + // host verify + ck_tile::HostTensor in(in_desc); + ck_tile::HostTensor out_device(out_desc); + ck_tile::HostTensor out_host(out_desc); + + std::cout << "input: " << in.mDesc << std::endl; + std::cout << "output: " << out_device.mDesc << std::endl; + + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(in); + + ck_tile::DeviceMem in_device_buf(in.get_element_space_size_in_bytes()); + ck_tile::DeviceMem out_device_buf(out_device.get_element_space_size_in_bytes()); + + in_device_buf.ToDevice(in.data()); + + using thread_tile = ck_tile::sequence<4, 4>; + using warp_tile = ck_tile::sequence<8, 128>; + using block_tile = ck_tile::sequence<32, 128>; + + using Shape = ck_tile::TileImageToColumnShape; + + using PipelineProblem = ck_tile::BlockImageToColumnProblem; + + using Kernel = ck_tile::ImageToColumn; + + auto kargs = Kernel::MakeKargs( + in_device_buf.GetDeviceBuffer(), + out_device_buf.GetDeviceBuffer(), + G, + N, + C, + ck_tile::to_array( + conv_params.input_spatial_lengths_), + ck_tile::to_array( + conv_params.filter_spatial_lengths_), + ck_tile::to_array( + conv_params.output_spatial_lengths_), + ck_tile::to_array(in_desc.get_strides()), + ck_tile::to_array(out_desc.get_strides()), + ck_tile::to_array(conv_params.conv_filter_strides_), + ck_tile::to_array( + conv_params.conv_filter_dilations_), + ck_tile::to_array(conv_params.input_left_pads_), + ck_tile::to_array(conv_params.input_right_pads_)); + + const dim3 grids = Kernel::GridSize( + kargs.N * kargs.output_spatial_lengths[0] * kargs.output_spatial_lengths[1], + kargs.filter_spatial_lengths[0] * kargs.filter_spatial_lengths[1] * kargs.C, + kargs.G); + constexpr dim3 blocks = Kernel::BlockSize(); + + constexpr ck_tile::index_t kBlockPerCu = 2; + + ck_tile::launch_kernel( + ck_tile::stream_config{}, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + // reference + ck_tile::reference_im2col(in, out_host, conv_params); + + out_device_buf.FromDevice(out_device.data()); + bool pass = ck_tile::check_err(out_device, out_host); + + EXPECT_TRUE(pass); + } +}; + +class TestCkTileImageToColumnFloat : public TestCkTileImageToColumn +{ +}; + +class TestCkTileImageToColumnHalf : public TestCkTileImageToColumn +{ +}; + +TEST_F(TestCkTileImageToColumnFloat, TestCorrectness) +{ + this->Run({2, 2, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->Run({2, 2, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->Run({2, 1, 64, 1, 64, {1, 1}, {7, 7}, {3, 3}, {1, 1}, {0, 0}, {0, 0}}); + this->Run({2, 1, 64, 1, 64, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->Run({2, 2, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}}); +} + +TEST_F(TestCkTileImageToColumnHalf, TestCorrectness) +{ + this->Run({2, 2, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->Run({2, 2, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->Run({2, 1, 64, 1, 64, {1, 1}, {7, 7}, {3, 3}, {1, 1}, {0, 0}, {0, 0}}); + this->Run({2, 1, 64, 1, 64, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->Run({2, 2, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}}); +}