diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index af36f492ba..0d7bcd6b18 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @vpietila-amd @Snektron # Documentation files -docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD -*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD -*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD -.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD +docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron +*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron +*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron +.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd +library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @vpietila-amd @Snektron diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index 0baa503334..cc6178b08c 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -54,7 +54,7 @@ jobs: with: repository: "ROCm/TheRock" path: "TheRock" - ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit + ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit - name: Setup ccache run: | @@ -78,8 +78,9 @@ jobs: run: | git config --global --add safe.directory '*' # Remove patches here if they cannot be applied cleanly, and they have not been deleted from TheRock repo - rm -f ./TheRock/patches/amd-mainline/rocm-libraries/0008-Revert-remove-options-no-enumerate-966.patch - git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch + rm ./TheRock/patches/amd-mainline/rocm-libraries/0003-Find-rocm_smi-via-config-files.patch + rm ./TheRock/patches/amd-mainline/rocm-libraries/0007-Remove-Windows-third_party_dlls-copying-code.patch + # git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch - name: Install python deps run: | diff --git a/.github/workflows/therock-test-component.yml b/.github/workflows/therock-test-component.yml index 565d1d3e54..74f3bb0017 100644 --- a/.github/workflows/therock-test-component.yml +++ b/.github/workflows/therock-test-component.yml @@ -51,7 +51,7 @@ jobs: uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: "ROCm/TheRock" - ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit + ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit - name: Run setup test environment workflow uses: './.github/actions/setup_test_environment' diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index cd255a40b6..e4bd295c95 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -27,7 +27,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit + ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit - name: "Configuring CI options" env: diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a9b25b062..066dc9aa3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,12 +5,16 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## (Unreleased) Composable Kernel 1.3.0 ### Added +* Added preshuffleB support for abquant mode in blockscale GEMM. * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". -* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. +* Added streamingllm sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. * Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline. * Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel. * Added FP8 KV cache support for FMHA batch prefill. +* Added support for gfx1153 target. +* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. +* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. ### Changed @@ -40,6 +44,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added top-k sigmoid kernel in CK_TILE * Added the blockscale 2D support for CK_TILE GEMM. * Added Flatmm pipeline for microscaling (MX) FP8/FP4 data types +* Added reduce and multi reduction kernels ### Changed diff --git a/CMakeLists.txt b/CMakeLists.txt index eaed7d3509..121c663f64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,11 +31,12 @@ endif() # Default installation path if(NOT WIN32) set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "") +else() + set(CMAKE_INSTALL_PREFIX "C:/dist/TheRock" CACHE PATH "") endif() set(version 1.2.0) -# Check support for CUDA/HIP in Cmake -project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) +project(composable_kernel VERSION ${version} LANGUAGES CXX) include(CTest) option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON) @@ -162,7 +163,13 @@ execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMI configure_file(include/ck/version.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/version.h) set(ROCM_SYMLINK_LIBS OFF) -find_package(ROCM REQUIRED PATHS /opt/rocm) + +if (WIN32) + find_package(ROCmCMakeBuildTools REQUIRED PATHS C:/dist/TheRock) + set(HIP_PLATFORM "amd" CACHE STRING "HIP platform") +else() + find_package(ROCM REQUIRED PATHS /opt/rocm) +endif() include(ROCMInstallTargets) include(ROCMPackageConfigHelpers) @@ -189,7 +196,10 @@ if(GPU_TARGETS) else() set(USER_GPU_TARGETS 0) endif() + find_package(hip REQUIRED) +enable_language(HIP) + # No assumption that HIP kernels are launched with uniform block size for backward compatibility # SWDEV-413293 and https://reviews.llvm.org/D155213 math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}") diff --git a/Dockerfile.aiter b/Dockerfile.aiter index 94591f9012..020afeccf4 100644 --- a/Dockerfile.aiter +++ b/Dockerfile.aiter @@ -2,7 +2,7 @@ ARG BASE_DOCKER="rocm/pytorch:latest" FROM $BASE_DOCKER ARG AITER_BRANCH="main" ARG CK_AITER_BRANCH="develop" -RUN pip install pandas zmq einops ninja && \ +RUN pip install pandas zmq einops ninja tabulate && \ pip install numpy==1.26.2 && \ sudo mkdir /home/jenkins && \ sudo mkdir /home/jenkins/workspace && \ diff --git a/Jenkinsfile b/Jenkinsfile index cb2f8631c5..58b5194f60 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -574,6 +574,8 @@ def cmake_build(Map conf=[:]){ def setup_cmd def build_cmd def execute_cmd = conf.get("execute_cmd", "") + //check the node gpu architecture + def arch_name = check_arch_name() if(!setup_args.contains("NO_CK_BUILD")){ if (params.NINJA_BUILD_TRACE) { echo "running ninja build trace" @@ -646,15 +648,15 @@ def cmake_build(Map conf=[:]){ //run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){ - sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${check_arch_name()}.json" - archiveArtifacts "ck_build_trace_${check_arch_name()}.json" - sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${check_arch_name()}.json" + sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${arch_name}.json" + archiveArtifacts "ck_build_trace_${arch_name}.json" + sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${arch_name}.json" if (params.NINJA_BUILD_TRACE || params.BUILD_INSTANCES_ONLY){ if (params.NINJA_FTIME_TRACE) { echo "running ClangBuildAnalyzer" sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log" - sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${check_arch_name()}.log" - archiveArtifacts "clang_build_analysis_${check_arch_name()}.log" + sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${arch_name}.log" + archiveArtifacts "clang_build_analysis_${arch_name}.log" } @@ -672,8 +674,8 @@ def cmake_build(Map conf=[:]){ if(params.BUILD_PACKAGES){ echo "Build ckProfiler packages" sh 'ninja -j64 package' - sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb" - stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}" + sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb" + stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}" } } if(params.BUILD_INSTANCES_ONLY){ @@ -699,16 +701,14 @@ def cmake_build(Map conf=[:]){ if(params.BUILD_PACKAGES){ echo "Build ckProfiler packages" sh 'ninja -j64 package' - sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb" - stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}" + sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb" + stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}" } } } } } - //check the node gpu architecture - def arch_name = check_arch_name() if (params.RUN_CK_TILE_FMHA_TESTS){ try{ archiveArtifacts "perf_fmha_*.log" @@ -811,41 +811,12 @@ def Build_CK(Map conf=[:]){ archiveArtifacts "perf_*.log" stash includes: "perf_**.log", name: "perf_log_${arch}" } - // disable performance tests on gfx1030 for now. - //else if ( arch == "gfx10"){ - // run basic tests on gfx1030 - // echo "Run gemm performance tests" - // sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx10" - // archiveArtifacts "perf_onnx_gemm_gfx10.log" - // stash includes: "perf_onnx_gemm_gfx10.log", name: "perf_log_gfx10" - //} - else if ( arch == "gfx11"){ - // run basic tests on gfx11 + else if ( arch != "gfx10"){ + // run basic tests on gfx11/gfx12/gfx908/gfx950, but not on gfx10, it takes too long echo "Run gemm performance tests" - sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx11" - archiveArtifacts "perf_onnx_gemm_gfx11.log" - stash includes: "perf_onnx_gemm_gfx11.log", name: "perf_log_gfx11" - } - else if ( arch == "gfx120" ){ - // run basic tests on gfx12 - echo "Run gemm performance tests" - sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx12" - archiveArtifacts "perf_onnx_gemm_gfx12.log" - stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12" - } - else if ( arch == "gfx908" ){ - // run basic tests on gfx908 - echo "Run performance tests" - sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx908" - archiveArtifacts "perf_onnx_gemm_gfx908.log" - stash includes: "perf_onnx_gemm_gfx908.log", name: "perf_log_gfx908" - } - else if ( arch == "gfx950" ){ - // run basic tests on gfx950 - echo "Run performance tests" - sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx950" - archiveArtifacts "perf_onnx_gemm_gfx950.log" - stash includes: "perf_onnx_gemm_gfx950.log", name: "perf_log_gfx950" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} ${arch}" + archiveArtifacts "perf_onnx_gemm_*.log" + stash includes: "perf_onnx_gemm_**.log", name: "perf_log_${arch}" } } } @@ -1046,9 +1017,10 @@ def run_aiter_tests(Map conf=[:]){ sh "rocminfo" sh "python3 --version" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py" - //sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" //temporarily disable + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_batch_prefill.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_2stage.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_blockscale.py" @@ -1201,8 +1173,8 @@ pipeline { description: "Run the ck_tile FMHA tests (default: OFF)") booleanParam( name: "RUN_TILE_ENGINE_BASIC_TESTS", - defaultValue: false, - description: "Run the tile_engine_basic tests (default: OFF)") + defaultValue: true, + description: "Run the tile_engine_basic tests (default: ON)") booleanParam( name: "RUN_TILE_ENGINE_GEMM_TESTS", defaultValue: false, @@ -1346,21 +1318,15 @@ pipeline { agent{ label rocmnode("nogpu") } environment{ setup_args = "NO_CK_BUILD" - execute_cmd = "(cd .. && git ls-files \'*.h\' \ - \'*.hpp\' \ - \'*.cpp\' \ - \'*.h.in\' \ - \'*.hpp.in\' \ - \'*.cpp.in\' \ - \'*.cl\' \ - | grep -v 'build/' \ - | grep -v 'include/rapidjson' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\') && \ + execute_cmd = """cd .. && \ + find . -type f \\( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' -o -name '*.h.in' -o -name '*.hpp.in' -o -name '*.cpp.in' -o -name '*.cl' \\) \ + -not -path '*/build/*' -not -path '*/include/rapidjson/*' | \ + xargs -P 8 -I{} sh -c 'clang-format-18 -style=file {} | diff -u - {} || (echo "ERROR: {} needs formatting" && exit 1)' && \ /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \ -D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ -U __gfx803__ -U __gfx900__ -U __gfx906__ -U CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 \ - --file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log" + --file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd) @@ -1376,17 +1342,10 @@ pipeline { agent{ label rocmnode("nogpu") } environment{ setup_args = "NO_CK_BUILD" - execute_cmd = "(cd .. && git ls-files \ - \'*.h\' \ - \'*.hpp\' \ - \'*.cpp\' \ - \'*.h.in\' \ - \'*.hpp.in\' \ - \'*.cpp.in\' \ - \'*.cl\' \ - | grep -v 'build/' \ - | grep -v 'include/rapidjson' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\')" + execute_cmd = """cd .. && \ + find . -type f \\( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' -o -name '*.h.in' -o -name '*.hpp.in' -o -name '*.cpp.in' -o -name '*.cl' \\) \ + -not -path '*/build/*' -not -path '*/include/rapidjson/*' | \ + xargs -P 8 -I{} sh -c 'clang-format-18 -style=file {} | diff -u - {} || (echo "ERROR: {} needs formatting" && exit 1)'""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd) @@ -1469,8 +1428,8 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 test_grouped_convnd_fwd_large_cases test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \ - ./bin/test_grouped_convnd_fwd_large_cases && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases""" + make -j64 test_grouped_convnd_fwd_large_cases test_grouped_convnd_bwd_data_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \ + ./bin/test_grouped_convnd_fwd_large_cases && ./bin/test_grouped_convnd_bwd_data_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) @@ -1650,7 +1609,10 @@ pipeline { -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ -D GEMM_PRESHUFFLE_LAYOUT="rcr" \ -D GEMM_PRESHUFFLE_CONFIG_FILE="default_ci_config.json" .. && \ - ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all """ + ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \ + python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ + python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ + python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) @@ -1667,37 +1629,6 @@ pipeline { } parallel { - stage("Run TILE_ENGINE_GEMM Tests on gfx90a") - { - when { - beforeAgent true - expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() } - } - agent{ label rocmnode("gfx90a") } - environment{ - setup_args = "NO_CK_BUILD" - execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \ - -D CMAKE_BUILD_TYPE=Release \ - -D GPU_TARGETS="gfx90a" \ - -D GEMM_UNIVERSAL_DATATYPE="fp8;fp16" \ - -D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \ - -D GEMM_STREAMK_DATATYPE="fp8;fp16" \ - -D GEMM_STREAMK_LAYOUT="rcr" \ - -D GEMM_MULTI_D_DATATYPE="fp16" \ - -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ - -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ - -D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \ - ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \ - python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ - python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ - python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ - } - steps{ - buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) - cleanWs() - } - } stage("Run TILE_ENGINE_GEMM Tests on gfx942") { when { @@ -1787,7 +1718,10 @@ pipeline { } agent{ label rocmnode("gfx90a") } environment{ - setup_args = """ -DGPU_TARGETS="gfx942" -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """ + // SLES15 is a legacy platform with limited C++20 ecosystem support (older system libraries, + // standard library implementation). While the ROCm compiler supports C++20, the experimental + // CK Builder requires full C++20 feature support that does not be reliably available on SLES15. + setup_args = """ -DGPU_TARGETS="gfx942" -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 -DCK_EXPERIMENTAL_BUILDER=OFF """ execute_args = " " } steps{ diff --git a/README.md b/README.md index 8a5258bab6..09540ff245 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,22 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa ``` **[See Note on -j](#notes)** +### Building for Windows + +Install TheRock and run CMake configure as + +```bash + cmake \ + -D CMAKE_PREFIX_PATH="C:/dist/TheRock" \ + -D CMAKE_CXX_COMPILER="C:/dist/TheRock/bin/hipcc.exe" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx1151" \ + -G Ninja \ + .. +``` + +Use Ninja to build either the whole library or individual targets. + ## Optional post-install steps * Build examples and tests: diff --git a/docs/conceptual/ck_tile/buffer_views.rst b/docs/conceptual/ck_tile/buffer_views.rst index 14b8309504..03b8e87b1b 100644 --- a/docs/conceptual/ck_tile/buffer_views.rst +++ b/docs/conceptual/ck_tile/buffer_views.rst @@ -1,35 +1,13 @@ -.. meta:: - :description: Composable Kernel CK Tile buffer views - :keywords: composable kernel, CK, CK Tile, ROCm, API, buffer view, raw memory - .. _ck_tile_buffer_views: -CK Tile buffer view -======================= - -Buffer view is an abstraction that provides structured access to memory. The ``buffer_view`` class is exposed in ``include/ck_tile/core/tensor/buffer_view.hpp``. - -Buffer view serves as the foundation for :ref:`ck_tile_tensor_views`. BufferView handles memory addressing and type safety, while TensorView builds upon this to add multi-dimensional coordinates (shape and strides). - - -Buffer view provides the following advantages: - -* A unified interface across global, shared, and register memory -* Address spaces encoded in types, taking advantage of compile-time type checking -* Configurable handling of invalid values, out-of-bounds operations, and conditional access patterns -* Atomic operations for parallel algorithms -* AMD GPU-specific optimizations -* Automatic application of appropriate memory ordering constraints and cache control directives based on the target address space and operation type - - -[TO DO: do we want to say more about these items? There wasn't a lot of detail in the original text, so I put them in a list for now] - - +Buffer Views - Raw Memory Access Address Space Usage Patterns ---------------------------- -[TO DO: explain in words what the diagram shows] +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + .. Original mermaid diagram (edit here, then run update_diagrams.py) @@ -66,18 +44,26 @@ Address Space Usage Patterns style Compute fill:#e0e7ff,stroke:#4338ca,stroke-width:2px + + + + .. image:: diagrams/buffer_views_1.svg :alt: Diagram :align: center +C++ Implementation +------------------ +**File**: ``include/ck_tile/core/tensor/buffer_view.hpp`` Basic Creation ~~~~~~~~~~~~~~ -[TO DO: remove "modern C++ template metaprogramming" and "zero-overhead abstraction"] +By encoding critical properties such as buffer size and address space as template parameters, BufferView transforms what would traditionally be runtime decisions into compile-time constants. This design philosophy enables the compiler to perform aggressive optimizations, including constant propagation, loop unrolling, and instruction selection, that would be impossible with runtime parameters. -[TO DO: might want to move the implementation details to a separate section under "reference"] +The use of compile-time constants extends beyond mere optimization. When the buffer size is encoded in the type system using constructs like ``number<8>{}``, the compiler can statically verify that array accesses are within bounds, eliminate unnecessary bounds checks, and even restructure algorithms to better match the known data dimensions. This compile-time knowledge propagates through the entire computation, enabling optimizations at every level of the abstraction hierarchy. +The address space template parameter represents another crucial design decision. By making the memory space part of the type system, BufferView ensures that operations appropriate for one memory space cannot be accidentally applied to another. This type safety prevents common errors such as attempting atomic operations on register memory or using global memory synchronization primitives on local memory. The compiler enforces these constraints at compile time, transforming potential runtime errors into compile-time diagnostics. .. code-block:: cpp @@ -98,7 +84,6 @@ Basic Creation buffer_size // number of elements ); - // Implementation detail: The actual C++ template is: // template (data, buffer_size, custom_invalid); - - // Invalid element access with is_valid_element=false - // Returns custom_invalid due to custom invalid value mode - auto invalid_value = buffer_view.template get(0, 0, false); - printf("Invalid element: %.1f\n", invalid_value.get(0)); - - // Out of bounds access - AMD buffer addressing handles bounds checking - // Will return custom_invalid when accessing beyond buffer_size - auto oob_value = buffer_view.template get(0, 100, true); - printf("Out of bounds: %.1f\n", oob_value.get(0)); - - - - - Get Operations -------------- -[TO DO: might want to put this implementation detail in the reference section] +Scalar Access +~~~~~~~~~~~~~ -The signature for the ``buffer_view`` ``get()`` takes four parameters: +The get operations in BufferView form the cornerstone of memory access patterns in CK Tile. These operations embody a advanced understanding of GPU memory systems and the patterns that lead to optimal performance. The scalar access interface incorporates multiple layers of optimization and safety mechanisms that work together to provide both performance and correctness. -``i``: the primary offset into the buffer expressed in terms of elements of type T rather than raw bytes. +The parameter structure of scalar access operations reflects careful design choices aimed at maximizing flexibility while maintaining efficiency. The base index parameter ``i`` represents the primary offset into the buffer, expressed in terms of elements of type T rather than raw bytes. This type-aware indexing prevents common errors related to pointer arithmetic and ensures that vector types are handled correctly. The additional ``linear_offset`` parameter provides fine-grained control over the final access location, enabling complex access patterns without requiring expensive index calculations in the kernel code. -``linear_offset``: [TO DO: what is this?] +The ``is_valid_element`` parameter provides a solution to conditional memory access. Rather than using traditional if-statements that would cause warp divergence, this boolean parameter enables predicated execution where the memory access occurs unconditionally but the result is conditionally used. This approach maintains uniform control flow across all threads in a warp, preserving the SIMD execution model that is fundamental to GPU performance. -``is_valid_element``: [TO DO: what is this?] +The invalid value modes provide a mechanism for handling the boundary conditions that arise in parallel algorithms. When ``InvalidElementUseNumericalZeroValue`` is set to true, the system returns zero for any invalid access, whether due to the ``is_valid_element`` flag or out-of-bounds indexing. This mode is important for algorithms where zero serves as a natural extension value, such as in image processing with zero-padding or sparse matrix operations where missing elements are implicitly zero. -[TO DO: the last param, that's the out of bounds handling, yes? -.. code:: cpp +The custom invalid value mode, activated when ``InvalidElementUseNumericalZeroValue`` is false, offers additional flexibility for algorithms with specific boundary requirements. This mode returns a user-specified value for invalid accesses, accommodating use cases such as sentinel values in sorting algorithms, infinity values in optimization problems, or special markers in data processing pipelines. The implementation ensures that this flexibility comes without performance penalty, using the same branchless execution strategies as the zero mode. - get(index_t i, - index_t linear_offset, - bool is_valid_element, - bool_constant = {}) +Out-of-bounds handling leverages AMD GPU hardware capabilities to provide safety with minimal impact to performance. When AMD buffer addressing is enabled, the hardware automatically clamps memory accesses to valid ranges, preventing the segmentation faults that would occur on CPU systems. This hardware-assisted bounds checking operates at wire speed, adding no overhead to the memory access path while ensuring that kernels cannot corrupt memory outside their allocated regions. +Vector Access +~~~~~~~~~~~~~ -[TO DO: need some context around the code] +Vector memory operations represent one of the most critical optimizations available in modern GPU programming, and BufferView's vector access interface exposes this capability. By using template parameters to specify vector types through constructs like ``ext_vector_t``, the interface enables compile-time selection of optimal load and store instructions that can transfer multiple data elements in a single memory transaction. This vectorization is crucial for :ref:`ck_tile_load_store_traits`, which automatically selects optimal access patterns. -[TO DO: code chunks need to have detail and explanation so that the reader can see what they're trying to demonstrate.] +The significance of vector operations extends beyond bandwidth improvements. GPUs are designed with wide memory buses that can transfer 128, 256, or even 512 bits per transaction. When scalar operations access only 32 bits at a time, they utilize only a fraction of this available bandwidth. Vector operations align with these wide buses, enabling full bandwidth utilization and reducing the total number of memory transactions required. +The implementation of vector access maintains the same parameter structure as scalar operations, providing consistency across the API while automatically handling the complexities of multi-element transfers. The system manages alignment requirements, ensures that vector loads and stores use the optimal hardware instructions, and handles cases where vector operations extend beyond buffer boundaries. This transparent handling of edge cases allows developers to use vector operations confidently without manual boundary checks or special-case code for partial vectors. -.. code-block:: cpp - - // Create buffer view - float data[8] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; - auto buffer_view = make_buffer_view(data, 8); - - // Simple get - compile-time bounds checking when possible - auto value_buf = buffer_view.template get(0,1,true); //get the buffer from the buffer view - float value = value_buf.get(0); //get the value from the buffer - - // Get with valid flag - branchless conditional access - bool valid_flag = false; - value_buf = buffer_view.template get(0,1,valid_flag); - value = value_buf.get(0); - // Returns 0 valid_flag is false - - // vectorized get - using float2 = ext_vector_t; - auto vector_buf = buffer_view.template get(0, 0, true); - // Loads 2 floats in a single instruction - float val1 = vector_buf.get(0); - float val2 = vector_buf.get(1); - } - -``ext_vector_t`` enables compile-time selection of optimal load and store instructions that can transfer multiple data elements in a single memory transaction. - -[TO DO: what is it actually doing? When does one use scalars vs vectors? Is it application specific or are there ] +Scalar vs Vectorized Memory Access +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + .. Original mermaid diagram (edit here, then run update_diagrams.py) @@ -287,8 +216,9 @@ The signature for the ``buffer_view`` ``get()`` takes four parameters: Understanding BufferView Indexing ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -[TO DO: an explanation of the diagram is needed] - +.. + Original mermaid diagram (edit here, then run update_diagrams.py) + .. Original mermaid diagram (edit here, then run update_diagrams.py) @@ -335,14 +265,69 @@ Understanding BufferView Indexing .. image:: diagrams/buffer_views_3.svg :alt: Diagram :align: center - - + +C++ Get Operations +~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + __device__ void example_get_operations() + { + // Create buffer view + float data[8] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + auto buffer_view = make_buffer_view(data, 8); + + // Simple get - compile-time bounds checking when possible + auto value_buf = buffer_view.template get(0,1,true); //get the buffer from the buffer view + float value = value_buf.get(0); //get the value from the buffer + + // Get with valid flag - branchless conditional access + bool valid_flag = false; + value_buf = buffer_view.template get(0,1,valid_flag); + value = value_buf.get(0); + // Returns 0 valid_flag is false + + // vectorized get + using float2 = ext_vector_t; + auto vector_buf = buffer_view.template get(0, 0, true); + // Loads 2 floats in a single instruction + float val1 = vector_buf.get(0); + float val2 = vector_buf.get(1); + } + +Custom Value Return Mode for OOB & Invalid Access +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: cpp + + void scalar_get_operations_example() { + + // Create data array + constexpr size_t buffer_size = 8; + float data[buffer_size] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + float custom_invalid = 13.0f; + + // Create global memory buffer view with zero invalid value mode (default) + auto buffer_view = make_buffer_view(data, buffer_size, custom_invalid); + + // Invalid element access with is_valid_element=false + // Returns custom_invalid due to custom invalid value mode + auto invalid_value = buffer_view.template get(0, 0, false); + printf("Invalid element: %.1f\n", invalid_value.get(0)); + + // Out of bounds access - AMD buffer addressing handles bounds checking + // Will return custom_invalid when accessing beyond buffer_size + auto oob_value = buffer_view.template get(0, 100, true); + printf("Out of bounds: %.1f\n", oob_value.get(0)); + } + +.. note:: + + Partial Out Of Bound (OOB) access during vector reads will return 'junk' values for the OOB access. Zero or custom invalid value is only returned for complete invalid/OOB access, in other words, it is only returned when the first address of the vector is invalid. Update Operations ----------------- -Update operations modify the buffer content. The ``set()`` method writes a value to a specific location. - .. code-block:: cpp void scalar_set_operations_example() { @@ -373,8 +358,6 @@ Update operations modify the buffer content. The ``set()`` method writes a value Atomic Operations ----------------- -[TO DO: this needs information] - Atomic vs Non-Atomic Operations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -441,3 +424,21 @@ C++ Atomic Operations __syncthreads(); } + +Summary +------- + +BufferView abstracts GPU memory hierarchies behind a concise interface. The approach is intended to keep overhead small while enabling optimizations that are otherwise awkward in low-level code. + +BufferView offers a unified interface across global, shared, and register memory. Using the same API for each space can lower cognitive overhead, reduce certain classes of mistakes, and support code reuse via template parameters. + +Address spaces are encoded in types so that common errors are reported at compile time. Consistent with CK Tile’s zero-overhead design aim, compile-time checks are favored over runtime guards. The C++ type system enforces memory-space constraints and can make valid cases more amenable to compiler optimization. + +BufferView supports configurable handling of invalid values, optional runtime bounds checks, and conditional access patterns. It also provides atomic operations for thread-safe updates. These features are intended to cover common edge cases without adding unnecessary overhead. + +By hiding the complexity of different memory spaces while exposing the operations needed for high-performance GPU computing, BufferView establishes a pattern that the rest of CK Tile follows: compile-time abstractions that enhance rather than compromise performance. The :ref:`ck_tile_tensor_views` and :ref:`ck_tile_distribution` add capability while maintaining the efficiency established at the base. For hardware-specific details about memory hierarchies, see :ref:`ck_tile_gpu_basics`. + +Next Steps +---------- + +Continue to :ref:`ck_tile_tensor_views` to learn how to build structured tensor views on top of buffer views. diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index b1ab09e6f7..f2fb27e2b9 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core[api_reference]==1.31.1 +rocm-docs-core[api_reference]==1.31.3 sphinxcontrib-bibtex==2.6.5 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 099e9e439f..23397503df 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -237,7 +237,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core[api-reference]==1.31.1 +rocm-docs-core[api-reference]==1.31.3 # via -r requirements.in rpds-py==0.24.0 # via diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 2d65368d4f..aba462638e 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -149,3 +149,7 @@ add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3) add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale) +add_example_executable(example_gemm_wmma_fp8_bpreshuffle gemm_wmma_fp8_bpreshuffle.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_bpreshuffle) +add_example_executable(example_gemm_wmma_fp16_bpreshuffle gemm_wmma_fp16_bpreshuffle.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_bpreshuffle) diff --git a/example/01_gemm/gemm_wmma_fp16_bpreshuffle.cpp b/example/01_gemm/gemm_wmma_fp16_bpreshuffle.cpp new file mode 100644 index 0000000000..d03971e6ec --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_bpreshuffle.cpp @@ -0,0 +1,70 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/get_id.hpp" +#include "ck/utility/scheduler_enum.hpp" + +#include +#include +#include + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using ComputeTypeA = F16; +using ComputeTypeB = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = false; +static constexpr int KPack = 8; // int4 -> 32, fp8 -> 16, fp16 -> 8 +// clang-format off +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3_BPreshuffle< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 128, + 32, 128, 128, + 8, 8, + 16, 16, + 2, 2, + S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB>; +// clang-format on + +#include "run_gemm_wmma_bpreshuffle_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp8_bpreshuffle.cpp b/example/01_gemm/gemm_wmma_fp8_bpreshuffle.cpp new file mode 100644 index 0000000000..8f8b380b93 --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp8_bpreshuffle.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/get_id.hpp" +#include "ck/utility/scheduler_enum.hpp" + +#include +#include +#include + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F8; +using BDataType = F8; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using ComputeTypeA = F8; +using ComputeTypeB = F8; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = false; +static constexpr int KPack = 16; // int4 -> 32, fp8 -> 16, fp16 -> 8 +// clang-format off +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3_BPreshuffle< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, + 32, 128, 256, + 16, 16, + 16, 16, + 2, 1, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB>; +// clang-format on + +#include "run_gemm_wmma_bpreshuffle_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/run_gemm_wmma_bpreshuffle_example.inc b/example/01_gemm/run_gemm_wmma_bpreshuffle_example.inc new file mode 100644 index 0000000000..b1d73cfe10 --- /dev/null +++ b/example/01_gemm/run_gemm_wmma_bpreshuffle_example.inc @@ -0,0 +1,206 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_preshuffled(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b_k_n_preshuffled: " << b_k_n_preshuffled.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // do GEMM + auto device_op = DeviceOpInstance{}; + + // weight pre-shuffle + int NPerWmma = device_op.GetPreShuffleParameters(); + int KLane = ck::get_warp_size() / NPerWmma; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NPerWmma + // N, K -> N0 K0 KLane NPerWmma KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NPerWmma; + int n1 = n % NPerWmma; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NPerWmma * KLane * K0 + k0 * KPack * NPerWmma * KLane + + k1 * KPack * NPerWmma + n1 * KPack + k2; + + b_k_n_preshuffled(outputIndex) = b_k_n(n * K + k); + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_preshuffled.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + std::cerr << device_op.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + float ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 50, false, 1}); + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size{3840, 4096, 4096, 4096, 4096, 4096, 1}; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index 0bded7d2ac..9b48d5765d 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -119,7 +119,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 3840; diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index 4acf4fe9ff..a770bf5c77 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -119,7 +119,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 3840; diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp index 55f3d99823..f8299028da 100644 --- a/example/12_reduce/reduce_blockwise.cpp +++ b/example/12_reduce/reduce_blockwise.cpp @@ -31,7 +31,7 @@ class SimpleAppArgs bool do_verification = true; int data_type = 1; int init_method = 2; - bool time_kernel = true; + bool time_kernel = false; public: void show_usage(const char* cmd) diff --git a/example/12_reduce/reduce_multiblock_atomic_add.cpp b/example/12_reduce/reduce_multiblock_atomic_add.cpp index af5903f83c..66fc2bb582 100644 --- a/example/12_reduce/reduce_multiblock_atomic_add.cpp +++ b/example/12_reduce/reduce_multiblock_atomic_add.cpp @@ -31,7 +31,7 @@ class SimpleAppArgs bool do_verification = true; int data_type = 1; int init_method = 2; - bool time_kernel = true; + bool time_kernel = false; public: void show_usage(const char* cmd) diff --git a/example/12_reduce/reduce_threadwise_multi_d.cpp b/example/12_reduce/reduce_threadwise_multi_d.cpp index e77daea212..ee06395771 100644 --- a/example/12_reduce/reduce_threadwise_multi_d.cpp +++ b/example/12_reduce/reduce_threadwise_multi_d.cpp @@ -31,7 +31,7 @@ class SimpleAppArgs bool do_verification = true; int data_type = 1; int init_method = 2; - bool time_kernel = true; + bool time_kernel = false; public: void show_usage(const char* cmd) diff --git a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp index f0a9ce9270..fc083ba3e2 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp +++ b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp @@ -53,7 +53,7 @@ int main(int argc, char* argv[]) { do_verification = true; init_method = 1; - time_kernel = true; + time_kernel = false; } else if(argc == 4) { diff --git a/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp b/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp index cc5e3616ff..7437d0be9d 100644 --- a/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp @@ -27,10 +27,11 @@ using ::ck::Tensor; template using S = ck::Sequence; -using I8 = int8_t; -using I32 = int32_t; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = PassThrough; @@ -125,11 +126,11 @@ int main(int /* argc */, char* /* argv */[]) if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index ce41c3310f..a7dae9dcd8 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -44,6 +44,9 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp) add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16) +add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp) +add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16) + list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp index 62d2022084..6fe285f165 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp @@ -90,7 +90,7 @@ struct ExecutionConfig final bool do_verification = true; int init_method = 1; int k_batch = 128; - bool time_kernel = true; + bool time_kernel = false; }; bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_wmma_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_wmma_fp16.cpp new file mode 100644 index 0000000000..bd58ea433f --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_wmma_fp16.cpp @@ -0,0 +1,76 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAdd = ck::tensor_operation::element_wise::AddAdd; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr int NumDs = 2; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3 + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<4, 4, 4>>; +// clang-format on + +#include "run_grouped_gemm_multiple_d_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index 1db8a9defb..9fdcf4aaad 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -71,339 +71,6 @@ using DeviceGemmInstance = < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>; // clang-format on -struct ProblemSize final -{ - std::vector Ms; - std::vector Ns; - std::vector Ks; +#include "run_grouped_gemm_multiple_d_example.inc" - std::vector stride_As; - std::vector stride_Bs; - std::vector> stride_Ds; - std::vector stride_Cs; - - ck::index_t group_count; -}; - -struct ExecutionConfig final -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = true; -}; - -bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) -{ - auto group_count = problem_size.group_count; - - using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument; - using GemmDesc = ck::tensor_operation::device::GemmDesc; - - // GEMM shape - std::vector gemm_descs; - std::vector ggemm_kargs; - std::vector p_Cs; - std::vector p_As; - std::vector p_Bs; - std::vector> p_Ds = {}; - - gemm_descs.reserve(group_count); - ggemm_kargs.reserve(group_count); - p_As.reserve(group_count); - p_Bs.reserve(group_count); - p_Ds.reserve(group_count); - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - std::vector> a_tensors; - std::vector> b_tensors; - std::vector, NumDs>> d_tensors; - std::vector> c_host_tensors; - std::vector> c_device_result_tensors; - - a_tensors.reserve(group_count); - b_tensors.reserve(group_count); - d_tensors.reserve(group_count); - c_host_tensors.reserve(group_count); - c_device_result_tensors.reserve(group_count); - - using DeviceMemPtr = std::unique_ptr; - - std::vector a_tensors_device, b_tensors_device, c_tensors_device; - std::vector> d_tensors_device; - - a_tensors_device.reserve(group_count); - b_tensors_device.reserve(group_count); - c_tensors_device.reserve(group_count); - d_tensors_device.resize(group_count); // reserve and update vector size - - std::size_t flop = 0, num_btype = 0; - - for(int i = 0; i < group_count; i++) - { - a_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); - b_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); - - auto d0_tensor = Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); - auto d1_tensor = Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); - - std::array, NumDs> d_tens = {d0_tensor, d1_tensor}; - d_tensors.push_back(d_tens); - c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - c_device_result_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc - << " b_k_n: " << b_tensors[i].mDesc - << " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl; - - flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; - num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() + - sizeof(BDataType) * b_tensors[i].GetElementSize() + - sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs + - sizeof(EDataType) * c_device_result_tensors[i].GetElementSize(); - - switch(config.init_method) - { - case 0: break; - case 1: - a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - for(int j = 0; j < NumDs; ++j) - { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - } - break; - case 2: - a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - for(int j = 0; j < NumDs; ++j) - { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - } - break; - default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - for(int j = 0; j < NumDs; ++j) - { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential{}); - } - } - } - - for(int i = 0; i < group_count; i++) - { - a_tensors_device.emplace_back( - std::make_unique(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType))); - b_tensors_device.emplace_back( - std::make_unique(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType))); - c_tensors_device.emplace_back(std::make_unique( - c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType))); - - for(int j = 0; j < NumDs; ++j) - { - d_tensors_device[i].emplace_back(std::make_unique( - d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType))); - } - - a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); - b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); - for(int j = 0; j < NumDs; ++j) - { - d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data()); - } - c_tensors_device[i]->SetZero(); - - p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); - p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer()); - p_Ds.push_back( - {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}); - p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); - - // The device op does not have to know M problem size at lunch time. - gemm_descs.push_back({0, - problem_size.Ns[i], - problem_size.Ks[i], - problem_size.stride_As[i], - problem_size.stride_Bs[i], - problem_size.stride_Cs[i], - {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}}); - ggemm_kargs.push_back( - {a_tensors_device[i]->GetDeviceBuffer(), - b_tensors_device[i]->GetDeviceBuffer(), - {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}, - c_tensors_device[i]->GetDeviceBuffer(), - problem_size.Ms[i], - problem_size.Ns[i], - problem_size.Ks[i], - problem_size.stride_As[i], - problem_size.stride_Bs[i], - {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}, - problem_size.stride_Cs[i]}); - } - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - - // do GEMM - auto argument = gemm.MakeArgument( - p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); - hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), - ggemm_kargs.data(), - gemm.GetDeviceKernelArgSize(&argument), - hipMemcpyHostToDevice)); - gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); - - invoker.Run(argument, StreamConfig{nullptr, false, 1}); - - bool pass = true; - if(config.do_verification) - { - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemmMultipleD; - - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - auto karg = ggemm_kargs[i]; - auto dev_res_tensor = - Tensor(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{})); - c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data()); - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], - b_tensors[i], - d_tensors[i], - c_host_tensors[i], - a_element_op, - b_element_op, - cde_element_op); - - ref_invoker.Run(ref_argument); - pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]); - } - - std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; - } - - if(config.time_kernel) - { - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm.GetTypeString() << std::endl; - } - - return pass; -} - -std::vector argToIntArray(char* input) -{ - std::vector out; - std::istringstream in(input); - std::string item; - - while(std::getline(in, item, ',')) - { - out.push_back(std::stoi(item)); - } - return out; -} - -int main(int argc, char* argv[]) -{ - ProblemSize problem_size; - ExecutionConfig config; - - if(argc < 10) - { - std::vector Ms{64, 127, 255, 129, 260, 190, 77}; - problem_size.group_count = Ms.size(); - - for(int i = 0; i < problem_size.group_count; i++) - { - problem_size.Ms.push_back(Ms[i]); - problem_size.Ns.push_back(252); - problem_size.Ks.push_back(4608); - - problem_size.stride_As.push_back(problem_size.Ks[i]); - problem_size.stride_Bs.push_back(problem_size.Ks[i]); - problem_size.stride_Cs.push_back(problem_size.Ns[i]); - - problem_size.stride_Ds.push_back({}); - for(int j = 0; j < NumDs; ++j) - { - problem_size.stride_Ds[i].push_back(problem_size.Ns[i]); - } - } - - std::cout - << "Usage:\n" - << "arg1: verification (0=no, 1=yes)\n" - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" - << "arg3: time kernel (0=n0, 1=yes)\n" - << "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " - "64,64 64,64 128,128)\n" - << "... setting default values." << std::endl; - } - else - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - - problem_size.Ms = argToIntArray(argv[4]); - problem_size.Ns = argToIntArray(argv[5]); - problem_size.Ks = argToIntArray(argv[6]); - - problem_size.stride_As = argToIntArray(argv[7]); - problem_size.stride_Bs = argToIntArray(argv[8]); - problem_size.stride_Cs = argToIntArray(argv[9]); - - for(int j = 0; j < NumDs; ++j) - { - problem_size.stride_Ds.push_back(problem_size.stride_Cs); - } - - problem_size.group_count = problem_size.Ms.size(); - } - - return !run_grouped_gemm(problem_size, config); -} +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp index e4da397c23..e942aad1c1 100644 --- a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp @@ -58,11 +58,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3 // clang-format off -//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; // clang-format on diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp index d5b2205892..fb3a6f0b4f 100644 --- a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp @@ -57,11 +57,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3 // clang-format off -//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; // clang-format on diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 764b533455..ffd0c5e9b7 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -323,8 +323,8 @@ bool run_grouped_gemm_example(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4: async hargs (0=n0, 1=yes)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4: async hargs (0=no, 1=yes)\n"); printf("arg5: group count (default=16)\n"); #if defined(EXAMPLE_USE_SPLITK) printf("arg6: k-batch count (default=1)\n"); diff --git a/example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc b/example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc new file mode 100644 index 0000000000..a71a23ab79 --- /dev/null +++ b/example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc @@ -0,0 +1,341 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector> stride_Ds; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument; + using GemmDesc = ck::tensor_operation::device::GemmDesc; + + // GEMM shape + std::vector gemm_descs; + std::vector ggemm_kargs; + std::vector p_Cs; + std::vector p_As; + std::vector p_Bs; + std::vector> p_Ds = {}; + + gemm_descs.reserve(group_count); + ggemm_kargs.reserve(group_count); + p_As.reserve(group_count); + p_Bs.reserve(group_count); + p_Ds.reserve(group_count); + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector, NumDs>> d_tensors; + std::vector> c_host_tensors; + std::vector> c_device_result_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_result_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + std::vector> d_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + d_tensors_device.resize(group_count); // reserve and update vector size + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + + auto d0_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + auto d1_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + + std::array, NumDs> d_tens = {d0_tensor, d1_tensor}; + d_tensors.push_back(d_tens); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_result_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc + << " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() + + sizeof(BDataType) * b_tensors[i].GetElementSize() + + sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs + + sizeof(EDataType) * c_device_result_tensors[i].GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential{}); + } + } + } + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType))); + b_tensors_device.emplace_back( + std::make_unique(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType))); + c_tensors_device.emplace_back(std::make_unique( + c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType))); + + for(int j = 0; j < NumDs; ++j) + { + d_tensors_device[i].emplace_back(std::make_unique( + d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType))); + } + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + for(int j = 0; j < NumDs; ++j) + { + d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data()); + } + c_tensors_device[i]->SetZero(); + + p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_Ds.push_back( + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}); + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + // The device op does not have to know M problem size at lunch time. + gemm_descs.push_back({0, + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + problem_size.stride_Cs[i], + {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}}); + ggemm_kargs.push_back( + {a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}, + problem_size.stride_Cs[i]}); + } + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + ggemm_kargs.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + + invoker.Run(argument, StreamConfig{nullptr, false, 1}); + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultipleD; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + auto karg = ggemm_kargs[i]; + auto dev_res_tensor = + Tensor(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{})); + c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data()); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + d_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + cde_element_op); + + ref_invoker.Run(ref_argument); + pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]); + } + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + } + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass; +} + +std::vector argToIntArray(char* input) +{ + std::vector out; + std::istringstream in(input); + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + return out; +} + +bool run_grouped_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + if(argc < 10) + { + std::vector Ms{64, 127, 255, 129, 260, 190, 77}; + problem_size.group_count = Ms.size(); + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(Ms[i]); + problem_size.Ns.push_back(252); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + + problem_size.stride_Ds.push_back({}); + for(int j = 0; j < NumDs; ++j) + { + problem_size.stride_Ds[i].push_back(problem_size.Ns[i]); + } + } + + std::cout + << "Usage:\n" + << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "... setting default values." << std::endl; + } + else + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.Ms = argToIntArray(argv[4]); + problem_size.Ns = argToIntArray(argv[5]); + problem_size.Ks = argToIntArray(argv[6]); + + problem_size.stride_As = argToIntArray(argv[7]); + problem_size.stride_Bs = argToIntArray(argv[8]); + problem_size.stride_Cs = argToIntArray(argv[9]); + + for(int j = 0; j < NumDs; ++j) + { + problem_size.stride_Ds.push_back(problem_size.stride_Cs); + } + + problem_size.group_count = problem_size.Ms.size(); + } + + return run_grouped_gemm(problem_size, config); +} diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp index 08915fdd26..a30bedf282 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp @@ -268,7 +268,7 @@ int main() pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2); } - bool time_kernel = true; + bool time_kernel = false; if(time_kernel) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp index 7a81d82c25..3401494625 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp @@ -302,7 +302,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp index 5a127d1cd4..e4960668eb 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp @@ -106,7 +106,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp index 29be3dde0a..c97fa7ebc5 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp @@ -106,7 +106,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp index 0574488e04..f32d5e9f6d 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp @@ -106,7 +106,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp index 7da40adc90..6c9fb8da75 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp @@ -108,7 +108,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp index 47f1d50ef5..4a63bee894 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp @@ -105,7 +105,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp index cac3db3078..ebd71f1799 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp @@ -112,7 +112,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp index 5ea09cfab2..1153a66615 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp @@ -112,7 +112,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp index 8e120851ec..6b5dde3cc7 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp @@ -112,7 +112,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/22_cgemm/cgemm_xdl_int4.cpp b/example/22_cgemm/cgemm_xdl_int4.cpp index 47b0e1d5a5..4f21c70562 100644 --- a/example/22_cgemm/cgemm_xdl_int4.cpp +++ b/example/22_cgemm/cgemm_xdl_int4.cpp @@ -81,7 +81,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // CGEMM shape ck::index_t M = 1024; diff --git a/example/23_softmax/softmax_blockwise.cpp b/example/23_softmax/softmax_blockwise.cpp index a741cb8133..0455819cdc 100644 --- a/example/23_softmax/softmax_blockwise.cpp +++ b/example/23_softmax/softmax_blockwise.cpp @@ -65,7 +65,7 @@ class SimpleAppArgs bool do_verification = true; int init_method = 2; - bool time_kernel = true; + bool time_kernel = false; public: void show_usage(const char* cmd) diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc index 12d7cf0aa6..86a36d53e2 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -27,7 +27,7 @@ struct ExecutionConfig final { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; }; template diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp index f7663cbd0a..6295cfdd04 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp @@ -17,7 +17,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; @@ -69,142 +69,6 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKNN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = - false> -struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_gs_ms_ks_{a_gs_ms_ks}, - b_gs_ns_ks_{b_gs_ns_ks}, - e_gs_ms_ns_{e_gs_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_gs_ms_ks_; - const Tensor& b_gs_ns_ks_; - Tensor& e_gs_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_G1_M2_N3_K1::Argument; - - float Run(const Argument& arg) - { - auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto n0, auto n1, auto n2) { - const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[3]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, ck::type_convert(arg.a_gs_ms_ks_(g0, m0, m1, k0))); - arg.b_element_op_( - v_b, - ck::type_convert(arg.b_gs_ns_ks_(g0, n0, n1, n2, k0))); - - v_acc += v_a * v_b; - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_gs_ms_ns_(g0, m0, m1, n0, n1, n2) = v_c; - }; - - make_ParallelTensorFunctor(f_gs_ms_ns, - arg.e_gs_ms_ns_.mDesc.GetLengths()[0], - arg.e_gs_ms_ns_.mDesc.GetLengths()[1], - arg.e_gs_ms_ns_.mDesc.GetLengths()[2], - arg.e_gs_ms_ns_.mDesc.GetLengths()[3], - arg.e_gs_ms_ns_.mDesc.GetLengths()[4], - arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{ - a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_M3_N2_K1" - << std::endl; - // clang-format on - - return str.str(); - } -}; - int main(int argc, char* argv[]) { bool do_verification = true; @@ -353,16 +217,18 @@ int main(int argc, char* argv[]) Tensor c_gs_ms_ns_host_result( e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1; + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceBatchedContraction_G1_M2_N3_K1; auto ref_gemm = ReferenceOpInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -399,7 +265,13 @@ int main(int argc, char* argv[]) } } - return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; + bool pass = ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result); + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + + if(!pass) + { + return 1; + } } return 0; diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp index 736dc09867..3adfecc7ae 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp @@ -17,6 +17,8 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; using ::ck::make_ParallelTensorFunctor; @@ -67,142 +69,6 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKNN; -template = - false> -struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_gs_ms_ks_{a_gs_ms_ks}, - b_gs_ns_ks_{b_gs_ns_ks}, - e_gs_ms_ns_{e_gs_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_gs_ms_ks_; - const Tensor& b_gs_ns_ks_; - Tensor& e_gs_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_G1_M3_N2_K1::Argument; - - float Run(const Argument& arg) - { - auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto m2, auto n0, auto n1) { - const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, - ck::type_convert(arg.a_gs_ms_ks_(g0, m0, m1, m2, k0))); - arg.b_element_op_( - v_b, ck::type_convert(arg.b_gs_ns_ks_(g0, n0, n1, k0))); - - v_acc += v_a * v_b; - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_gs_ms_ns_(g0, m0, m1, m2, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_gs_ms_ns, - arg.e_gs_ms_ns_.mDesc.GetLengths()[0], - arg.e_gs_ms_ns_.mDesc.GetLengths()[1], - arg.e_gs_ms_ns_.mDesc.GetLengths()[2], - arg.e_gs_ms_ns_.mDesc.GetLengths()[3], - arg.e_gs_ms_ns_.mDesc.GetLengths()[4], - arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{ - a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_G1_M3_N2_K1" - << std::endl; - // clang-format on - - return str.str(); - } -}; - int main(int argc, char* argv[]) { bool do_verification = true; @@ -353,17 +219,18 @@ int main(int argc, char* argv[]) Tensor c_gs_ms_ns_host_result( e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1; + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceBatchedContraction_G1_M3_N2_K1; auto ref_gemm = ReferenceOpInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -400,7 +267,13 @@ int main(int argc, char* argv[]) } } - return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; + bool pass = ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result); + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + + if(!pass) + { + return 1; + } } return 0; diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index d5d5521370..6cf93215f8 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -3,3 +3,4 @@ add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) +add_example_executable(example_batched_gemm_bias_e_permute_wmma_v3_fp16 batched_gemm_bias_e_permute_wmma_v3_fp16.cpp) diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 6efed7eb29..f102a0b132 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -106,352 +106,5 @@ using DeviceOpInstanceKKNN = using DeviceOpInstance = DeviceOpInstanceKKNN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = - false> -struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_gs_ms_ks_{a_gs_ms_ks}, - b_gs_ns_ks_{b_gs_ns_ks}, - e_gs_ms_ns_{e_gs_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_gs_ms_ks_; - const Tensor& b_gs_ns_ks_; - Tensor& e_gs_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_G2_M2_N2_K1::Argument; - - float Run(const Argument& arg) - { - auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) { - const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, - ck::type_convert(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0))); - arg.b_element_op_( - v_b, - ck::type_convert(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0))); - - v_acc += v_a * v_b; - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_ms_ns, - arg.e_gs_ms_ns_.mDesc.GetLengths()[0], - arg.e_gs_ms_ns_.mDesc.GetLengths()[1], - arg.e_gs_ms_ns_.mDesc.GetLengths()[2], - arg.e_gs_ms_ns_.mDesc.GetLengths()[3], - arg.e_gs_ms_ns_.mDesc.GetLengths()[4], - arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{ - a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_G2_M2_N2_K1" - << std::endl; - // clang-format on - - return str.str(); - } -}; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = true; - - ck::index_t G0 = 1; - ck::index_t G1 = 2; - - ck::index_t M0 = 4; - ck::index_t M1 = 128; - - ck::index_t N0 = 16; - ck::index_t N1 = 256; - - ck::index_t K0 = 2048; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 11) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - G0 = std::stoi(argv[4]); - G1 = std::stoi(argv[5]); - M0 = std::stoi(argv[6]); - M1 = std::stoi(argv[7]); - N0 = std::stoi(argv[8]); - N1 = std::stoi(argv[9]); - K0 = std::stoi(argv[10]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n"); - exit(0); - } - - // A[G0, G1, M0, M1, K0] - std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; - std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; - // B[G0, G1, N0, N1, K0] - std::vector b_gs_ns_ks_lengths{G0, G1, N0, N1, K0}; - std::vector b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1}; - - // D[G0, G1, M0, N0, M1, N1] - std::vector d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; - std::vector d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1}; - // E[G0, G1, M0, N0, M1, N1] - std::vector e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; - std::vector e_gs_ms_ns_strides{ - G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; - - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; - std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; - std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; - std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * - e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); - b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); - d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - a_gs_ms_ks_lengths, - a_gs_ms_ks_strides, - b_gs_ns_ks_lengths, - b_gs_ns_ks_strides, - std::array, 1>{d_gs_ms_ns_lengths}, - std::array, 1>{d_gs_ms_ns_strides}, - e_gs_ms_ns_lengths, - e_gs_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t G = - ck::accumulate_n(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{}); - - ck::index_t M = ck::accumulate_n( - e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); - std::cout << "GMNK=" << G << ", " << M << ", " << N << ", " << K << std::endl; - std::size_t flop = std::size_t(2) * G * M * N * K; - std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + - sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result( - e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - - using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; - - auto ref_gemm = ReferenceOpInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) - { - for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1) - { - for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0) - { - for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1) - { - for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0) - { - for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; - ++n1) - { - cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1), - c_ms_ns_host_result(g0, g1, m0, m1, n0, n1), - d_gs_ms_ns(g0, g1, m0, m1, n0, n1)); - } - } - } - } - } - } - - return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; - } - - return 0; -} +#include "run_batched_gemm_bias_e_permute_example.inc" +int main(int argc, char* argv[]) { return !run_batched_gemm_bias_e_permute_example(argc, argv); } diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_v3_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_v3_fp16.cpp new file mode 100644 index 0000000000..4e34f18b8b --- /dev/null +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_v3_fp16.cpp @@ -0,0 +1,111 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" + +using ::ck::DeviceMem; +using ::ck::HostTensorDescriptor; +using ::ck::make_ParallelTensorFunctor; +using ::ck::Tensor; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 1; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto ASpec = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; + +using DeviceOpInstanceKKNN = + ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< + NumDimG, + NumDimM, + NumDimN, + NumDimK, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + ASpec, + BSpec, + DESpec, + 128, + 64, + 64, + 64, + 4, + 4, + 16, + 16, + 1, + 4, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, + false, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, + false, + 1, + 1, + S<1, 64, 1, 2>, + S<8, 8>>; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_batched_gemm_bias_e_permute_example.inc" +int main(int argc, char* argv[]) { return !run_batched_gemm_bias_e_permute_example(argc, argv); } diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp index d7f468bc62..4ed054faaa 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp @@ -67,340 +67,5 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKNN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = - false> -struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_gs_ms_ks_{a_gs_ms_ks}, - b_gs_ns_ks_{b_gs_ns_ks}, - e_gs_ms_ns_{e_gs_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_gs_ms_ks_; - const Tensor& b_gs_ns_ks_; - Tensor& e_gs_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_G2_M2_N2_K1::Argument; - - float Run(const Argument& arg) - { - auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) { - const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, - ck::type_convert(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0))); - arg.b_element_op_( - v_b, - ck::type_convert(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0))); - - v_acc += v_a * v_b; - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_ms_ns, - arg.e_gs_ms_ns_.mDesc.GetLengths()[0], - arg.e_gs_ms_ns_.mDesc.GetLengths()[1], - arg.e_gs_ms_ns_.mDesc.GetLengths()[2], - arg.e_gs_ms_ns_.mDesc.GetLengths()[3], - arg.e_gs_ms_ns_.mDesc.GetLengths()[4], - arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_gs_ms_ks, - const Tensor& b_gs_ns_ks, - Tensor& e_gs_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{ - a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_G2_M2_N2_K1" - << std::endl; - // clang-format on - - return str.str(); - } -}; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - ck::index_t G0 = 1; - ck::index_t G1 = 2; - - ck::index_t M0 = 4; - ck::index_t M1 = 256; - - ck::index_t N0 = 16; - ck::index_t N1 = 128; - - ck::index_t K0 = 64; - - // A[G0, G1, M0, M1, K0] - std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; - std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; - // B[G0, G1, N0, N1, K0] - std::vector b_gs_ns_ks_lengths{G0, G1, N0, N1, K0}; - std::vector b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1}; - - // D[G0, G1, M0, N0, M1, N1] - std::vector d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; - std::vector d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1}; - // E[G0, G1, M0, N0, M1, N1] - std::vector e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; - std::vector e_gs_ms_ns_strides{ - G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - exit(0); - } - - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - - std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; - std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; - std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; - std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * - e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); - b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); - d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - a_gs_ms_ks_lengths, - a_gs_ms_ks_strides, - b_gs_ns_ks_lengths, - b_gs_ns_ks_strides, - std::array, 1>{d_gs_ms_ns_lengths}, - std::array, 1>{d_gs_ms_ns_strides}, - e_gs_ms_ns_lengths, - e_gs_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t G = - ck::accumulate_n(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{}); - - ck::index_t M = ck::accumulate_n( - e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * G * M * N * K; - std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + - sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result( - e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); - - using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; - - auto ref_gemm = ReferenceOpInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) - { - for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1) - { - for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0) - { - for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1) - { - for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0) - { - for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; - ++n1) - { - cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1), - c_ms_ns_host_result(g0, g1, m0, m1, n0, n1), - d_gs_ms_ns(g0, g1, m0, m1, n0, n1)); - } - } - } - } - } - } - - return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; - } - - return 0; -} +#include "run_batched_gemm_bias_e_permute_example.inc" +int main(int argc, char* argv[]) { return !run_batched_gemm_bias_e_permute_example(argc, argv); } diff --git a/example/29_batched_gemm_bias_e_permute/run_batched_gemm_bias_e_permute_example.inc b/example/29_batched_gemm_bias_e_permute/run_batched_gemm_bias_e_permute_example.inc new file mode 100644 index 0000000000..803c1eb0bf --- /dev/null +++ b/example/29_batched_gemm_bias_e_permute/run_batched_gemm_bias_e_permute_example.inc @@ -0,0 +1,350 @@ + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = + false> +struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_gs_ms_ks_{a_gs_ms_ks}, + b_gs_ns_ks_{b_gs_ns_ks}, + e_gs_ms_ns_{e_gs_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_gs_ms_ks_; + const Tensor& b_gs_ns_ks_; + Tensor& e_gs_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_G2_M2_N2_K1::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, + ck::type_convert(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0))); + arg.b_element_op_( + v_b, + ck::type_convert(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0))); + + v_acc += v_a * v_b; + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_gs_ms_ns_.mDesc.GetLengths()[0], + arg.e_gs_ms_ns_.mDesc.GetLengths()[1], + arg.e_gs_ms_ns_.mDesc.GetLengths()[2], + arg.e_gs_ms_ns_.mDesc.GetLengths()[3], + arg.e_gs_ms_ns_.mDesc.GetLengths()[4], + arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{ + a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_G2_M2_N2_K1" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int run_batched_gemm_bias_e_permute_example(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::index_t G0 = 1; + ck::index_t G1 = 2; + + ck::index_t M0 = 4; + ck::index_t M1 = 128; + + ck::index_t N0 = 16; + ck::index_t N1 = 256; + + ck::index_t K0 = 2048; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + G0 = std::stoi(argv[4]); + G1 = std::stoi(argv[5]); + M0 = std::stoi(argv[6]); + M1 = std::stoi(argv[7]); + N0 = std::stoi(argv[8]); + N1 = std::stoi(argv[9]); + K0 = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n"); + exit(0); + } + + // A[G0, G1, M0, M1, K0] + std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; + std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; + // B[G0, G1, N0, N1, K0] + std::vector b_gs_ns_ks_lengths{G0, G1, N0, N1, K0}; + std::vector b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1}; + + // D[G0, G1, M0, N0, M1, N1] + std::vector d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1}; + // E[G0, G1, M0, N0, M1, N1] + std::vector e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector e_gs_ms_ns_strides{ + G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; + std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; + std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * + e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); + d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + std::array, 1>{d_gs_ms_ns_lengths}, + std::array, 1>{d_gs_ms_ns_strides}, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t G = + ck::accumulate_n(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{}); + + ck::index_t M = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); + std::cout << "GMNK=" << G << ", " << M << ", " << N << ", " << K << std::endl; + std::size_t flop = std::size_t(2) * G * M * N * K; + std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + + sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + + using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) + { + for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1) + { + for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0) + { + for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1) + { + for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0) + { + for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; + ++n1) + { + cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1), + c_ms_ns_host_result(g0, g1, m0, m1, n0, n1), + d_gs_ms_ns(g0, g1, m0, m1, n0, n1)); + } + } + } + } + } + } + + return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result); + } + + return 1; +} diff --git a/example/30_grouped_conv_fwd_multiple_d/common.hpp b/example/30_grouped_conv_fwd_multiple_d/common.hpp index e1939d4300..dce9f62293 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common.hpp @@ -92,7 +92,7 @@ struct ExecutionConfig final { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; }; #define DefaultConvParam \ diff --git a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp index ca8cba039f..2b27405ecd 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp @@ -92,7 +92,7 @@ struct ExecutionConfig final { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; }; #define DefaultConvParam \ diff --git a/example/33_multiple_reduce/dual_reduce_common.hpp b/example/33_multiple_reduce/dual_reduce_common.hpp index 3f04af5e89..923b5b6f15 100644 --- a/example/33_multiple_reduce/dual_reduce_common.hpp +++ b/example/33_multiple_reduce/dual_reduce_common.hpp @@ -40,7 +40,7 @@ class SimpleAppArgs bool do_verification = true; int init_method = 2; - bool time_kernel = true; + bool time_kernel = false; public: SimpleAppArgs() diff --git a/example/35_splitK_gemm/common.hpp b/example/35_splitK_gemm/common.hpp index d0f03f3611..8bf09ee786 100644 --- a/example/35_splitK_gemm/common.hpp +++ b/example/35_splitK_gemm/common.hpp @@ -44,7 +44,7 @@ struct ExecutionConfig final { bool do_verification = true; int init_method = 2; - bool time_kernel = true; + bool time_kernel = false; }; template diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index 2f290497c9..ea8858b958 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -56,7 +56,7 @@ template<> struct emb_kernel { using kernel_type = DeviceInsta int main(int argc, char* argv[]) { - bool time_kernel = true; + bool time_kernel = false; ck::index_t num_rows = 65536; constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{}; diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index dc0b95863e..ab87124c6b 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -195,7 +195,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t M = 1024; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp index c6cc9c6a15..9e7039461c 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp @@ -86,7 +86,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp index 0f49cb5a38..fa6a36c212 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp @@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp index 5652cc38ab..45651da757 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp @@ -87,7 +87,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp index 138a214127..cda4c1419c 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp @@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp index 1652cea214..0e52ac280a 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp @@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp index f127940377..9bff452a67 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp @@ -90,7 +90,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp index 7a03a3efe0..17a7b632af 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp @@ -88,7 +88,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp index 155024dc62..345277e092 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp @@ -88,7 +88,7 @@ using DeviceGroupedConvNDFwdInstance = int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc index b1596b5a53..d5f9b831f0 100644 --- a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc +++ b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc @@ -12,7 +12,7 @@ int run_groupnorm_fwd_example(int argc, char* argv[]) ck::index_t C = 128; bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; bool log_kernel = true; if(argc == 1) diff --git a/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp index 14b338c9c5..e90880dabd 100644 --- a/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp @@ -53,7 +53,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; std::vector nchw = {16, 128, 32, 64}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index a7d139fc95..2b99d9261f 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -46,7 +46,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index cd1db4cdaf..276aa7f3c7 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp index 683c5cb072..0842325bad 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index abfd3ccf7c..a48f2349c9 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -49,7 +49,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp index ff4e8f3a3d..39d88c47a1 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp index 939860bf69..3aef0fdaac 100644 --- a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -121,7 +121,7 @@ void reference_scale_permute_amax(Tensor& input, int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; const float scale = 2.f; diff --git a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp index 497f1c67c8..86af00e4fb 100644 --- a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp @@ -58,7 +58,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; if(argc == 1) { diff --git a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp index eb95128f38..71cee9c420 100644 --- a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp +++ b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -84,7 +84,7 @@ void host_elementwise2D(HostTensorC& C, int main(int argc, char* argv[]) { bool do_verification = true; - bool time_kernel = true; + bool time_kernel = false; ck::index_t M = 48 * 256; ck::index_t N = 1024; diff --git a/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp index 24c58bb69a..1e3d946bad 100644 --- a/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp +++ b/example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp @@ -31,8 +31,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F16; using B0DataType = F16; @@ -139,11 +140,11 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index c0452b6067..10f7a38863 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -205,7 +205,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // GEMM shape ck::index_t N = 4096; diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp index ecc3034bba..d6082e5882 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp @@ -193,7 +193,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; #if 1 // GEMM shape ck::index_t N = 4096; diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp index ae707e74a2..ccb3a9c435 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp @@ -119,7 +119,7 @@ static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_an static constexpr bool MulRoutedWeight = false; // splitk gemm1 does not do routedWeight. #if 1 -static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t MPerBlock = 64; static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1); @@ -156,7 +156,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, + int32_t, A0DataType, A0DataType, A0DataType, A0DataType, true>; #else static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< @@ -171,7 +172,8 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, + int32_t, A0DataType, A0DataType, A0DataType, A0DataType, false>; #endif // clang-format on @@ -182,12 +184,14 @@ int main(int argc, char* argv[]) bool time_kernel = true; #if 1 // GEMM shape - ck::index_t N = 4096; - ck::index_t K = 6144; + ck::index_t N = 1536; + ck::index_t K = 4096; + // ck::index_t N = 4096; + // ck::index_t K = 6144; // ck::index_t N = 128; // ck::index_t K = 512; - ck::index_t experts = 8; - ck::index_t topk = 2; + ck::index_t experts = 16; + ck::index_t topk = 8; // ck::index_t sorted_tile_num = 515; // ck::index_t valid_tile_num = 512; // ck::index_t tokens = 208; @@ -196,9 +200,9 @@ int main(int argc, char* argv[]) // ck::index_t sorted_tile_num = 259; // ck::index_t valid_tile_num = 256; // ck::index_t tokens = 4096; - ck::index_t sorted_tile_num = 2; - ck::index_t valid_tile_num = 2; - ck::index_t tokens = 32; + ck::index_t sorted_tile_num = 16; + ck::index_t valid_tile_num = 16; + ck::index_t tokens = 4; #else // deepseek ck::index_t N = 2048; @@ -209,7 +213,7 @@ int main(int argc, char* argv[]) ck::index_t sorted_tile_num = 261; ck::index_t valid_tile_num = 256; #endif - ck::index_t KBatch = 6; + ck::index_t KBatch = 1; if(argc == 1) { // use default case diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 0067c1d1fb..a2002270dc 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -194,7 +194,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index a602838c30..9f4cd13573 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -185,7 +185,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index fb5e3b6456..552d3cd7b5 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -188,7 +188,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // tokens = 1 // topk = 1 diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index f56410d37a..377b53b519 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -164,7 +164,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/65_gemm_multiply_multiply/run_gemm_multiply_multiply_wp_example.inc b/example/65_gemm_multiply_multiply/run_gemm_multiply_multiply_wp_example.inc index 2de3222380..10dce7fe64 100644 --- a/example/65_gemm_multiply_multiply/run_gemm_multiply_multiply_wp_example.inc +++ b/example/65_gemm_multiply_multiply/run_gemm_multiply_multiply_wp_example.inc @@ -5,6 +5,8 @@ int run_gemm_example(int argc, char* argv[]) { + using Bypass = ck::tensor_layout::BypassLayoutVerification; + bool do_verification = true; int init_method = 1; bool time_kernel = false; @@ -64,11 +66,11 @@ int run_gemm_example(int argc, char* argv[]) if(std::is_same::value) { - return ck::HostTensorDescriptor({row, col}, {stride, 1_uz}); + return ck::HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return ck::HostTensorDescriptor({row, col}, {1_uz, stride}); + return ck::HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp index 3ce059ba20..586ecd81bf 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp @@ -178,7 +178,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp index d1d601977d..b3b2ebcbc0 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp @@ -178,7 +178,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp index 0078cc5625..5c7668ab73 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp @@ -208,7 +208,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp index 202241d14f..04c3afc62b 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -171,7 +171,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp index 660ccabc94..12bb76eccd 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -171,7 +171,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp index f398959114..6a5f5a6b9f 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp @@ -204,7 +204,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; // per expert: // GEMM shape diff --git a/example/68_gemm_add/common.hpp b/example/68_gemm_add/common.hpp index 362dc2fff2..12d4b381b2 100644 --- a/example/68_gemm_add/common.hpp +++ b/example/68_gemm_add/common.hpp @@ -87,7 +87,7 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); } - else if(argc == 13) + else if(argc == 11) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); diff --git a/example/68_gemm_add/run_gemm_add_example_wmma.inc b/example/68_gemm_add/run_gemm_add_example_wmma.inc index ba15d03e07..0f2cc08edf 100644 --- a/example/68_gemm_add/run_gemm_add_example_wmma.inc +++ b/example/68_gemm_add/run_gemm_add_example_wmma.inc @@ -6,6 +6,7 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; + using Bypass = ck::tensor_layout::BypassLayoutVerification; auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; @@ -13,11 +14,11 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/68_gemm_add/run_gemm_add_example_xdl.inc b/example/68_gemm_add/run_gemm_add_example_xdl.inc index da22230a4e..186423d32f 100644 --- a/example/68_gemm_add/run_gemm_add_example_xdl.inc +++ b/example/68_gemm_add/run_gemm_add_example_xdl.inc @@ -6,6 +6,7 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; + using Bypass = ck::tensor_layout::BypassLayoutVerification; auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; @@ -13,11 +14,11 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/69_gemm_add_relu/common.hpp b/example/69_gemm_add_relu/common.hpp index e54c5317ae..de84d69a5e 100644 --- a/example/69_gemm_add_relu/common.hpp +++ b/example/69_gemm_add_relu/common.hpp @@ -87,7 +87,7 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); } - else if(argc == 13) + else if(argc == 11) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); diff --git a/example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc b/example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc index 8deac6dec8..c3cfd00ab3 100644 --- a/example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc +++ b/example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc @@ -6,6 +6,7 @@ bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; + using Bypass = ck::tensor_layout::BypassLayoutVerification; auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; @@ -13,11 +14,11 @@ bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& c [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc b/example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc index df7474bab5..cca85aa11c 100644 --- a/example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc +++ b/example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc @@ -6,6 +6,7 @@ bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; + using Bypass = ck::tensor_layout::BypassLayoutVerification; auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; @@ -13,11 +14,11 @@ bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& c [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(std::is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{}); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{}); } }; diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index aed19c083a..c39f89fcaf 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -6,6 +6,35 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/library/include ) +if(WIN32) + # On Windows, HIP uses -nostdlib which prevents C runtime linking + # We need legacy_stdio_definitions.lib to provide vfprintf and other legacy C functions + # This is mainly needed for the getopt library. + set(LEGACY_STDIO_SEARCH_PATHS) + + # Try to use Visual C++ Tools environment variable (if build executes from Visual Studio Developer Command Prompt) + if(DEFINED ENV{VCToolsInstallDir}) + list(APPEND LEGACY_STDIO_SEARCH_PATHS "$ENV{VCToolsInstallDir}/lib/x64") + endif() + + # Fallback: Search common Visual Studio installation locations + file(GLOB MSVC_LIB_PATHS "C:/Program Files/Microsoft Visual Studio/*/*/VC/Tools/MSVC/*/lib/x64") + list(APPEND LEGACY_STDIO_SEARCH_PATHS ${MSVC_LIB_PATHS}) + + # Use find_library to locate the library + find_library(LEGACY_STDIO_LIB legacy_stdio_definitions + PATHS ${LEGACY_STDIO_SEARCH_PATHS} + NO_DEFAULT_PATH + ) + + if(LEGACY_STDIO_LIB) + message(STATUS "Found legacy_stdio_definitions.lib: ${LEGACY_STDIO_LIB}") + add_link_options("SHELL:-Xlinker \"${LEGACY_STDIO_LIB}\"") + else() + message(WARNING "Could not find legacy_stdio_definitions.lib - examples may fail to link.") + endif() +endif() + add_custom_target(examples) @@ -216,6 +245,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt) add_dependencies(examples ${EXAMPLE_NAME}) set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS}) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 95e8379769..9a2d727253 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -36,6 +36,19 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} +SUPPORTED_PAGE_SIZE = [1, 16, 1024] +SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"] +SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"] +KV_MEMORY_LAYOUT_ENUM_MAP = { + "vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT", + "linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT", +} +KV_LOOKUP_TABLE_ENUM_MAP = { + "vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D", + "sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D", +} + + FMHA_BATCH_PREFILL_PIPELINE_MAP = { "qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync", } @@ -59,7 +72,7 @@ using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, {F_vlayout}>; -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, +using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, @@ -69,13 +82,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_lse}, {F_dropout}, {F_qscale}, - {F_occupancy}>; + {F_occupancy}, + false, + {F_page_size}, + {F_kv_memory_layout}, + {F_kv_lookup_table}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; using fmha_mask_{F_idx} = {F_mask}; -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem< typename FmhaFwdTypeConfig::QDataType, typename FmhaFwdTypeConfig::KDataType, typename FmhaFwdTypeConfig::VDataType, @@ -92,6 +109,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< fmha_variant_{F_idx}, fmha_mask_{F_idx}, false, + {F_page_size}, fmha_trait_{F_idx}>; using fmha_pipeline_{F_idx} = {F_pipeline}< @@ -105,8 +123,8 @@ using fmha_epilogue_{F_idx} = using fmha_kernel_{F_idx} = ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; +using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; #include @@ -184,8 +202,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v """ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{ + using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; return fmha_batch_prefill_(s, a); }} """ @@ -230,12 +248,15 @@ class FmhaFwdApiTrait: dpad: str dvpad: str constraint: CppConstraint + kv_memory_layout: str + kv_lookup_table: str + page_size: int = 1 # page block size @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}" ) @property @@ -322,6 +343,8 @@ class FmhaFwdPipeline: F_dropout: str # F_qscale: str # no/pertensor F_mask: str # value from MASK_MAP + F_kv_memory_layout: str # + F_kv_lookup_table: str # F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -382,6 +405,8 @@ class FmhaFwdPipeline: n += f"_{self.F_qscale}" else: n += "_nqscale" + + n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table return n @@ -440,6 +465,13 @@ class FmhaFwdApiPool: F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype], + F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[ + trait.kv_memory_layout + ], + F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[ + trait.kv_lookup_table + ], + F_page_size=trait.page_size, ) if_j = "if" if j == 0 else "else if" per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( @@ -497,6 +529,7 @@ class FmhaFwdKernel: F_tile: FmhaFwdTileSize F_pipeline: FmhaFwdPipeline mask_impl: str + F_page_size: int = 1 # page block size @property def template(self) -> str: @@ -534,17 +567,24 @@ class FmhaFwdKernel: F_dropout=BOOL_MAP[self.F_pipeline.F_dropout], F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale], F_occupancy=self.F_tile.F_occupancy, + F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[ + self.F_pipeline.F_kv_memory_layout + ], + F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[ + self.F_pipeline.F_kv_lookup_table + ], F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode=MODE_MAP[self.F_mode], F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag], + F_page_size=self.F_page_size, ) @property def name(self) -> str: # TODO: we don't encode idx here return ( - f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_" + self.F_tile.name + "_" + self.F_pipeline.name @@ -578,6 +618,9 @@ class FmhaFwdKernel: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + kv_memory_layout=self.F_pipeline.F_kv_memory_layout, + kv_lookup_table=self.F_pipeline.F_kv_lookup_table, + page_size=self.F_page_size, ) @@ -604,23 +647,42 @@ class KernelComponentFactory: pipelines = [] if dtype in ["fp16", "bf16"]: qscale = "no" - for logits, mask, bias, lse, dropout in itertools.product( + for ( + logits, + mask, + bias, + lse, + dropout, + kv_memory_layout, + kv_lookup_table, + ) in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], + SUPPORTED_KV_MEMORY_LAYOUT, + SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip elif dtype in ["fp8bf16"]: # no need lse/dropout kernels - for logits, qscale, mask, bias in itertools.product( + for ( + logits, + qscale, + mask, + bias, + kv_memory_layout, + kv_lookup_table, + ) in itertools.product( ["t", "f"], ["pertensor"], get_mask_map(mask_impl).keys(), ["no"], + SUPPORTED_KV_MEMORY_LAYOUT, + SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip else: assert False return pipelines @@ -672,69 +734,75 @@ def get_fwd_blobs( or pipeline.F_logits == "f" ): continue - k = FmhaFwdKernel( - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, - ) - if kernel_filter != "": - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # Aiter(mha_batch_prefill) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # aiter::mha_batch_prefill C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # fp32 only - if receipt == 800 or receipt == 801: - cond = dtype == "fp32" - if not cond: + # Generate kernels for both page_size=16 and page_size=1024 + for page_size in SUPPORTED_PAGE_SIZE: + if page_size == 1 and pipeline.F_kv_memory_layout != "linear": continue + k = FmhaFwdKernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + F_page_size=page_size, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "alibi"] + cond &= pipeline.F_qscale == "no" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "bias"] + cond &= pipeline.F_qscale == "no" + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_qscale == "no" + if not cond: + continue + # Aiter(mha_batch_prefill) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16", "fp8bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # aiter::mha_batch_prefill C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16", "fp8bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_qscale == "no" + if not cond: + continue - api_pool.register_traits(k.api_trait()) - gen.append(k) + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == "fp32" + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) return (api_pool, gen) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index dd65c0298b..81c7b067d3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -315,7 +315,7 @@ class FmhaFwdApiTrait: assert False def seqtune(self, max_bm0: int) -> str: - if self.bm0 == max_bm0: + if self.bm0 == max_bm0 or self.bm0 == 64: return "true/*fall back to largest tile*/" else: return f"a.seqlen_q <= {self.bm0}" @@ -847,6 +847,11 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128) and kernel_ctx.tile.F_bm0 != 128 ) + or ( + (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) + and kernel_ctx.pipeline.tag != "qr_async" + and kernel_ctx.tile.F_bk0 == 64 + ) ): # non qr_async_trload only support km0=128 tile size when hdim is not 128 # non qr_async only support kn0=128 tile size when hdim is 128 @@ -942,6 +947,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')), FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index 6f2616cae5..f5ad6b2bc5 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -114,7 +114,8 @@ auto create_args(int argc, char* argv[]) .insert("kv_eff_lens", "", "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override."); + "Comma-separated list of length 'b'. If empty, no override.") + .insert("init_sink", "0", "value to init the output tensor sink value for validation"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -157,6 +158,7 @@ auto run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); std::string init_method = arg_parser.get_str("init"); uint32_t seed = arg_parser.get_uint32("seed"); + int init_sink_value = arg_parser.get_int("init_sink"); ck_tile::stream_config stream_config{nullptr, true, @@ -203,6 +205,7 @@ auto run(const ck_tile::ArgParser& arg_parser) init_method, seed, do_validation, + init_sink_value, stream_config, json); } diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index ba55d6d722..fdd720fd75 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -230,6 +230,7 @@ struct fmha_fwd_args // array [batch + 1]. (Used with padding) const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -317,6 +318,7 @@ struct fmha_fwd_pagedkv_args const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -400,6 +402,7 @@ struct fmha_fwd_splitkv_args const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -476,6 +479,7 @@ struct fmha_fwd_appendkv_args ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache) + const void* sink_ptr; ck_tile::index_t stride_q; ck_tile::index_t stride_k; @@ -519,6 +523,7 @@ struct fmha_batch_prefill_args // 1) + // kargs.kv_last_page_lens[b] const void* seqstart_q_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -529,14 +534,25 @@ struct fmha_batch_prefill_args ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; - // SGLang-style page table - int32_t num_total_pages; - void* kv_indptr; - void* kv_page_indices; -#if 0 // we assume page_block_size=1 for now - void* kv_last_page_lens; - ck_tile::index_t page_block_size; -#endif + // KV cache page table fields (kv_lookup_table selects interpretation): + // - SGLANG_PAGE_TABLE_1D: + // kv_indptr: prefix-sum [batch+1] into kv_page_indices + // kv_page_indices: 1D list of physical page ids, length = num_total_pages + // kv_last_page_lens: per-batch last page lengths [batch] + // - VLLM_BLOCK_TABLE_2D: + // kv_page_indices: block_table [batch, max_blocks_per_seq] (2D) + // batch_stride_block_table: row stride for block_table + // seqlen_k_ptr: per-batch seqlen_k [batch] + int32_t num_total_pages; // total physical pages in KV cache (SGLang/vLLM) + ck_tile::index_t page_block_size; // tokens per page (SGLang/vLLM) + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum + kv_memory_layout; // KV memory layout (SGLang/vLLM) + ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table; // lookup table layout selector + void* kv_indptr; // SGLang: prefix-sum; vLLM: unused + void* kv_page_indices; // SGLang: 1D page list; vLLM: block_table 2D + void* kv_last_page_lens; // SGLang: last page lengths; vLLM: unused + void* seqlen_k_ptr; // vLLM: per-batch seqlen_k; SGLang: unused + ck_tile::index_t batch_stride_block_table; // vLLM: row stride; SGLang: unused float scale_s; float scale_p; @@ -627,7 +643,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.s_randval, args.drop_seed_offset, args.cu_seqlen_q_ptr, - args.cu_seqlen_k_ptr); + args.cu_seqlen_k_ptr, + args.sink_ptr); } else { // create batch mode kernel arguments @@ -677,7 +694,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.s_randval, args.drop_seed_offset, args.cu_seqlen_q_ptr, - args.cu_seqlen_k_ptr); + args.cu_seqlen_k_ptr, + args.sink_ptr); } }(); @@ -837,7 +855,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.window_size_right, args.sink_size, args.mask_type, - args.min_seqlen_q); + args.min_seqlen_q, + args.sink_ptr); } else { // create batch mode kernel arguments @@ -882,7 +901,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.window_size_left, args.window_size_right, args.sink_size, - args.mask_type); + args.mask_type, + args.sink_ptr); } }(); @@ -949,7 +969,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.window_size_left, args.window_size_right, args.sink_size, - args.mask_type); + args.mask_type, + args.sink_ptr); } else { // create batch mode kernel arguments @@ -997,7 +1018,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.window_size_left, args.window_size_right, args.sink_size, - args.mask_type); + args.mask_type, + args.sink_ptr); } }(); @@ -1113,6 +1135,22 @@ template auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) { assert(args.nhead_q % args.nhead_k == 0); + using PageTableKargs = typename FmhaKernel::PageBlockTableKargs; + const PageTableKargs page_table = [&]() { + if constexpr(FmhaKernel::kKVLookupTable == + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) + { + return PageTableKargs{reinterpret_cast(args.kv_indptr), + reinterpret_cast(args.kv_page_indices), + reinterpret_cast(args.kv_last_page_lens)}; + } + else + { + return PageTableKargs{reinterpret_cast(args.kv_page_indices), + args.batch_stride_block_table, + reinterpret_cast(args.seqlen_k_ptr)}; + } + }(); auto kargs = [&] { // create group mode kernel arguments if constexpr(FmhaKernel::kIsGroupMode) @@ -1133,12 +1171,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.num_total_pages, - args.kv_indptr, - args.kv_page_indices, -#if 0 // we assume page_block_size=1 for now - args.kv_last_page_lens, args.page_block_size, -#endif + page_table, args.scale_s, args.scale_p, args.scale_o, @@ -1164,7 +1198,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.mask_type, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.sink_ptr); } else { // create batch mode kernel arguments @@ -1184,12 +1219,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.num_total_pages, - args.kv_indptr, - args.kv_page_indices, -#if 0 // we assume page_block_size=1 for now - args.kv_last_page_lens, args.page_block_size, -#endif + page_table, args.scale_s, args.scale_p, args.scale_o, @@ -1220,7 +1251,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.mask_type, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.sink_ptr); } }(); @@ -1281,6 +1313,65 @@ struct fmha_fwd_traits_ static constexpr bool kHasSink = kHasSink_; }; +template +struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_ +{ + static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; + static constexpr auto kKVLookupTable = kKVLookupTable_; + static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_; + static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout"); +}; + template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); @@ -1527,7 +1618,15 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, fmha_fwd_appendkv_args, const ck_tile::stream_config&); -using fmha_batch_prefill_traits = fmha_fwd_traits; +struct fmha_batch_prefill_traits : public fmha_fwd_traits +{ + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; + int page_size = 1; +}; + float fmha_batch_prefill(fmha_batch_prefill_traits, fmha_batch_prefill_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 536fcb0692..0c988b2acc 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -149,6 +149,28 @@ int override_num_splits_if_necessary( return num_splits; } +template +void copy_attention_scores_with_sink(const ck_tile::HostTensor& s_host_ref, + const ck_tile::HostTensor& sink_host, + ck_tile::HostTensor& s_with_sinks_ref, + ck_tile::index_t nhead, + ck_tile::index_t real_seqlen_q, + ck_tile::index_t real_seqlen_k) +{ + for(auto i_h = 0; i_h < nhead; i_h++) + { + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c); + } + // Append sink token at the end of each row + s_with_sinks_ref(i_h, i_r, real_seqlen_k) = sink_host(i_h); + } + } +} + template fwd_result fmha_fwd_run(mode_enum mode, ck_tile::index_t batch, @@ -184,6 +206,7 @@ fwd_result fmha_fwd_run(mode_enum mode, std::string init_method, uint32_t seed, int do_validation, + int init_sink_value, const ck_tile::stream_config& stream_config, std::optional json = std::nullopt) { @@ -527,6 +550,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor sink_host({nhead}); ck_tile::HostTensor k_host( 0 < page_block_size ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) @@ -609,6 +633,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}( bias_host); } + else if(init_method == "ni") { ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); @@ -695,10 +720,17 @@ fwd_result fmha_fwd_run(mode_enum mode, iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); - + if(init_sink_value != 0) + { + // sink is initialized to a fixed integer value for easy debugging and use 30 to 60 range + // for close to rowmax values. + ck_tile::FillUniformDistributionIntegerValue{30.f, 60.f, next_seed()}( + sink_host); + } ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sink_buf(sink_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); @@ -743,6 +775,7 @@ fwd_result fmha_fwd_run(mode_enum mode, q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); v_buf.ToDevice(v_host.data()); + sink_buf.ToDevice(sink_host.data()); knew_buf.ToDevice(knew_host.data()); vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); @@ -971,7 +1004,10 @@ fwd_result fmha_fwd_run(mode_enum mode, args.q_ptr = q_buf.GetDeviceBuffer(); args.k_ptr = k_buf.GetDeviceBuffer(); args.v_ptr = v_buf.GetDeviceBuffer(); - + if(init_sink_value != 0) + args.sink_ptr = sink_buf.GetDeviceBuffer(); + else + args.sink_ptr = nullptr; args.batch = batch; args.seqlen_q = shape_seqlen_q; // unused in group mode args.hdim_q = hdim_q; @@ -1351,8 +1387,8 @@ fwd_result fmha_fwd_run(mode_enum mode, auto oacc_element_func = [&]() { if constexpr(std::is_same_v && supports_qscale) - return ck_tile::composes(ck_tile::saturates{}, - ck_tile::scales{scale_o_host}); + return ck_tile::make_composes(ck_tile::saturates{}, + ck_tile::scales{scale_o_host}); else if constexpr(supports_qscale) return ck_tile::scales{scale_o_host}; else @@ -1675,19 +1711,57 @@ fwd_result fmha_fwd_run(mode_enum mode, mask.type == mask_enum::mask_top_left)); } const ck_tile::HostTensor masked_s_host_ref = s_host_ref; - if(lse) + if(init_sink_value != 0) { - ck_tile:: - reference_batched_softmax( - s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref); + // Create extended tensor with sink token + ck_tile::HostTensor s_with_sinks_ref( + {nhead, real_seqlen_q, real_seqlen_k + 1}); + + // Copy original attention scores and append sink values + copy_attention_scores_with_sink( + s_host_ref, sink_host, s_with_sinks_ref, nhead, real_seqlen_q, real_seqlen_k); + + // Compute softmax on extended tensor + ck_tile::HostTensor p_extended( + {nhead, real_seqlen_q, real_seqlen_k + 1}); + + if(lse) + { + ck_tile::reference_batched_softmax( + s_with_sinks_ref, p_extended, p_compute_element_func, lse_host_ref); + } + else + { + ck_tile::reference_batched_softmax( + s_with_sinks_ref, p_extended, p_compute_element_func); + } + + // Extract only the original columns (exclude sink token column) + p_host_ref.ForEach( + [&](auto& self, auto idx) { self(idx) = p_extended(idx[0], idx[1], idx[2]); }); } else { - ck_tile:: - reference_batched_softmax( + // No sink tokens - compute softmax directly + if(lse) + { + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref); + } + else + { + ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, p_compute_element_func); + } } - if(p_drop > 0) { ck_tile::HostTensor randval_host_ref( diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh index 664c825418..5c9d3132b3 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh @@ -84,3 +84,10 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l # 1 1 1 1 1 1 1 1 1 1 # l=2/r=0(br) l=2/r=0/s=2(br) +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1 + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0 + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 + +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1 diff --git a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp index 77a9fe4271..df8351602b 100644 --- a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp @@ -69,107 +69,88 @@ struct BasicInvoker using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + CodegenPipelineProblem::TransposeC>>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - CodegenPipelineProblem::TransposeC, - memory_operation>>; + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; - - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - rotating_mem_ptr = - std::make_unique>( - kargs.as_ptr[0], - kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - if(args.k_batch == 1) + if(s.flush_cache_) { - return Run(MemoryOpSet{}); + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = std::make_unique>( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; } else { - return Run(MemoryOpAtomicAdd{}); + preprocess = clear_gemm_output; } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp index c312a53c2a..d2460193d8 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp @@ -72,160 +72,144 @@ struct SplitKTwoStageInvoker using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmKernel = ck_tile::GemmKernel; - using GemmKernel = ck_tile::GemmKernel; + ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType)); + ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); + auto c_ptr = ws_args.c_ptr; + ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); + auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); - ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType)); - ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); - auto c_ptr = ws_args.c_ptr; - ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); + const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) + : GemmKernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernel::BlockSize(); - const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) - : GemmKernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = GemmKernel::BlockSize(); + if(!GemmKernel::IsSupportedArgument(gemm_kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(!GemmKernel::IsSupportedArgument(gemm_kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; + using BlockTile = ck_tile::sequence<2048>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; - using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; - using BlockTile = ck_tile::sequence<2048>; - using BlockWarps = ck_tile::sequence<8>; - using WarpTile = ck_tile::sequence<64>; + using ElementwiseShape = + ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + using ElementwiseKernel = + ck_tile::ElementWiseKernel; - using ElementwiseShape = - ck_tile::ElementWiseShape; - using Problem = ck_tile::ElementWisePipelineProblem; - using ElementwiseKernel = - ck_tile::ElementWiseKernel; + ck_tile::index_t total_elements = 1; + std::vector shape = {args.M, args.N}; - ck_tile::index_t total_elements = 1; - std::vector shape = {args.M, args.N}; + for(auto d : shape) + total_elements *= d; - for(auto d : shape) - total_elements *= d; + const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; - const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = 1; + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; - constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); - ck_tile::index_t kGridSize = - (total_elements + elements_per_block - 1) / elements_per_block; + auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.c_ptr)); + auto input_size = ck_tile::make_tuple(args.M, args.N); - auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.c_ptr)); - auto input_size = ck_tile::make_tuple(args.M, args.N); + // Check if the kernel configuration is supported + if(!ElementwiseKernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); + } - // Check if the kernel configuration is supported - if(!ElementwiseKernel::IsSupportedArgument(input_size)) - { - throw std::runtime_error( - "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; - - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - rotating_mem_ptr = - std::make_unique>( - gemm_kargs.as_ptr[0], - gemm_kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel( - GemmKernel{}, grids, blocks, 0, gemm_kargs), - ck_tile::make_kernel(ElementwiseKernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(args.N, 1), // Input Stride - ck_tile::make_tuple(args.N, 1), // Output Stride - input_tensors, - static_cast(c_ptr))); + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_)); }; - if(args.k_batch == 1) + if(s.flush_cache_) { - return Run(MemoryOpSet{}); + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = std::make_unique>( + gemm_kargs.as_ptr[0], + gemm_kargs.bs_ptr[0], + s.rotating_count_, + size_a_buffer, + size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; } else { - return Run(MemoryOpAtomicAdd{}); + preprocess = clear_gemm_output; } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel( + GemmKernel{}, grids, blocks, 0, gemm_kargs), + ck_tile::make_kernel(ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(args.N, 1), // Input Stride + ck_tile::make_tuple(args.N, 1), // Output Stride + input_tensors, + static_cast(c_ptr))); } }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index c06dc457c9..64305b85cf 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -160,110 +160,101 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& args.stride_E); constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&]() { - // use SET operation since each K-split writes to separate memory - constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = + ck_tile::CShuffleEpilogue>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(base_args); - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(base_args); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + const dim3 blocks = Kernel::BlockSize(); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); - ck_tile::RotatingMemWrapper rotating_mem( - kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - return ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - else - { - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - }; - - return Run(); + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + return ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } } /** diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index f79494a478..8eff0e7469 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -460,12 +460,6 @@ inline auto create_args() return arg_parser; } -// Type aliases for memory operation integral constants -using MemoryOpSet = - std::integral_constant; -using MemoryOpAtomicAdd = std::integral_constant; - // host API template ::template GemmPipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" - << std::endl; - } - float ave_time = 0.f; - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - ck_tile::RotatingMemWrapper rotating_mem(kargs.as_ptr[0], - kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - ave_time = - ck_tile::launch_kernel_time_mask(s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - return ave_time; - }; - - if(args.k_batch == 1) + dim3 grids; + if constexpr(Persistent) { - return Run(ck_tile::integral_constant{}); + grids = Kernel::MaxOccupancyGridSize(s); } else { - throw std::runtime_error("split-k is not supported yet!"); + grids = Kernel::GridSize(args.M, args.N, args.k_batch); } + dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" << std::endl; + } + float ave_time = 0.f; + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; } }; diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index 4a83a2c4ab..fb89e6b4cc 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -60,112 +60,94 @@ struct UniversalInvoker using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GemmKernel; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s) - : Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s) + : Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - rotating_mem_ptr = - std::make_unique>( - kargs.as_ptr[0], - kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - if(args.k_batch == 1) + if(s.flush_cache_) { - return Run(MemoryOpSet{}); + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = std::make_unique>( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; } else { - return Run(MemoryOpAtomicAdd{}); + preprocess = clear_gemm_output; } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/05_reduce/CMakeLists.txt b/example/ck_tile/05_reduce/CMakeLists.txt index 715ed35394..074b594534 100644 --- a/example/ck_tile/05_reduce/CMakeLists.txt +++ b/example/ck_tile/05_reduce/CMakeLists.txt @@ -15,6 +15,22 @@ list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-flo target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) +# Multi Reduce Threadwise Example +set(EXAMPLE_MULTI_REDUCE "tile_example_multi_reduce_threadwise") +add_executable(${EXAMPLE_MULTI_REDUCE} EXCLUDE_FROM_ALL multiple_reduce_threadwise.cpp) +target_include_directories(${EXAMPLE_MULTI_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS) +list(APPEND EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +target_compile_options(${EXAMPLE_MULTI_REDUCE} PRIVATE ${EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS}) + +# Multi Reduce Blockwise Example +set(EXAMPLE_MULTI_REDUCE_BLOCKWISE "tile_example_multi_reduce_multiblock") +add_executable(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} EXCLUDE_FROM_ALL multiple_reduce_multiblock.cpp) +target_include_directories(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS) +list(APPEND EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +target_compile_options(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS}) + # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global diff --git a/example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp b/example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp new file mode 100644 index 0000000000..2384dc2aa5 --- /dev/null +++ b/example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp @@ -0,0 +1,271 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/utility/json_dump.hpp" +#include + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("n", "32", "n dimension") + .insert("h", "19", "h dimension") + .insert("w", "7", "w dimension") + .insert("c", "512", "c dimension") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "multi_reduce_multiblock.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; + using ComputeDataType = float; + using YDataType = float; + + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t H = arg_parser.get_int("h"); + ck_tile::index_t W = arg_parser.get_int("w"); + ck_tile::index_t C = arg_parser.get_int("c"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + // Validate input dimensions + const ck_tile::index_t kept_dim_len_prod = N * C; + const ck_tile::index_t reduce_total_length = H * W; + + if(kept_dim_len_prod == 0) + { + std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C + << ", product=" << kept_dim_len_prod << ")." << std::endl; + std::cerr << "This will result in an empty output tensor." << std::endl; + return false; + } + + if(reduce_total_length == 0) + { + std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W + << ", product=" << reduce_total_length << ")." << std::endl; + std::cerr << "This will result in an empty reduction with no data to process." << std::endl; + std::cerr << "The kernel will exit early without performing any computation." << std::endl; + return false; + } + + std::vector problem_shape = {N, H, W, C}; + std::vector strides(4); + strides[0] = H * W * C; + strides[1] = W * C; + strides[2] = C; + strides[3] = 1; + + // Define reduction specification: + constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep + constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce + + ck_tile::HostTensor x_host(problem_shape, strides); + ck_tile::HostTensor y_host_add_ref({N, C}, {C, 1}); + ck_tile::HostTensor y_host_max_ref({N, C}, {C, 1}); + auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref); + + ck_tile::HostTensor y_host_add_dev({N, C}, {C, 1}); + ck_tile::HostTensor y_host_max_dev({N, C}, {C, 1}); + auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev); + + const auto number_operations = y_host_dev_tuple.size(); + + std::vector h(number_operations * N * C); + + auto y_buf_size = number_operations * + y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes(); + ck_tile::DeviceMem y_buf(y_buf_size); + + const auto output_tensor_offset = N * C; + + // Operations: one doing a sum reduction, the other computing the mean square + // In the case of mean square: + // 1. The element wise operation squares each element before reduction + // 2. The reduction operation sum the squared element + // 3. The accumulator element wise operation divides the result by the total number of reduced + // elements (intra block operation) + // 4. The partial result is updated across blocks using inter block reduction, a sum. + auto reduce_ops = + ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions + auto elementwise_ops = ck_tile::make_tuple(ck_tile::element_wise::PassThrough{}, + ck_tile::element_wise::UnarySquare{}); // Elementwise + // ops + auto accumulator_elementwise_ops = ck_tile::make_tuple( + ck_tile::element_wise::PassThrough{}, + ck_tile::element_wise::UnaryDivide{ + reduce_total_length}); // Accumulator Elementwise ops on reduction, intra block + auto inter_block_reduce_ops = ck_tile::make_tuple( + ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // Inter block reduction + + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<128, 128>; + using WarpTile = ck_tile::sequence<32, 128>; + using ThreadTile = ck_tile::sequence<8, 8>; + + constexpr ck_tile::index_t kBlockPerCu = 1; + + using Shape = ck_tile::Reduce2dShape; + using Problem = ck_tile::Reduce2dProblem; + + using Kernel = ck_tile::MultiReduceMultiblock; + + // Determine block group size for multi-block reduction + // block_group_size records how many blocks participate to a reduction (input data dependent) + // , for efficiency reasons this size if limited to a maximum of 128. If this is not sufficient + // to process the whole reduction, each thread will to process multiple thread tile + // a num_block_tile_iterations times + auto [num_block_tile_iterations, block_group_size] = + typename Kernel::TilePartitioner{reduce_total_length}.GetBlockGroupParams(); + + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + ck_tile::index_t kGridSize = + ((kept_dim_len_prod + Shape::Block_M - 1) / Shape::Block_M) * block_group_size; + + std::cout << "Block group size: " << block_group_size + << ", Num block tile iterations: " << num_block_tile_iterations + << ", Reduce total length: " << reduce_total_length << std::endl; + std::cout << "grid size " << kGridSize << ", block size " << kBlockSize << std::endl; + + // Create input tensor shape and strides + auto input_shape = + ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]); + auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]); + + if(!Kernel::IsSupportedArgument( + C, input_strides)) // output tensor's continuous dimension and input strides + { + throw std::runtime_error("Wrong! Arguments not supported!\n"); + } + + // Init the output data with identity values respective to each reduce op + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + constexpr auto op = reduce_ops.at(i); + const auto identity_val = op.template GetIdentityValue(); + const auto output_number_elements = N * C; + std::fill(h.begin() + i * output_number_elements, + h.begin() + (i + 1) * output_number_elements, + identity_val); + }); + + auto clear_output_buffer = [&]() { y_buf.ToDevice(h.data()); }; + + float ave_time = launch_kernel_time_mask( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + clear_output_buffer, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + input_shape, + input_strides, + kept_dim, + reduce_dims, + output_tensor_offset, + elementwise_ops, + accumulator_elementwise_ops, + inter_block_reduce_ops) + + ); + + std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_multiple_reduce_multiblock( + x_host, + y_host_ref_tuple, + reduce_ops, + kept_dim, + reduce_dims, + elementwise_ops, + accumulator_elementwise_ops, + inter_block_reduce_ops, + block_group_size); + std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl; + + // Transfer data from device and check error for each operation + y_buf.FromDevice(h.data()); + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + std::memcpy(y_host_dev_tuple.get(ck_tile::number{}).data(), + h.data() + i * output_tensor_offset, + output_tensor_offset * sizeof(YDataType)); + std::cout << "Checking operation " << i << ": " << std::endl; + + bool pass_op = ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number{}), + y_host_ref_tuple.get(ck_tile::number{})); + + if(pass_op) + { + std::cout << "✅ valid results for this operation" << std::endl; + } + pass &= pass_op; + }); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } +} diff --git a/example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp b/example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp new file mode 100644 index 0000000000..c929a7eb82 --- /dev/null +++ b/example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp @@ -0,0 +1,224 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/utility/json_dump.hpp" +#include + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("n", "32", "n dimension") + .insert("h", "7", "h dimension") + .insert("w", "7", "w dimension") + .insert("c", "512", "c dimension") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "multi_reduce.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; + using ComputeDataType = float; + using YDataType = DataType; + + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t H = arg_parser.get_int("h"); + ck_tile::index_t W = arg_parser.get_int("w"); + ck_tile::index_t C = arg_parser.get_int("c"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + // Validate input dimensions + const ck_tile::index_t kept_dim_len_prod = N * C; + const ck_tile::index_t reduce_total_length = H * W; + + if(kept_dim_len_prod == 0) + { + std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C + << ", product=" << kept_dim_len_prod << ")." << std::endl; + std::cerr << "This will result in an empty output tensor." << std::endl; + return false; + } + + if(reduce_total_length == 0) + { + std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W + << ", product=" << reduce_total_length << ")." << std::endl; + std::cerr << "This will result in an empty reduction with no data to process." << std::endl; + std::cerr << "The kernel will exit early without performing any computation." << std::endl; + return false; + } + + std::vector problem_shape = {N, H, W, C}; + std::vector strides(4); + strides[0] = H * W * C; + strides[1] = W * C; + strides[2] = C; + strides[3] = 1; + + // Define reduction specification: + constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep + constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce + + ck_tile::HostTensor x_host(problem_shape, strides); + ck_tile::HostTensor y_host_add_ref({N, C}, {C, 1}); + ck_tile::HostTensor y_host_max_ref({N, C}, {C, 1}); + auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref); + + ck_tile::HostTensor y_host_add_dev({N, C}, {C, 1}); + ck_tile::HostTensor y_host_max_dev({N, C}, {C, 1}); + auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev); + + const auto number_operations = y_host_dev_tuple.size(); + + // Two operations: one do a sum reduction, the other computing the mean square + auto reduce_ops = + ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions ops + auto elementwise_ops = + ck_tile::make_tuple(ck_tile::element_wise::PassThrough{}, + ck_tile::element_wise::UnarySquare{}); // Elementwise ops + auto accumulator_elementwise_ops = + ck_tile::make_tuple(ck_tile::element_wise::PassThrough{}, + ck_tile::element_wise::UnaryDivide{ + reduce_total_length}); // Accumulator Elementiwise ops on reduction, + + auto y_buf_size = number_operations * + y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes(); + ck_tile::DeviceMem y_buf(y_buf_size); + + const auto output_tensor_offset = N * C; + + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<128, 128>; + using WarpTile = ck_tile::sequence<32, 128>; + using ThreadTile = ck_tile::sequence<8, 8>; + + constexpr ck_tile::index_t kBlockPerCu = 1; + ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) / + BlockTile::at(ck_tile::number<0>{}); + std::cout << "grid size " << kGridSize << std::endl; + + using Shape = ck_tile::Reduce2dShape; + using Problem = ck_tile::Reduce2dProblem; + + using Kernel = ck_tile::MultiReduceThreadWise; + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + + // Create input tensor shape and strides + auto input_shape = + ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]); + auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]); + + if(!Kernel::IsSupportedArgument( + C, input_strides)) // output tensor's continuous dimension and input strides + { + throw std::runtime_error("Wrong! Arguments not supported!\n"); + } + + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + input_shape, + input_strides, + kept_dim, + reduce_dims, + output_tensor_offset, + elementwise_ops, + accumulator_elementwise_ops)); + + std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + std::vector h(number_operations * N * C); + + // reference + ck_tile::reference_multiple_reduce( + x_host, + y_host_ref_tuple, + reduce_ops, + kept_dim, + reduce_dims, + elementwise_ops, + accumulator_elementwise_ops); + std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl; + + // Transfer data from device and check error for each operation + y_buf.FromDevice(h.data()); + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + std::memcpy(y_host_dev_tuple.get(ck_tile::number{}).data(), + h.data() + i * output_tensor_offset, + output_tensor_offset * sizeof(YDataType)); + pass &= ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number{}), + y_host_ref_tuple.get(ck_tile::number{})); + }); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } +} diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index d9cb54cf74..a98faf5840 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -334,13 +334,13 @@ bool test_moe_sorting(ck_tile::ArgParser args) if(moe_buf_bytes > 0) { #if MOE_SORTING_FMOE_2D_BUF - printf("moe_buf:%lu(%d,%d), ", + printf("moe_buf:%" PRIu64 "(%d,%d), ", static_cast(moe_buf_bytes), moe_buf_interm_dim, moe_buf_elem_bytes); #else - printf("moe_buf:%lu, ", static_cast(moe_buf_bytes)); + printf("moe_buf:%" PRIu64 ", ", static_cast(moe_buf_bytes)); #endif } diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index c7e37bc8a7..b68c30351d 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -78,63 +78,48 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } #include "run_batched_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 9b51af22fe..0f0a0d8ba7 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -14,7 +14,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95") quant_grouped_gemm_bf8_rowcol.cpp quant_grouped_gemm_bf8_tensor.cpp ) - + add_executable(tile_example_abquant_grouped_gemm abquant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle grouped_gemm_preshuffle.cpp) add_executable(tile_example_grouped_gemm_multi_d grouped_gemm_multi_d.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) @@ -25,4 +25,5 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95") target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_abquant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp new file mode 100644 index 0000000000..84da1e26da --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp @@ -0,0 +1,278 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm_quant.hpp" +#include "ck_tile/host.hpp" +#include "abquant_grouped_gemm.hpp" + +// Non-persistent grouped gemm for ABQuant +template +float grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + GemmQuantConfig::template BaseGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + +// Persistent grouped gemm tileloop for ABQuant +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); +} + +#include "run_grouped_gemm_abquant_example.inc" + +int main(int argc, char* argv[]) +{ + int result1 = run_abquant_grouped_gemm_example(argc, argv); + return result1; +} diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp new file mode 100644 index 0000000000..da8bd5514c --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp @@ -0,0 +1,171 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/utility/json_dump.hpp" + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + static constexpr bool PreshuffleB = false; + static constexpr bool Persistent = Persistent_; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); +}; + +template +struct GemmQuantConfig; + +// ABQuant specialization for GemmQuantConfig +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("Ms", "", "M dimensions - empty by default.") + .insert("Ns", "", "N dimensions - empty by default.") + .insert("Ks", "", "K dimensions - empty by default.") + .insert( + "stride_As", + "", + "Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs + // can be set to zero if + // Ms/Ns/Ks is not empty + .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") + .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") + .insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.") + .insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.") + .insert("a_layout", "R", "A tensor data layout - Row by default.") + .insert("b_layout", "C", "B tensor data layout - Row by default.") + .insert("c_layout", "R", "C tensor data layout - Row by default.") + .insert("validate", "1", "0. No validation, 1. Validation on CPU.") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "10", "number of iterations before benchmark the kernel.") + .insert("repeat", "100", "number of iterations to benchmark the kernel.") + .insert("group_count", "8", "group count.") + .insert("kbatch", "1", "kbatch for SplitK") + .insert("init", "0", "0. Random, 2. One(s) (Constant)") + .insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.") + .insert("bquant_group_size", "1x1x128", "BQuant group size. 1x1x128 (default) or 1x128x128") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "abquant_grouped_gemm.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +inline std::size_t get_workspace_size(const std::vector& gemm_descs) +{ + return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); +} + +// Forward declaration of the non-persistent version +template +float grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr); + +// Forward declaration of the tileloop version for persistent kernels +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 3ff3f2f10e..a24e4bc8ab 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -62,71 +62,55 @@ float grouped_gemm(const std::vector& gemm_descs, using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -161,74 +144,55 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, BLayout, CLayout>; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC>>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - UniversalGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - }; - - if(!splitk) + if(s.log_level_ > 0) { - return ave_time = Run(ck_tile::integral_constant{}); - } - else - { - return ave_time = - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 67b411c1f0..462f11e405 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -328,5 +328,4 @@ template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk = false); + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index 060dd311b5..e5aefad8d1 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -61,72 +61,56 @@ float grouped_gemm_multi_d(const std::vector& gemm_d using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { " - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { " + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -163,76 +146,55 @@ float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s, BLayout, ELayout>; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; - }; - if(!splitk) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - return ave_time; + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } #include "run_grouped_gemm_multi_d_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp index 4a5be996c0..b4c10900d6 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp @@ -65,70 +65,54 @@ float grouped_gemm(const std::vector& gemm_descs, using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -167,75 +150,53 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, GemmConfig::NumWaveGroups, GemmConfig::Preshuffle>; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, // DsDataType (empty for no D tensors) + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout (empty for no D tensors) + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC>>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue, // DsDataType (empty for no D tensors) - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout (empty for no D tensors) - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - UniversalGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; - }; - - if(splitk) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - return ave_time; + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp b/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp index 16352722e1..ea71abb213 100644 --- a/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp @@ -72,10 +72,9 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::BQuantGrouped; @@ -137,8 +136,7 @@ float grouped_gemm(const std::vector& gemm_descs, GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, - QuantGemmProblem::TransposeC, - memory_operation>>; + QuantGemmProblem::TransposeC>>; using Kernel = ck_tile::QuantGroupedGemmKernel; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; - constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; - using QuantGemmProblem = std::conditional_t< - UseGroupedQuant, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>, - ck_tile::GemmRowColTensorQuantPipelineProblem>; + using GemmPipeline = GemmQuantConfig::template GemmPipeline; - using GemmPipeline = - GemmQuantConfig::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - QuantGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::QuantGroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - }; - - return ave_time = Run(ck_tile::integral_constant{}); + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc new file mode 100644 index 0000000000..bc5167439d --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc @@ -0,0 +1,604 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_abquant_gemm(int n_warmup, + int n_repeat, + int group_count, + const std::vector& args) +{ + // Workspace memory allocated to hold the gemm descriptions. + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(args)); + + float ave_time = 0; + + if constexpr(!GemmConfig::Persistent) + { + ave_time = grouped_gemm_abquant( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if(args[0].k_batch != 1) + { + throw std::runtime_error("Split-K not supported yet for persistent kernel"); + } + + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); + } + + return ave_time; +} + +template +int run_abquant_grouped_gemm_example_with_layouts( + int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const AQLayout aq_layout = AQLayout{}, + const BLayout b_layout = BLayout{}, + const BQLayout bq_layout = BQLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + + auto [result, arg_parser] = create_args(argc, argv); + + auto valid_input_data = [&](int group_count, const auto&... args) { + return group_count != 0 && ((args.size() == static_cast(group_count)) && ...); + }; + + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + const int init_method = arg_parser.get_int("init"); + bool validate = arg_parser.get_bool("validate"); + + if(kbatch > 1 && validate && warmup + repeat > 1) + { + std::cout << "WARNING: Data validation enabled with SplitK and more than" + << "1 warmup/repeat. Disabling validation." << std::endl; + validate = false; + } + + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector AQs; // dimension of AQ tensor is calculated from A tensor + std::vector BQs; // dimension of BQ tensor is calculated from B tensor + std::vector stride_As = arg_parser.get_int_vec("stride_As"); + std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); + std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); + std::vector stride_AQs = arg_parser.get_int_vec("stride_AQs"); + std::vector stride_BQs = arg_parser.get_int_vec("stride_BQs"); + + ck_tile::index_t AQK, BQK; + + if(!valid_input_data( + group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs)) + { + std::cout << "Please check the input data. Default values will be used." << std::endl; + + // Clear existing (invalid) data before adding defaults + Ms.clear(); + Ns.clear(); + Ks.clear(); + stride_As.clear(); + stride_Bs.clear(); + stride_Cs.clear(); + stride_AQs.clear(); + stride_BQs.clear(); + + for(int i = 0; i < group_count; i++) + { + + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + // Let get_default_stride calculate based on layout + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + } + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + std::vector> aq_tensors; + std::vector> bq_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + aq_tensors.reserve(group_count); + bq_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + std::vector> aq_dev_buf; + std::vector> bq_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + aq_dev_buf.reserve(group_count); + bq_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + // For ABQuantGrouped, both A and B need quantization + static_assert(QuantMode == ck_tile::QuantType::ABQuantGrouped, + "This file only supports ABQuantGrouped mode"); + + AQK = K / AQuantGroupSize::kK; // Group quantization: AQK = K / AQuantGroupSize + BQK = K / BQuantGroupSize::kK; // Group quantization: BQK = K / BQuantGroupSize + if(K % AQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode"); + } + if(K % BQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode"); + } + + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); + stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout)); + stride_BQs[i] = ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout)); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout)))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc + << " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl; + + if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(bq_tensors[i]); + } + else + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(bq_tensors[i]); + } + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_k_n_tensors[i].get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + aq_dev_buf.push_back( + std::make_unique(aq_tensors[i].get_element_space_size_in_bytes())); + bq_dev_buf.push_back( + std::make_unique(bq_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + aq_dev_buf[i]->ToDevice(aq_tensors[i].data()); + bq_dev_buf[i]->ToDevice(bq_tensors[i].data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer(); + const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back({p_a, + p_b, + p_c, + p_aq, + p_bq, + kbatch, + M, + N, + K, + AQK, + BQK, + stride_As[i], + stride_Bs[i], + stride_Cs[i], + stride_AQs[i], + stride_BQs[i]}); + } + + float ave_time = invoke_abquant_gemm(warmup, repeat, group_count, gemm_descs); + + std::string op_name = "ABQuant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; + + std::size_t flop = 0, num_btype = 0; + for(int j = 0; j < group_count; ++j) + { + flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K; + + num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K + + sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N + + sizeof(CDataType) * gemm_descs[j].M * gemm_descs[j].N; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + if(validate) + { + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + // Reference implementation for ABQuantGrouped + ck_tile::reference_gemm_abquant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + Ks[i], kbatch, max_accumulated_value); + pass &= + ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results! in group [" + std::to_string(i) + "]", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "gemm[" << i + << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + + if(arg_parser.get_int("json") == 1) + { + dump_grouped_gemm_json_results(arg_parser.get_str("jsonfile"), + op_name, + group_count, + pass, + ave_time, + tflops, + gb_per_sec); + } + + return pass; +} + +template +int run_abquant_grouped_gemm_example_prec_type_with_bquant( + std::string a_layout, std::string b_layout, std::string c_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using Types = GemmTypeConfig; + // Specific type aliases for easy access + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using AccDataType = typename Types::AccDataType; + using CDataType = typename Types::CDataType; + using AQDataType = typename Types::AccDataType; + using BQDataType = typename Types::AccDataType; + using AQuantGroupSize = ck_tile::QuantGroupShape>; + + constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped; + + if(a_layout == "R" && b_layout == "C" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); + } + else if(a_layout == "R" && b_layout == "R" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Col{}, Row{}, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + +template +int run_abquant_grouped_gemm_example_prec_type(std::string a_layout, + std::string b_layout, + std::string c_layout, + std::string bquant_group_size, + int argc, + char* argv[]) +{ + if(bquant_group_size == "1x1x128") + { + using BQuantGroupSize = ck_tile::QuantGroupShape>; + return run_abquant_grouped_gemm_example_prec_type_with_bquant( + a_layout, b_layout, c_layout, argc, argv); + } + else if(bquant_group_size == "1x128x128") + { + using BQuantGroupSize = ck_tile::QuantGroupShape>; + return run_abquant_grouped_gemm_example_prec_type_with_bquant( + a_layout, b_layout, c_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported BQuantGroupSize! Use 1x1x128 or 1x128x128."); + } +} + +template +int run_abquant_gemm_example_persistency(std::string a_layout, + std::string b_layout, + std::string c_layout, + bool persistent, + std::string bquant_group_size, + int argc, + char* argv[]) +{ + if(persistent) + { + using GemmConfig = typename GemmQuantConfig< + ck_tile::QuantType::ABQuantGrouped>::template GemmConfig; + return run_abquant_grouped_gemm_example_prec_type( + a_layout, b_layout, c_layout, bquant_group_size, argc, argv); + } + else + { + using GemmConfig = typename GemmQuantConfig< + ck_tile::QuantType::ABQuantGrouped>::template GemmConfig; + return run_abquant_grouped_gemm_example_prec_type( + a_layout, b_layout, c_layout, bquant_group_size, argc, argv); + } +} + +int run_abquant_grouped_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string c_layout = arg_parser.get_str("c_layout"); + const std::string data_type = arg_parser.get_str("prec"); + bool persistent = arg_parser.get_bool("persistent"); + const std::string bquant_group_size = arg_parser.get_str("bquant_group_size"); + + if(data_type == "fp8") + { + return run_abquant_gemm_example_persistency( + a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv); + } + else if(data_type == "bf8") + { + return run_abquant_gemm_example_persistency( + a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type configuration."); + } +} diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 390a54644b..7a01b1dcea 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -79,8 +79,7 @@ float invoke_gemm(int n_warmup, // earlier stage. std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = args[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : args) { kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr}, @@ -109,7 +108,7 @@ float invoke_gemm(int n_warmup, ADataType, BDataType, AccDataType, - CDataType>(stream, group_count, kargs_ptr, splitk); + CDataType>(stream, group_count, kargs_ptr); } return ave_time; diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index ac6ea99db3..4f2bebdf17 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -95,8 +95,7 @@ float invoke_gemm(int n_warmup, else { std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = args[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : args) { kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, NumDTensor>{{arg.a_ptr}, @@ -119,18 +118,17 @@ float invoke_gemm(int n_warmup, kargs.size() * sizeof(ck_tile::GemmTransKernelArg), hipMemcpyHostToDevice, stream.stream_id_)); - ave_time = - grouped_gemm_multi_d_tileloop(stream, group_count, kargs_ptr, splitk); + ave_time = grouped_gemm_multi_d_tileloop(stream, group_count, kargs_ptr); } return ave_time; } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index cd241a2be0..af46884a90 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -170,13 +170,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem& args, FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, 1, @@ -282,23 +278,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/grouped_flatmm.cpp b/example/ck_tile/18_flatmm/grouped_flatmm.cpp index da85c95dae..780a21ba14 100644 --- a/example/ck_tile/18_flatmm/grouped_flatmm.cpp +++ b/example/ck_tile/18_flatmm/grouped_flatmm.cpp @@ -113,13 +113,10 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. @@ -216,23 +212,7 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp index fe7fe4c5d1..708e8a683e 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp @@ -113,13 +113,10 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = std::conditional_t{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp index 2b6dbace36..f9f8c0cec7 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp @@ -89,13 +89,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern @@ -128,7 +125,6 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, // FixedVectorSize 1, // VectorSizeC @@ -201,23 +197,7 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/moe_flatmm.cpp b/example/ck_tile/18_flatmm/moe_flatmm.cpp index 96b9ae29a4..4cca953066 100644 --- a/example/ck_tile/18_flatmm/moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/moe_flatmm.cpp @@ -144,15 +144,11 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs& args, const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem& args, FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, 1, @@ -261,37 +256,20 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs& args, args.NumTokens * args.TopK * outputN * sizeof(CDataType), s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( + return ck_tile::launch_kernel_time_mask( s, run_flush_cache, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + float ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp index f177ef04ca..01128f8fe8 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp @@ -61,8 +61,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, "mixed_prec_flatmm requires ADataType is a wider type than BDataType"); constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = - Splitk ? ck_tile::memory_operation_enum::atomic_add : ck_tile::memory_operation_enum::set; + ck_tile::ignore = Splitk; constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern @@ -98,7 +97,6 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, MXPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, // FixedVectorSize 1, // VectorSizeC diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp index 9e2bc3e3fb..1c56295f9f 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp @@ -81,87 +81,45 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - // Epilogue selection: set to true for chainer-based, false for standard - // CShuffleEpilogue - constexpr bool UseChainerEpilogue = true; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = std::conditional_t< - UseChainerEpilogue, - // Chainer-based epilogue - ck_tile::EpilogueChainer, - ck_tile::DefaultScheduleTag>>, - // Standard CShuffleEpilogue - ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>>; + using Kernel = ck_tile::GemmKernelMultiD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y - << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y - << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } #include "run_gemm_multi_d_fp16_example.inc" diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp index d2663b033c..ca8573d6d2 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp @@ -59,94 +59,80 @@ struct GroupedConvolutionBackwardDataInvoker ConvConfig::NumWaveGroups>; constexpr auto scheduler = ConvConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - OutDataType, - WeiDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - InDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - auto preprocess = [&]() { - ck_tile::hip_check_error(hipMemsetAsync( - kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); - }; - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto preprocess = [&]() { + ck_tile::hip_check_error(hipMemsetAsync( + kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); }; - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp index afe43cd1c0..90874e6018 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -59,104 +59,85 @@ struct GroupedConvolutionBackwardWeightInvoker ConvConfig::NumWaveGroups>; constexpr auto scheduler = ConvConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - OutDataType, - InDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - WeiDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; - const auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + auto preprocess = [&]() { + if(args.k_batch > 1) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + ck_tile::hip_check_error(hipMemsetAsync( + kargs.wei_ptr, 0, args.template GetWeightByte(), s.stream_id_)); } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - auto preprocess = [&]() { - if(kargs.k_batch > 1) - { - ck_tile::hip_check_error( - hipMemsetAsync(kargs.wei_ptr, - 0, - args.template GetWeightByte(), - s.stream_id_)); - } - }; - - const auto ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - const auto split_k = kargs.k_batch; - - return InvokerResult{ave_time, split_k}; }; - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } + float ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return InvokerResult{ave_time, args.k_batch}; } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp index ad5e8ae70f..c4d618a0bf 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -65,163 +65,143 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker constexpr auto scheduler = ConvConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - OutDataType, - InDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - WeiDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; - using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; + const ck_tile::index_t spatial_lengths_accum = + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum * + sizeof(WorkspaceDataType)); + ck_tile::GroupedConvBwdWeightHostArgs ws_args = ck_tile::GroupedConvBwdWeightHostArgs(args); + auto c_ptr = ws_args.wei_ptr; + ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - const ck_tile::index_t spatial_lengths_accum = - std::accumulate(args.filter_spatial_lengths_.begin(), - args.filter_spatial_lengths_.end(), - 1, - std::multiplies()); - ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum * - sizeof(WorkspaceDataType)); - ck_tile::GroupedConvBwdWeightHostArgs ws_args = - ck_tile::GroupedConvBwdWeightHostArgs(args); - auto c_ptr = ws_args.wei_ptr; - ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - const auto kargs = Kernel::MakeKernelArgs(ws_args); + const auto kargs = Kernel::MakeKernelArgs(ws_args); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; + using BlockTile = ck_tile::sequence<2048>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; - using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; - using BlockTile = ck_tile::sequence<2048>; - using BlockWarps = ck_tile::sequence<8>; - using WarpTile = ck_tile::sequence<64>; + using ElementwiseShape = + ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + using ElementwiseKernel = + ck_tile::ElementWiseKernel; - using ElementwiseShape = - ck_tile::ElementWiseShape; - using Problem = ck_tile::ElementWisePipelineProblem; - using ElementwiseKernel = - ck_tile::ElementWiseKernel; + ck_tile::index_t total_elements = 1; + std::vector shape = { + static_cast(args.G_ * args.K_), + static_cast(args.C_ * spatial_lengths_accum)}; - ck_tile::index_t total_elements = 1; - std::vector shape = { - static_cast(args.G_ * args.K_), - static_cast(args.C_ * spatial_lengths_accum)}; + for(auto d : shape) + total_elements *= d; - for(auto d : shape) - total_elements *= d; + const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); - const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; - constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); - ck_tile::index_t kGridSize = - (total_elements + elements_per_block - 1) / elements_per_block; + auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.wei_ptr)); + auto input_size = ck_tile::make_tuple(shape[0], shape[1]); - auto input_tensors = - ck_tile::make_tuple(static_cast(ws_args.wei_ptr)); - auto input_size = ck_tile::make_tuple(shape[0], shape[1]); + // Check if the kernel configuration is supported + if(!ElementwiseKernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); + } - // Check if the kernel configuration is supported - if(!ElementwiseKernel::IsSupportedArgument(input_size)) - { - throw std::runtime_error( - "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - auto preprocess = [&]() { - if(kargs.k_batch > 1) - ck_tile::hip_check_error( - hipMemsetAsync(ws_args.wei_ptr, - 0, - shape[0] * shape[1] * sizeof(WorkspaceDataType), - s.stream_id_)); - }; - - const auto ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), - ck_tile::make_kernel( - ElementwiseKernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(shape[1], 1), // Input Stride - ck_tile::make_tuple(shape[1], 1), // Output Stride - input_tensors, - static_cast(c_ptr))); - - const auto split_k = kargs.k_batch; - - return InvokerResult{ave_time, split_k}; + auto preprocess = [&]() { + if(args.k_batch > 1) + ck_tile::hip_check_error( + hipMemsetAsync(ws_args.wei_ptr, + 0, + shape[0] * shape[1] * sizeof(WorkspaceDataType), + s.stream_id_)); }; - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } + float ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), + ck_tile::make_kernel( + ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(shape[1], 1), // Input Stride + ck_tile::make_tuple(shape[1], 1), // Output Stride + input_tensors, + static_cast(c_ptr))); + return InvokerResult{ave_time, kargs.k_batch}; } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index 82541bb593..c94466aeb2 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -70,91 +70,74 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== // Regular Convolution: Simple, no split-image // ===================================================================== - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - InDataType, - WeiDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - OutDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using Kernel = ck_tile::GroupedConvolutionForwardKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GroupedConvolutionForwardKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - // ===================================================================== - // Split-K dispatch - // ===================================================================== - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(MemoryOpSet{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); } - else + + if(s.log_level_ > 0) { - return Run(MemoryOpAtomicAdd{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index 4261385a84..5dec340668 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -213,8 +213,7 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== // Kernel launch lambda: Uses EnableSplitImage based on layout support // ===================================================================== - const auto Run = [&](const auto memory_operation_, const auto enable_split_image_) { - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto enable_split_image_) { constexpr bool EnableSplitImage = enable_split_image_.value; using GroupedConvTraitsType = std::conditional_t>; @@ -332,17 +330,11 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== if(use_split_image) { - if(args.k_batch == 1) - return Run(MemoryOpSet{}, ck_tile::bool_constant{}); - else - return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant{}); + return Run(ck_tile::bool_constant{}); } else { - if(args.k_batch == 1) - return Run(MemoryOpSet{}, ck_tile::bool_constant{}); - else - return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant{}); + return Run(ck_tile::bool_constant{}); } } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp index 63dd54dcae..a78a880815 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -13,11 +13,6 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "conv_configs.hpp" -using MemoryOpSet = - std::integral_constant; -using MemoryOpAtomicAdd = std::integral_constant; - template auto calculate_rtol_atol(const ck_tile::index_t GemmK, const ck_tile::index_t kbatch, diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp index acb9126d65..9202bf9d98 100644 --- a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp @@ -85,60 +85,44 @@ auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_conf using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiABD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y - << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y - << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } #include "run_gemm_multi_abd_fp16_example.inc" diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 28e52b9275..ec536f7287 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -20,9 +20,18 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_bquant_quantgrouped_bf16mxfp4.cpp gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp - gemm_bquant_quantgrouped_preshuffleb.cpp - gemm_bquant_quantgrouped_preshufflequant.cpp - gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp + gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_bf8.cpp + gemm_bquant_quantgrouped_preshuffleb_fp8.cpp + gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp + gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp + gemm_bquant_quantgrouped_preshufflequant_bf8.cpp + gemm_bquant_quantgrouped_preshufflequant_fp8.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp gemm_quant_rowcol.cpp gemm_quant_tensor.cpp ) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index 4a90c07e05..155f19881e 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -69,4 +69,64 @@ void abquant_quantgrouped_instance_factory( BQuantGroupSize, ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp index 61fd65960f..82e30e56d2 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp @@ -49,4 +49,10 @@ void bquant_quantgrouped_bf8_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp index 1d471068eb..515e6eb027 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp @@ -51,4 +51,10 @@ void bquant_quantgrouped_bf8i4_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp index 280029033b..eaf10f057c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp @@ -49,4 +49,10 @@ void bquant_quantgrouped_fp8_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp index a277c864bb..c91867534f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp @@ -51,4 +51,10 @@ void bquant_quantgrouped_fp8i4_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp deleted file mode 100644 index b32356c29d..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb.cpp +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "run_gemm_quant_example.inc" - -#if CK_TILE_USE_WMMA -template -using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; -#else -template -using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; -#endif - -void bquant_quantgrouped_preshuffleb_instance_factory( - std::unordered_map>& lut) -{ - lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "bquant", - "preshuffleb", - "non-preshufflequant", - "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8", - "bquant", - "preshuffleb", - "non-preshufflequant", - "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - - lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "bquant", - "preshuffleb", - "non-preshufflequant", - "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", - "bquant", - "preshuffleb", - "non-preshufflequant", - "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = - decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - using QuantGroupSize = ck_tile::QuantGroupShape>; - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; -} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp new file mode 100644 index 0000000000..7166a5647e --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_bf8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp new file mode 100644 index 0000000000..85599864db --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp new file mode 100644 index 0000000000..87cb4c9d10 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_fp8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp new file mode 100644 index 0000000000..0cb16441a9 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "preshuffleb", "non-preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp deleted file mode 100644 index 180f353df8..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "run_gemm_quant_example.inc" - -#if CK_TILE_USE_WMMA -template -using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma; -#else -template -using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; -#endif - -void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory( - std::unordered_map>& lut) -{ - using QuantGroupSize = ck_tile::QuantGroupShape>; - lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; -} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp new file mode 100644 index 0000000000..640757a956 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp @@ -0,0 +1,50 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp new file mode 100644 index 0000000000..575a43afd8 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp new file mode 100644 index 0000000000..9e40fbaa87 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp @@ -0,0 +1,50 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp new file mode 100644 index 0000000000..2552a1d134 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +#if CK_TILE_USE_WMMA +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill_Wmma; +#else +template +using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; +#endif + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp deleted file mode 100644 index e0e0a64416..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "run_gemm_quant_example.inc" - -template -using GemmConfig = GemmConfigPreshuffleBQuantPrefill; - -void bquant_quantgrouped_preshufflequant_instance_factory( - std::unordered_map>& lut) -{ - using QuantGroupSize = ck_tile::QuantGroupShape>; - lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; - lut[hash_multiple_strings( - {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = - [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); - return run_gemm_example_prec_type, - TypeConfig, - QuantGroupSize, - ck_tile::QuantType::BQuantGrouped>(arg_parser); - }; -} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp new file mode 100644 index 0000000000..edb28236af --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshufflequant_bf8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp new file mode 100644 index 0000000000..59da63447e --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp @@ -0,0 +1,59 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp new file mode 100644 index 0000000000..29c88001e8 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshufflequant_fp8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp new file mode 100644 index 0000000000..f487132557 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp @@ -0,0 +1,59 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x128x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 940c1b8cf3..8de58b0a30 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -111,11 +111,29 @@ void bquant_quantgrouped_bf8i4_instance_factory( std::unordered_map>& lut); void bquant_quantgrouped_bf16fp4_instance_factory( std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_instance_factory( +void bquant_quantgrouped_preshuffleb_fp8_instance_factory( std::unordered_map>& lut); -void bquant_quantgrouped_preshufflequant_instance_factory( +void bquant_quantgrouped_preshuffleb_bf8_instance_factory( std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory( +void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory( std::unordered_map>& lut); void quant_rowcol_instance_factory( std::unordered_map>& lut); @@ -144,9 +162,18 @@ int main(int argc, char* argv[]) bquant_quantgrouped_fp8i4_instance_factory(lut); bquant_quantgrouped_bf8i4_instance_factory(lut); bquant_quantgrouped_bf16fp4_instance_factory(lut); - bquant_quantgrouped_preshuffleb_instance_factory(lut); - bquant_quantgrouped_preshufflequant_instance_factory(lut); - bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut); + bquant_quantgrouped_preshuffleb_fp8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_bf8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(lut); + bquant_quantgrouped_preshufflequant_fp8_instance_factory(lut); + bquant_quantgrouped_preshufflequant_bf8_instance_factory(lut); + bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(lut); + bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(lut); + bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(lut); quant_rowcol_instance_factory(lut); quant_tensor_instance_factory(lut); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 47a22cdcba..607c53d9af 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -74,9 +74,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, ck_tile::BaseGemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BaseGemmPipelineAgBgCrCompV3>>>; + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrMem, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -145,26 +146,33 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str GemmConfig::Scheduler, has_hot_loop_v, tail_number_v>>>>; + using AQuantPipeline = + std::conditional_t, + ck_tile::AQuantGemmPipelineAgBgCrMem>; + + using BQuantPipeline = std::conditional_t< + GemmConfig::PreshuffleB, + ck_tile::WPQuantBPipelineAgBgCrV2, + std::conditional_t< + std::is_same_v, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + + using ABQuantPipeline = + std::conditional_t, + ck_tile::ABQuantGemmPipelineAgBgCrCompV3>; using GemmPipeline = std::conditional_t< QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant, ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped, - std::conditional_t, - ck_tile::AQuantGemmPipelineAgBgCrMem>, - std::conditional_t< - QuantMode == ck_tile::QuantType::ABQuantGrouped, - ck_tile::ABQuantGemmPipelineAgBgCrCompV3, - std::conditional_t< - GemmConfig::PreshuffleB == true, - ck_tile::WPQuantBPipelineAgBgCrV2, - std::conditional_t< - std::is_same_v, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>>>; + std::conditional_t>>; constexpr bool TiledPermuteN = (BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN; @@ -173,77 +181,30 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str printf( "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN); } - - // Epilogue selection: use chainer for RowCol/Tensor quant, standard for others - // Toggle to switch between chainer-based and standard CShuffleEpilogue - constexpr bool UseChainerEpilogue = true; - - // Define the schedule tag based on quant mode - using ScheduleTag = - std::conditional_t>; - - using GemmEpilogue = std::conditional_t< - UseChainerEpilogue && (QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant), - // Chainer-based epilogue for RowCol/Tensor quant modes - ck_tile::EpilogueChainer, - typename TypeConfig::ADataType, - typename TypeConfig::BDataType>, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledPermuteN>, - ScheduleTag>>, - // Standard CShuffleEpilogue for other modes - ck_tile::CShuffleEpilogue, typename TypeConfig::ADataType, - std::conditional_t< - std::is_same_v, - typename TypeConfig::ADataType, - typename TypeConfig::BDataType>, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledPermuteN>>>; - + typename TypeConfig::BDataType>, + ck_tile::tuple<>, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + 1, + false, + 1, + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; @@ -579,7 +540,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::RowColQuant) { bq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout))); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { @@ -955,8 +916,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - if((QuantMode == ck_tile::QuantType::ABQuantGrouped || - QuantMode == ck_tile::QuantType::AQuantGrouped || + if((QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::RowColQuant || std::is_same_v) && GemmConfig::PreshuffleB) @@ -985,7 +945,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::ABQuantGrouped) && - !GemmConfig::PreshuffleQuant) + !GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB) { if(a_layout == "R" && b_layout == "R") { diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp index d3ee9fe9c6..828c861349 100644 --- a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp @@ -48,112 +48,87 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, GemmConfiguration::NUM_WAVE_GROUPS, GemmConfiguration::PRESHUFFLE>; - const auto runKernel = [&](const auto memory_operation) -> std::tuple { - // We create the GEMM pipeline without specifying has_hot_loop or tail_num. - // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K - // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K - // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::StreamKKernel; + using Kernel = ck_tile::StreamKKernel; - auto kernel_args = Kernel::MakeKernelArgs(args); - const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args); - ck_tile::DeviceMem workspace_data(workspace_size); + auto kernel_args = Kernel::MakeKernelArgs(args); + const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args); + ck_tile::DeviceMem workspace_data(workspace_size); + workspace_data.SetZero(); + kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer(); + + dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner); + dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kernel_args)) + { + // Clear the output C tensor results after each repetition of the kernel + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); + } + + if(stream_config.log_level_ > 0) + { + // Reset sk flags to zero before each repetition of the kernel workspace_data.SetZero(); - kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer(); + } - dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner); - dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kernel_args)) + auto reset_data_buffers = [&]() { + if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + // Clear the output C tensor results after each repetition of the kernel + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); } - - if(stream_config.log_level_ > 0) + else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + // Reset sk flags to zero before each repetition of the kernel + workspace_data.SetZero(); } - - auto reset_data_buffers = [&]() { - if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) - { - // Clear the output C tensor results after each repetition of the kernel - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); - } - else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) - { - // Reset sk flags to zero before each repetition of the kernel - workspace_data.SetZero(); - } - }; - - std::function preprocess = reset_data_buffers; - - float average_time = - ck_tile::launch_kernel_time_mask(stream_config, - preprocess, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kernel_args)); - - ck_tile::index_t num_wgs_per_tile = - kernel_args.tile_partitioner.estimate_num_wgs_per_tile(); - return std::tuple{average_time, num_wgs_per_tile}; }; - if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy) - { - return runKernel(ck_tile::integral_constant{}); - } - else // We are using ck_tile::StreamKReductionStrategy::Reduction - { - return runKernel(ck_tile::integral_constant{}); - } + std::function preprocess = reset_data_buffers; + + float average_time = + ck_tile::launch_kernel_time_mask(stream_config, + preprocess, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kernel_args)); + + ck_tile::index_t num_wgs_per_tile = kernel_args.tile_partitioner.estimate_num_wgs_per_tile(); + return std::tuple{average_time, num_wgs_per_tile}; } #include "run_gemm_example.inc" diff --git a/example/ck_tile/41_batched_contraction/batched_contraction.cpp b/example/ck_tile/41_batched_contraction/batched_contraction.cpp index f9f13c6e85..1e159a5615 100644 --- a/example/ck_tile/41_batched_contraction/batched_contraction.cpp +++ b/example/ck_tile/41_batched_contraction/batched_contraction.cpp @@ -92,67 +92,59 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = GEMM_PIPELINE; - using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = + ck_tile::BatchedContractionKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = - ck_tile::BatchedContractionKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::GetBlockSize(); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::GetBlockSize(); + if(!Kernel::IsSupportedArguments(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n"); + } - if(!Kernel::IsSupportedArguments(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } + auto kernel = ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs); - auto kernel = ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs); - - return ck_tile::launch_kernel(s, kernel); - }; - - return Run(); + return ck_tile::launch_kernel(s, kernel); } #define HANDLE_CASE(G, M, N, K) \ diff --git a/experimental/builder/README.md b/experimental/builder/README.md index 940ee3e503..850bcf136e 100644 --- a/experimental/builder/README.md +++ b/experimental/builder/README.md @@ -2,13 +2,13 @@ This directory contains the experimental builder feature for composable_kernel. -* Status: In development (October - December 2025) +* Status: In development (October 2025 - March 2026) ## Overview The builder provides a high-level, semantically-clear interface for constructing composable kernel operations, with an initial focus on convolution kernels for MIOpen. It leverages modern C++20 features (such as POD structs as non-type template parameters, concepts, and designated initializers) to simplify kernel instantiation and improve developer experience. -This project is a prototype for a more general builder pattern for all of composable_kernel (CK) and CKTile, but is currently limited to formalizing the interface between MIOpen and CK. +This project is a prototype for a more general builder pattern for all of composable_kernel (CK) and CK Tile, but is currently limited to formalizing the interface between MIOpen and CK. ## Design descriptions @@ -45,6 +45,11 @@ cmake .. ``` +Note: The tests for WMMA builders are only built when `CK_USE_WMMA` is enabled. Add e.g. +`gfx1121` or any of the other `gfx11`/`gfx12` architectures to the GPU targets. Alternatively, +one can add flag `-D CK_USE_WMMA=ON` to build the tests. For the end-to-end tests that use +the instances from builder, one needs an actual Navi card. + ## Building and Testing The builder test suite is organized into two main categories: diff --git a/experimental/builder/include/ck_tile/builder/README.md b/experimental/builder/include/ck_tile/builder/README.md index 8075e33220..0af0cede60 100644 --- a/experimental/builder/include/ck_tile/builder/README.md +++ b/experimental/builder/include/ck_tile/builder/README.md @@ -85,21 +85,23 @@ The top-level signature contains global properties that apply to the entire conv template concept ConvSignatureDescriptor = requires(T t) { { t.spatial_dim } -> std::convertible_to; // 1, 2, or 3 - { t.data_type } -> std::convertible_to; // Default data type { t.input } -> ConvTensorDescriptor; { t.weight } -> ConvTensorDescriptor; { t.output } -> ConvTensorDescriptor; requires ConvolutionDirectionWellDefinedIfProvided; // Optional direction + requires detail::DataTypeWellDefinedIfProvided; // Optional default data type + requires detail::ElementwiseOpWellDefinedIfProvided; // Optional default elementwise operation }; ``` **Properties:** - **`spatial_dim`**: Dimensionality of the convolution (1D, 2D, or 3D) -- **`direction`**: Operation type (optional, defaults to FORWARD) +- **`direction`**: Operation type (Optional, defaults to FORWARD) - `FORWARD`: Standard forward convolution - `BACKWARD_DATA`: Gradient computation w.r.t. input - `BACKWARD_WEIGHT`: Gradient computation w.r.t. weights -- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8) +- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8). (Optional, defaults to UNDEFINED_DATA_TYPE which indicates the type should be inferred or specified per-tensor, may be overridden by individual tensors) +- **`elementwise_operation`**: Default elementwise operation for all tensors (Optional, defaults to PASS_THROUGH, may be overridden by individual tensors via their `operation` field) - **`accumulation_data_type`**: Type used for internal accumulation #### 2. Tensor Level @@ -116,7 +118,7 @@ concept ConvTensorDescriptor = requires(T t) { A tensor descriptor encapsulates: - **Configuration**: Layout and data type information -- **Operation** (optional): Fused elementwise operations on this tensor +- **operation** Fused elementwise operations on this tensor (Optional, default provided by ConvSignatureDescriptor) #### 3. Tensor Configuration @@ -126,11 +128,14 @@ Describes the memory layout and data types: template concept TensorConfigDescriptor = requires(T t) { { t.layout } -> std::convertible_to; - { t.data_type } -> std::convertible_to; // Optional override + requires detail::DataTypeWellDefinedIfProvided; // Override data type (Optional, default provided by ConvSignatureDescriptor) }; ``` **Layout Types** (dimension-specific): +- **Special Values**: + - `UNDEFINED_TENSOR_LAYOUT`: Placeholder value indicating layout is not yet specified or should be inferred + - **1D Convolution**: - Input: `GNCW`, `GNWC`, `NWGC`, `NGCW`, `G_NW_C_strided` - Weight: `GKXC`, `GKCX`, `KXGC`, `G_K_X_C_strided` @@ -146,6 +151,9 @@ concept TensorConfigDescriptor = requires(T t) { - Weight: `GKZYXC`, `GKCZYX`, `KZYXGC`, `G_K_ZYX_C_strided` - Output: `GNKDHW`, `GNDHWK`, `NDHWGK`, `NGKDHW`, `G_NDHW_K_strided` +- **Bias Tensors**: + - `GC`, `G_C_strided`, `G_K_strided` + Where: - `G` = Groups - `N` = Batch size diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index bf7e89fcaa..29a04d9b6c 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -15,29 +15,31 @@ namespace ck_tile::builder { /* Descriptors for individual elements of the algorithm description */ /********************************************************************/ +// Common concept for size-related fields +template +concept SizeType = std::unsigned_integral>; + // Concept for thread block dimensions for a GEMM problem. template concept ThreadBlockDescriptor = requires(T t) { - { t.block_size } -> std::convertible_to; - { t.tile_size.m } -> std::convertible_to; - { t.tile_size.n } -> std::convertible_to; - { t.tile_size.k } -> std::convertible_to; + { t.block_size } -> SizeType; + { t.tile_size.m } -> SizeType; + { t.tile_size.n } -> SizeType; + { t.tile_size.k } -> SizeType; }; // Concept for parameters that describe a gridwise XDL GEMM problem. template concept GridwiseXdlGemmDescriptor = requires(T t) { - { t.ak1 } -> std::convertible_to; - { t.bk1 } -> std::convertible_to; - { t.m_per_xdl } -> std::convertible_to; - { t.n_per_xdl } -> std::convertible_to; - { t.m_xdl_per_wave } -> std::convertible_to; - { t.n_xdl_per_wave } -> std::convertible_to; + { t.m_per_xdl } -> SizeType; + { t.n_per_xdl } -> SizeType; + { t.m_xdl_per_wave } -> SizeType; + { t.n_xdl_per_wave } -> SizeType; }; // Concept for parameter that describe block GEMM problem. template -concept BlockGemmDescriptor = requires(T t) { +concept BlockGemmPipelineDescriptor = requires(T t) { { t.pipeline_version } -> std::convertible_to; { t.scheduler } -> std::convertible_to; }; @@ -45,37 +47,48 @@ concept BlockGemmDescriptor = requires(T t) { // Concept for parameters that describe a gridwise WMMA GEMM problem. template concept GridwiseWmmaGemmDescriptor = requires(T t) { - { t.k1 } -> std::convertible_to; - { t.m_per_wmma } -> std::convertible_to; - { t.n_per_wmma } -> std::convertible_to; - { t.m_wmma_per_wave } -> std::convertible_to; - { t.n_wmma_per_wave } -> std::convertible_to; - { t.pipeline_version } -> std::convertible_to; + { t.k1 } -> SizeType; + { t.m_per_wmma } -> SizeType; + { t.n_per_wmma } -> SizeType; + { t.m_wmma_per_wave } -> SizeType; + { t.n_wmma_per_wave } -> SizeType; }; // Concept for vectorized data transfer for convolution input tensors. template -concept BlockTransferDescriptor = requires(T t) { - { t.k0 } -> std::convertible_to; - { t.m_n } -> std::convertible_to; - { t.k1 } -> std::convertible_to; +concept BlockTransferDescriptor3D = requires(T t) { + { t.k0 } -> SizeType; + { t.m_n } -> SizeType; + { t.k1 } -> SizeType; }; +template +concept BlockTransferDescriptor4D = requires(T t) { + { t.k0 } -> SizeType; + { t.m_n } -> SizeType; + { t.k1 } -> SizeType; + { t.k_batch_size } -> SizeType; +}; + +template +concept BlockTransferDescriptor = (ThreadClusterRank == 3 && BlockTransferDescriptor3D) || + (ThreadClusterRank == 4 && BlockTransferDescriptor4D); + // Concept for thread cluster dimensions for GEMM output tensor. template concept ThreadClusterDescriptor = requires(T t) { - { t.m_block } -> std::convertible_to; - { t.m_wave_per_xdl } -> std::convertible_to; - { t.n_block } -> std::convertible_to; - { t.n_wave_per_xdl } -> std::convertible_to; + { t.m_block } -> SizeType; + { t.m_wave_per_xdl } -> SizeType; + { t.n_block } -> SizeType; + { t.n_wave_per_xdl } -> SizeType; }; // Concept for the LDS transfer for the convolution input tensors. template concept LdsTransferDescriptor = requires(T t) { - { t.src_vector_dim } -> std::convertible_to; - { t.src_scalar_per_vector } -> std::convertible_to; - { t.lds_dst_scalar_per_vector } -> std::convertible_to; + { t.src_vector_dim } -> SizeType; + { t.src_scalar_per_vector } -> SizeType; + { t.lds_dst_scalar_per_vector } -> SizeType; { t.is_direct_load } -> std::convertible_to; { t.lds_padding } -> std::convertible_to; }; @@ -84,33 +97,35 @@ concept LdsTransferDescriptor = requires(T t) { // LDS). template concept EpilogueDescriptor = requires(T t) { - { t.m_xdl_per_wave_per_shuffle } -> std::convertible_to; - { t.n_per_wave_per_shuffle } -> std::convertible_to; - { t.scalar_per_vector } -> std::convertible_to; + { t.m_xdl_per_wave_per_shuffle } -> SizeType; + { t.n_per_wave_per_shuffle } -> SizeType; + { t.scalar_per_vector } -> SizeType; }; // Concept for the thread cluster access order template -concept AccessOrderDescriptor = requires(T t) { +concept ThreadClusterOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; +} || requires(T t) { + { t.order } -> std::convertible_to>; }; // Concept for thread block dimensions for a GEMM problem for CK Tile (Block // size is deduced from block gemm structure). template concept TileThreadBlockDescriptor = requires(T t) { - { t.tile_size.m } -> std::convertible_to; - { t.tile_size.n } -> std::convertible_to; - { t.tile_size.k } -> std::convertible_to; + { t.tile_size.m } -> SizeType; + { t.tile_size.n } -> SizeType; + { t.tile_size.k } -> SizeType; }; // Concept for thread block dimensions for a GEMM problem for CK Tile (Block // size is deduced from block gemm structure). template concept TileTransferDescriptor = requires(T t) { - { t.a_scalar_per_vector } -> std::convertible_to; - { t.b_scalar_per_vector } -> std::convertible_to; - { t.c_scalar_per_vector } -> std::convertible_to; + { t.a_scalar_per_vector } -> SizeType; + { t.b_scalar_per_vector } -> SizeType; + { t.c_scalar_per_vector } -> SizeType; }; // Concept to check if struct specifies block GEMM (CK Tile). @@ -159,30 +174,51 @@ concept SpecifiesTileThreadBlock = requires { // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept SpecifiesGridwiseXdlGemm = requires { - { T::gridwise_gemm } -> GridwiseXdlGemmDescriptor; +concept GridwiseFwdXdlGemmDescriptor = requires(T t) { + { t.ak1 } -> SizeType; + { t.bk1 } -> SizeType; + { t.xdl_params } -> GridwiseXdlGemmDescriptor; +}; + +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept GridwiseBwdXdlGemmDescriptor = requires(T t) { + { t.k1 } -> SizeType; + { t.xdl_params } -> GridwiseXdlGemmDescriptor; +}; + +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept SpecifiesGridwiseFwdXdlGemm = requires(T t) { + { t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; +}; + +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept SpecifiesGridwiseBwdXdlGemm = requires(T t) { + { t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise WMMA GEMM info. template -concept SpecifiesGridwiseWmmaGemm = requires { - { T::gridwise_gemm } -> GridwiseWmmaGemmDescriptor; +concept SpecifiesGridwiseWmmaGemm = requires(T t) { + { t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor; }; // Concept to check if a struct specifies convolution input and output block transfer info. -template +template concept SpecifiesBlockTransfer = requires(T t) { - { T::transfer.a.block_transfer } -> BlockTransferDescriptor; - { T::transfer.b.block_transfer } -> BlockTransferDescriptor; + { T::transfer.a.block_transfer } -> BlockTransferDescriptor; + { T::transfer.b.block_transfer } -> BlockTransferDescriptor; { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; // Concept to check if a struct specifies convolution scalar per vector infor for A, B and C. template concept SpecifiesTileTransfer = requires(T t) { - { T::transfer.a_scalar_per_vector } -> std::convertible_to; - { T::transfer.b_scalar_per_vector } -> std::convertible_to; - { T::transfer.c_scalar_per_vector } -> std::convertible_to; + { T::transfer.a_scalar_per_vector } -> SizeType; + { T::transfer.b_scalar_per_vector } -> SizeType; + { T::transfer.c_scalar_per_vector } -> SizeType; }; // Concept to check if a struct specifies LDS transfer info for tensors A, B, and C. @@ -195,23 +231,27 @@ concept SpecifiesLdsTransfer = requires(T t) { // Concept to check if a struct specifies thread cluster access order info. template -concept SpecifiesThreadClusterAccessOrder = requires(T t) { - { T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor; - { T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor; +concept SpecifiesThreadClusterArrangeOrder = requires(T t) { + { T::transfer.a.thread_cluster_arrange_order } -> ThreadClusterOrderDescriptor; + { T::transfer.b.thread_cluster_arrange_order } -> ThreadClusterOrderDescriptor; }; // Concept to check if a struct specifies source access order info. template concept SpecifiesSourceAccessOrder = requires(T t) { - { T::transfer.a.src_access_order } -> AccessOrderDescriptor; - { T::transfer.b.src_access_order } -> AccessOrderDescriptor; + { T::transfer.a.src_access_order } -> ThreadClusterOrderDescriptor; + { T::transfer.b.src_access_order } -> ThreadClusterOrderDescriptor; }; // Concept to check if struct specifies block GEMM. template concept SpecifiesBlockGemm = requires { - { T::block_gemm.pipeline_version } -> std::convertible_to; - { T::block_gemm.scheduler } -> std::convertible_to; + { T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor; +}; + +template +concept SpecifiesGridwiseGemmPipeline = requires { + { T::pipeline_version } -> std::convertible_to; }; // Concept to check if struct specifies block GEMM (CK Tile). @@ -244,7 +284,12 @@ concept SpecifiesTileConvSpecialization = requires { template concept SpecifiesFwdConvSpecialization = requires { - { T::fwd_specialization } -> std::convertible_to; + { T::fwd_specialization } -> std::convertible_to; +}; + +template +concept SpecifiesBwdWeightConvSpecialization = requires { + { T::bwd_weight_specialization } -> std::convertible_to; }; template @@ -254,12 +299,12 @@ concept SpecifiesGemmSpecialization = requires { template concept SpecifiesNumPrefetchStages = requires { - { T::num_gemm_k_prefetch_stages } -> std::convertible_to; + { T::num_gemm_k_prefetch_stages } -> SizeType; }; template concept SpecifiesNumGroupsToMerge = requires { - { T::num_groups_to_merge } -> std::convertible_to; + { T::num_conv_groups_to_merge } -> SizeType; }; template @@ -267,12 +312,59 @@ concept SpecifiesLoopScheduler = requires { { T::loop_scheduler } -> std::convertible_to; }; +template +concept SpecifiesGenericInstance = !requires { + { T::specialization }; +}; + +template +concept SpecifiesTransposeTransfer = requires { + { T::max_transpose_transfer_src_scalar_per_vector } -> SizeType; + { T::max_transpose_transfer_dst_scalar_per_vector } -> SizeType; +}; + +template +concept HasTransposeTransfer = requires { + { T::max_transpose_transfer_src_scalar_per_vector }; + { T::max_transpose_transfer_dst_scalar_per_vector }; +}; + +template +concept TransposeTransferWellDefinedIfProvided = + !HasTransposeTransfer || SpecifiesTransposeTransfer; + +template +concept SpecifiesGemmBatchOptions = requires { + { T::num_conv_groups_to_merge } -> SizeType; +}; + +/******************************************** */ +/* Algorithm specialization concepts */ +/******************************************** */ template concept SpecifiesLargeTensorSupport = requires { { T::specialization } -> std::convertible_to; requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR; }; +template +concept SpecifiesReferenceAlgorithm = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::REFERENCE; +}; + +template +concept SpecifiesTwoStageSupport = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE; +}; + +template +concept SpecifiesMultipleDSupport = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D; +}; + /******************************************** */ /* DL-specific descriptors and requirements */ /******************************************** */ @@ -280,11 +372,11 @@ concept SpecifiesLargeTensorSupport = requires { // Concept for DL thread configuration template concept DlThreadConfigDescriptor = requires(T t) { - { t.k0_per_block } -> std::convertible_to; - { t.k1 } -> std::convertible_to; - { t.m1_per_thread } -> std::convertible_to; - { t.n1_per_thread } -> std::convertible_to; - { t.k_per_thread } -> std::convertible_to; + { t.k0_per_block } -> SizeType; + { t.k1 } -> SizeType; + { t.m1_per_thread } -> SizeType; + { t.n1_per_thread } -> SizeType; + { t.k_per_thread } -> SizeType; }; // Concept for DL thread cluster @@ -295,23 +387,29 @@ concept DlThreadClusterDescriptor = requires(T t) { }; // Concept for DL block transfer -template +template concept DlBlockTransferDescriptor = requires(T t) { - { t.thread_slice_lengths } -> std::convertible_to>; - { t.thread_cluster_lengths } -> std::convertible_to>; - { t.thread_cluster_arrange_order } -> std::convertible_to>; - { t.src_access_order } -> std::convertible_to>; - { t.src_vector_tensor_lengths } -> std::convertible_to>; - { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; - { t.dst_vector_tensor_lengths } -> std::convertible_to>; + { t.thread_slice_lengths } -> std::convertible_to>; + { t.thread_cluster_lengths } -> std::convertible_to>; + { t.thread_cluster_arrange_order } -> std::convertible_to>; + { t.src_access_order } -> std::convertible_to>; + { t.src_vector_tensor_lengths } -> std::convertible_to>; + { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; + { t.dst_vector_tensor_lengths } -> std::convertible_to>; }; +template +concept DlBlockTransferDescriptor4D = DlBlockTransferDescriptor; + +template +concept DlBlockTransferDescriptor5D = DlBlockTransferDescriptor; + // Concept for DL epilogue template concept DlEpilogueDescriptor = requires(T t) { { t.src_dst_access_order } -> std::convertible_to>; - { t.src_dst_vector_dim } -> std::convertible_to; - { t.dst_scalar_per_vector } -> std::convertible_to; + { t.src_dst_vector_dim } -> SizeType; + { t.dst_scalar_per_vector } -> SizeType; }; // Concept to check if algorithm specifies DL thread config @@ -328,15 +426,21 @@ concept SpecifiesDlThreadCluster = requires { // Concept to check if algorithm specifies DL block transfer template -concept SpecifiesDlBlockTransfer = requires { - { T::transfer.a.block_transfer } -> DlBlockTransferDescriptor; - { T::transfer.b.block_transfer } -> DlBlockTransferDescriptor; +concept SpecifiesDlFwdBlockTransfer = requires { + { T::transfer.a } -> DlBlockTransferDescriptor4D; + { T::transfer.b } -> DlBlockTransferDescriptor4D; +}; + +template +concept SpecifiesDlBwdBlockTransfer = requires { + { T::transfer.a } -> DlBlockTransferDescriptor5D; + { T::transfer.b } -> DlBlockTransferDescriptor5D; }; // Concept to check if algorithm specifies DL C thread transfer template concept SpecifiesDlEpilogue = requires { - { T::transfer.c.epilogue } -> DlEpilogueDescriptor; + { T::transfer.c } -> DlEpilogueDescriptor; }; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index 10a619024a..5196eae6c7 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -5,6 +5,9 @@ #include #include +#include +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/arch/arch.hpp" namespace ck_tile::builder { @@ -29,10 +32,240 @@ concept OutputVectorTransferLimits = requires { // Limits for access order. Must be a permutation of {0, 1, 2}. template -concept AccessOrderLimits = requires { +concept AccessOrderLimits3D = requires { requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) && (Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) && - (Value[2] >= 0 && Value[2] < 3)); + (Value[2] >= 0 && Value[2] < 3) && (Value.Size() == 3)); }; +// Limits for access order. Must be a permutation of {0, 1, 2, 3}. +template +concept AccessOrderLimits4D = requires { + requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[0] != Value[3]) && + (Value[1] != Value[2]) && (Value[1] != Value[3]) && (Value[2] != Value[3]) && + (Value[0] >= 0 && Value[0] < 4) && (Value[1] >= 0 && Value[1] < 4) && + (Value[2] >= 0 && Value[2] < 4) && (Value[3] >= 0 && Value[3] < 4) && + (Value.Size() == 4)); +}; + +namespace detail { + +// Helper to check if access order is a valid permutation +template +constexpr bool is_valid_permutation() +{ + constexpr auto size = Value.Size(); + + // Check all values are in range [0, size) + for(size_t i = 0; i < size; ++i) + { + if(Value[i] < 0 || Value[i] >= static_cast(size)) + return false; + } + + // Check all values are unique (valid permutation) + for(size_t i = 0; i < size; ++i) + { + for(size_t j = i + 1; j < size; ++j) + { + if(Value[i] == Value[j]) + return false; + } + } + + return true; +} + +} // namespace detail + +// Generic access order limits. Must be a valid permutation of {0, 1, ..., Dims-1}. +// Works with both 3D and 4D (or any dimensionality) access orders. +template +concept AccessOrderLimits = requires { + requires Value.Size() == Dims; + requires detail::is_valid_permutation(); +}; + +namespace detail { + +// Helper trait to get compile-time size from ck::Array +template +concept HasStaticSize = requires { + { T::Size() } -> std::convertible_to; +}; + +// Helper trait to get compile-time size from std::array and similar +template +concept HasTupleSize = requires { + { std::tuple_size::value } -> std::convertible_to; +}; + +// Helper for dependent static_assert +template +constexpr bool always_false = false; + +// Get compile-time size of a range +template +constexpr size_t get_range_size() +{ + if constexpr(HasStaticSize) + { + return Range::Size(); + } + else if constexpr(HasTupleSize) + { + return std::tuple_size_v; + } + else + { + static_assert(always_false, "Unsupported type of range object."); + } +} + +// Fold expression implementation for product calculation +template +constexpr auto get_cluster_size_impl(const Range& range, std::index_sequence) +{ + using value_type = std::remove_cvref_t; + return ((range[Is]) * ... * value_type{1}); +} + +// Generic function that calculates the product of all elements in a range +// Works with any indexable range with compile-time size (ck::Array, std::array, etc.) +template + requires requires(Range r) { + r[0]; // Must be indexable + get_range_size(); // Must have compile-time size + } +constexpr auto get_cluster_size(const Range& range) +{ + return get_cluster_size_impl(range, std::make_index_sequence()>{}); +} + +// Calculate K dimension coverage (k0 * k1, with vectorization if applicable) +template +constexpr auto get_k_coverage() +{ + auto k0 = BlockTransfer.thread_cluster_dims[0]; + auto k1 = BlockTransfer.thread_cluster_dims[2]; + auto k_total = k0 * k1; + + // If vectorization is on k0 (dim 0) or k1 (dim 2), multiply by vector size + if constexpr(BlockTransfer.src_vector_dim == 0 || BlockTransfer.src_vector_dim == 2) + { + k_total *= BlockTransfer.src_scalar_per_vector; + } + + return k_total; +} + +// Calculate M/N dimension coverage (m_n, with vectorization if applicable) +template +constexpr auto get_mn_coverage() +{ + auto mn = BlockTransfer.thread_cluster_dims[1]; + + // If vectorization is on m_n (dim 1), multiply by vector size + if constexpr(BlockTransfer.src_vector_dim == 1) + { + mn *= BlockTransfer.src_scalar_per_vector; + } + + return mn; +} + +template +constexpr auto get_data_max_vec_size() +{ + constexpr auto max_vec_inst_size_bytes = get_max_mem_vec_inst_width(); + static_assert(max_vec_inst_size_bytes % DataTypeSize == 0, + "The max vec instruction size is not a multiple of given data type size."); + return max_vec_inst_size_bytes / DataTypeSize; +} + +} // namespace detail + +// product of thread cluster lengths must be <= workgroup size +template +concept ValidBlockTransferClusterSize = + requires { requires detail::get_cluster_size(BlockTransfer.thread_cluster_dims) <= BlockSize; }; + +// Check that thread cluster covers the K and M dimensions for A transfer +template +concept ThreadsCoverATile = requires { + // K dimension: k0 * k1 * (vectorization) must divide K + requires TileSize.k % detail::get_k_coverage() == 0; + // M dimension: m_n * (vectorization) must divide M + requires TileSize.m % detail::get_mn_coverage() == 0; +}; + +// Check that thread cluster covers the K and N dimensions for B transfer +template +concept ThreadsCoverBTile = requires { + // K dimension: k0 * k1 * (vectorization) must divide K + requires TileSize.k % detail::get_k_coverage() == 0; + // N dimension: m_n * (vectorization) must divide N + requires TileSize.n % detail::get_mn_coverage() == 0; +}; + +template +concept ThreadsCoverCTile = requires { + // M dimension: m_wave_per_xdl must divide M + requires TileSize.m % CBlockTransfer.thread_cluster_dims[1] == 0; + // N dimension: n_wave_per_xdl * (vectorization) must divide N + requires TileSize.n % (CBlockTransfer.thread_cluster_dims[3] * + CBlockTransfer.scalar_per_vector) == 0; +}; + +template +concept IsPowerOf2 = (Value > 0) && ((Value & (Value - 1)) == 0); + +template +concept IsVectorSizeValid = + IsPowerOf2 && (ScalarPerVec <= detail::get_data_max_vec_size()); + +// Composite concept for input block transfer validation (A) +// Includes all validations: vector transfer limits, access order, cluster size, +// vector size validity, and tile coverage +template +concept ValidABlockTransfer = + InputVectorTransferLimits && + AccessOrderLimits && + AccessOrderLimits && + ValidBlockTransferClusterSize && + IsVectorSizeValid && + IsVectorSizeValid && + ThreadsCoverATile; + +// Composite concept for input block transfer validation (B) +template +concept ValidBBlockTransfer = + InputVectorTransferLimits && + AccessOrderLimits && + AccessOrderLimits && + ValidBlockTransferClusterSize && + IsVectorSizeValid && + IsVectorSizeValid && + ThreadsCoverBTile; + +// Composite concept for output block transfer validation (C) +template +concept ValidCBlockTransfer = + OutputVectorTransferLimits && + ValidBlockTransferClusterSize && + IsVectorSizeValid && + ThreadsCoverCTile; + +// Usage: IsValidLayout +template +concept IsValidLayout = ck_tile::is_any_value_of(ACTUAL_LAYOUT, VALID_LAYOUTS...); + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 39e081ec8d..c9cb6fe767 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -80,6 +80,7 @@ concept ConvOutputLayout3D = (L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) || (L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided); +namespace detail { template concept HasDataType = requires(T t) { { t.data_type }; @@ -94,10 +95,11 @@ concept DataTypeWellDefinedIfProvided = requires(T t) { }; }; +} // namespace detail template concept TensorConfigDescriptor = requires(T t) { { t.layout } -> std::convertible_to; - requires DataTypeWellDefinedIfProvided; + requires detail::DataTypeWellDefinedIfProvided; }; template @@ -116,7 +118,6 @@ template struct IsArrayOfTensorConfigDescriptors> : std::true_type { }; -} // namespace detail template concept ConvertibleToArrayOfTensorConfigs = @@ -128,11 +129,12 @@ concept AuxiliaryOperandConfigsWellDefinedIfProvided = requires(T t) { { t.auxiliary_operand_configs } -> ConvertibleToArrayOfTensorConfigs; }; }; +} // namespace detail template concept TensorOperatorDescriptor = requires(T t) { { t.elementwise_operation } -> std::convertible_to; - requires AuxiliaryOperandConfigsWellDefinedIfProvided; + requires detail::AuxiliaryOperandConfigsWellDefinedIfProvided; }; template @@ -140,6 +142,8 @@ concept HasTensorOp = requires(T t) { { t.operation }; }; +namespace detail { + template concept HasConvolutionDirection = requires(T t) { { t.direction }; @@ -159,11 +163,13 @@ concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) { }; }; +} // namespace detail + // Concept for the convolution tensor template concept ConvTensorDescriptor = requires(T t) { { t.config } -> TensorConfigDescriptor; - requires ElementwiseOpWellDefinedIfProvided; + requires detail::ElementwiseOpWellDefinedIfProvided; }; template @@ -179,8 +185,9 @@ concept ConvSignatureDescriptor = requires(T t) { { t.input } -> ConvTensorDescriptor; { t.weight } -> ConvTensorDescriptor; { t.output } -> ConvTensorDescriptor; - requires ConvolutionDirectionWellDefinedIfProvided; - requires DataTypeWellDefinedIfProvided; + requires detail::ConvolutionDirectionWellDefinedIfProvided; + requires detail::DataTypeWellDefinedIfProvided; + requires detail::ElementwiseOpWellDefinedIfProvided; }; // Concept to validate a convolution signature's values. @@ -221,4 +228,13 @@ concept ValidConvWeightLayoutForSpatialDim = (SpatialDim == 1 && ConvWeightLayout1D) || (SpatialDim == 2 && ConvWeightLayout2D) || (SpatialDim == 3 && ConvWeightLayout3D); +// Constraint for 3D conv signature. +template +concept Is3D = requires { + requires Sig.spatial_dim == 3; + requires ConvInputLayout3D; + requires ConvOutputLayout3D; + requires ConvWeightLayout3D; +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp new file mode 100644 index 0000000000..79b818555e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -0,0 +1,128 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory { + +// Base algorithm concepts +template +concept TileTransferParameters = + SpecifiesBlockTransfer && SpecifiesLdsTransfer && + SpecifiesThreadClusterArrangeOrder && SpecifiesSourceAccessOrder; + +template +concept SpecifiesTileTransferParameters3D = TileTransferParameters; + +template +concept SpecifiesTileTransferParameters4D = TileTransferParameters; + +template +concept FwdXdlAlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseFwdXdlGemm && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && + SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; + +template +concept BwdXdlAlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters4D && + SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization; + +template +concept BwdXdlV3AlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization && + SpecifiesBlockGemm; + +template +concept BwdWmmaAlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization; + +template +concept BwdWmmaV3AlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization && + SpecifiesBlockGemm; + +// Reference algorithm concept +template +concept ReferenceAlgorithm = ConvAlgorithmDescriptor && SpecifiesReferenceAlgorithm; + +// Tile-based algorithm concept +template +concept TileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && + SpecifiesTileTransfer && SpecifiesTileConvSpecialization && + SpecifiesTileBlockGemm && SpecifiesTileOptimizations; + +// FWD XDL algorithm concepts +template +concept FwdXdlAlgorithm = FwdXdlAlgorithmBase && SpecifiesGenericInstance; + +template +concept LargeTensorAlgorithm = FwdXdlAlgorithmBase && SpecifiesLargeTensorSupport; + +template +concept FwdXdlV3Algorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseFwdXdlGemm && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesBlockGemm; + +// FWD WMMA algorithm concepts +template +concept FwdWmmaAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseWmmaGemm && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && + SpecifiesGridwiseGemmPipeline; + +// FWD DL algorithms +template +concept FwdDlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && + SpecifiesDlFwdBlockTransfer && SpecifiesDlEpilogue; + +// BWD weight XDL algorithm concepts +template +concept BwdXdlAlgorithm = + BwdXdlAlgorithmBase && SpecifiesTransposeTransfer && SpecifiesGenericInstance; + +template +concept BwdMultiDXdlAlgorithm = BwdXdlAlgorithmBase && SpecifiesMultipleDSupport; + +template +concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase && SpecifiesGenericInstance; + +template +concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesGemmBatchOptions && SpecifiesTwoStageSupport; + +// BWD weight WMMA algorithm concepts +template +concept BwdWmmaAlgorithm = + BwdWmmaAlgorithmBase && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && + SpecifiesGridwiseGemmPipeline && SpecifiesGenericInstance; + +template +concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesMultipleDSupport; + +template +concept BwdWmmaV3Algorithm = + BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && SpecifiesGenericInstance; + +template +concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesGemmBatchOptions && SpecifiesTwoStageSupport; + +// BWD weigth DL algorithms +template +concept BwdDlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesBwdWeightConvSpecialization && SpecifiesDlThreadConfig && + SpecifiesDlThreadCluster && SpecifiesDlBwdBlockTransfer && SpecifiesDlEpilogue; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp new file mode 100644 index 0000000000..fda1659c75 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp @@ -0,0 +1,131 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeight_Dl instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightDlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + + // DL-specific parameters from algorithm descriptor + static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config; + static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block; + static constexpr ck::index_t K1 = DL_THREAD_CFG.k1; + static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread; + static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread; + static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread; + + // Thread cluster from descriptor + static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster; + using M1N1ThreadClusterM1Xs = to_sequence_v; + using M1N1ThreadClusterN1Xs = to_sequence_v; + + // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format + static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a; + using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using ABlockTransferSrcAccessOrder = to_sequence_v; + using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + + // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format + static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b; + using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using BBlockTransferSrcAccessOrder = to_sequence_v; + using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + + // C Thread Transfer from descriptor + static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c; + using CThreadTransferSrcDstAccessOrder = to_sequence_v; + static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; + static constexpr ck::index_t CThreadTransferDstScalarPerVector = + DL_C_TRANSFER.dst_scalar_per_vector; + + // The DL forward convolution kernel class instance + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + K0PerBlock, + K1, + M1PerThread, + N1PerThread, + KPerThread, + M1N1ThreadClusterM1Xs, + M1N1ThreadClusterN1Xs, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp new file mode 100644 index 0000000000..b02dea9558 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp @@ -0,0 +1,110 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight && Is3D +struct ConvBwdWeightMultiDWmmaV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Layouts::DsLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Types::DsDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + typename Types::OutComputeType, + typename Types::InComputeType>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp new file mode 100644 index 0000000000..4f6812617a --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp @@ -0,0 +1,103 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightMultiDXdlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Layouts::DsLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Types::DsDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::OutComputeType, + typename Types::InComputeType>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp new file mode 100644 index 0000000000..adf108bac4 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp @@ -0,0 +1,111 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffle_V3 instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightTwoStageWmmaV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + ALGORITHM.num_conv_groups_to_merge, + typename Types::OutComputeType, + typename Types::InComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp new file mode 100644 index 0000000000..d887c1c1ce --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -0,0 +1,111 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightTwoStageXdlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + ALGORITHM.num_conv_groups_to_merge, + typename Types::OutComputeType, + typename Types::InComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp new file mode 100644 index 0000000000..4067845291 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -0,0 +1,109 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight && Is3D +struct ConvBwdWeightWmmaFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = + internal::SetGridwiseGemmPipelineVersion(); + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + ALGORITHM.num_gemm_k_prefetch_stages, + LOOP_SCHEDULER, + GRIDWISE_GEMM_PIPELINE_VERSION>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp new file mode 100644 index 0000000000..027c8a1fba --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp @@ -0,0 +1,109 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightWmmaV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + typename Types::OutComputeType, + typename Types::InComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp new file mode 100644 index 0000000000..fbb177f333 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -0,0 +1,103 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffle instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightXdlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::OutComputeType, + typename Types::InComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp new file mode 100644 index 0000000000..66a47c5407 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -0,0 +1,108 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightXdlV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + typename Types::OutComputeType, + typename Types::InComputeType>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index c0dd3d8018..e235db4bb0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -57,6 +57,9 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/types.hpp" +// Compile time diagnostics +#include "ck_tile/builder/factory/conv_algorithms.hpp" + // Include all factory implementations #include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp" @@ -65,6 +68,15 @@ #include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" #include "ck_tile/builder/factory/reference_factory.hpp" #include "ck_tile/builder/factory/conv_tile_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp" namespace ck_tile::builder::factory { @@ -87,56 +99,6 @@ namespace ck_tile::builder::factory { // // TODO: Make this dispatch logic much more robust and clear for users. -// Reference algorithm (simplest implementation for validation) -template -concept IsReferenceAlgorithm = ConvAlgorithmDescriptor && requires { - { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::REFERENCE; -}; - -// CK Tile kernel -template -concept IsTileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && - SpecifiesTileTransfer && SpecifiesTileConvSpecialization && - SpecifiesTileBlockGemm && SpecifiesTileOptimizations; - -// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) -template -concept IsXdlV3Algorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesBlockGemm; - -// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply) -template -concept IsXdlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && - SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; - -// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions) -template -concept IsWmmaAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; - -// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts -template -concept IsDlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && - SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; - -// XDL-based kernel with large tensor support -template -concept IsLargeTensorAlgorithm = - IsXdlAlgorithm && SpecifiesLargeTensorSupport; - template @@ -145,35 +107,35 @@ constexpr auto make_conv_instance() using AlgoType = std::remove_const_t; // Reference algorithm supports all directions - if constexpr(IsReferenceAlgorithm) + if constexpr(ReferenceAlgorithm) { return typename ReferenceFactory::Instance{}; } // CK Tile supports common factory for each direction - else if constexpr(IsTileAlgorithm) + else if constexpr(TileAlgorithm) { return typename ConvTileFactory::Instance{}; } // Forward direction (supports most algorithm variants) else if constexpr(ConvDirectionIsForward) { - if constexpr(IsXdlV3Algorithm) + if constexpr(FwdXdlV3Algorithm) { return typename ConvFwdXdlV3Factory::Instance{}; } - else if constexpr(IsXdlAlgorithm) + else if constexpr(FwdXdlAlgorithm) { return typename ConvFwdXdlFactory::Instance{}; } - else if constexpr(IsWmmaAlgorithm) + else if constexpr(FwdWmmaAlgorithm) { return typename ConvFwdWmmaFactory::Instance{}; } - else if constexpr(IsDlAlgorithm) + else if constexpr(FwdDlAlgorithm) { return typename ConvFwdDlFactory::Instance{}; } - else if constexpr(IsLargeTensorAlgorithm) + else if constexpr(LargeTensorAlgorithm) { return typename ConvFwdLargeTensorFactory::Instance{}; } @@ -197,10 +159,55 @@ constexpr auto make_conv_instance() // Backward weight direction (will expand with more algorithms in the future) else if constexpr(ConvDirectionIsBackwardWeight) { - static_assert(false, - "Backward weight convolution: Only reference and tile algorithms " - "supported currently. " - "Optimized kernels (XDL, WMMA, etc.) not yet implemented."); + if constexpr(BwdXdlAlgorithm) + { + return typename ConvBwdWeightXdlFactory::Instance{}; + } + else if constexpr(BwdXdlV3Algorithm) + { + return typename ConvBwdWeightXdlV3Factory::Instance{}; + } + else if constexpr(BwdTwoStageXdlAlgorithm) + { + return + typename ConvBwdWeightTwoStageXdlFactory::Instance{}; + } + else if constexpr(BwdDlAlgorithm) + { + return typename ConvBwdWeightDlFactory::Instance{}; + } + else if constexpr(BwdMultiDXdlAlgorithm) + { + return + typename ConvBwdWeightMultiDXdlFactory::Instance{}; + } + else if constexpr(BwdWmmaV3Algorithm) + { + return typename ConvBwdWeightWmmaV3Factory::Instance{}; + } + else if constexpr(BwdTwoStageWmmaV3Algorithm) + { + return typename ConvBwdWeightTwoStageWmmaV3Factory:: + Instance{}; + } + else if constexpr(BwdWmmaAlgorithm) + { + return typename ConvBwdWeightWmmaFactory::Instance{}; + } + else if constexpr(BwdMultiDWmmaV3Algorithm) + { + return typename ConvBwdWeightMultiDWmmaV3Factory:: + Instance{}; + } + else + { + static_assert( + false, + "No suitable backward weight convolution kernel factory found for the provided " + "ALGORITHM. The ALGORITHM must satisfy requirements for one of: Reference, Tile, " + "XDL, XDL V3, Two-Stage XDL, DL, Multi-D XDL, WMMA V3, Two-Stage " + "WMMA V3, WMMA, or Multi-D WMMA V3 variant."); + } } else { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index ca202aabfd..1d55772dd6 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -24,10 +24,10 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); @@ -48,7 +48,7 @@ struct ConvFwdDlFactory using M1N1ThreadClusterN1Xs = to_sequence_v; // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format - static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a.block_transfer; + static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a; using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = to_sequence_v; using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = @@ -64,7 +64,7 @@ struct ConvFwdDlFactory to_sequence_v; // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format - static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b.block_transfer; + static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b; using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = to_sequence_v; using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = @@ -80,7 +80,7 @@ struct ConvFwdDlFactory to_sequence_v; // C Thread Transfer from descriptor - static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c.epilogue; + static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c; using CThreadTransferSrcDstAccessOrder = to_sequence_v; static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; static constexpr ck::index_t CThreadTransferDstScalarPerVector = @@ -89,18 +89,18 @@ struct ConvFwdDlFactory // The DL forward convolution kernel class instance using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< SPATIAL_DIM, - typename Types::ADataType, - typename Types::BDataType, - typename Types::DsDataTypes, - typename Types::EDataType, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::DsDataType, + typename Types::OutDataType, typename Types::AccDataType, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Layouts::OutLayout, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, FWD_CONV_SPECIALIZATION, GEMM_SPECIALIZATION, BLOCK.block_size, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index fadf41f48a..b80406c37e 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -26,68 +26,106 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); - static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm; - - static constexpr auto FWD_CONV_SPECIALIZATION = - internal::SetFwdConvSpecialization(); - static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); + static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, .gemm_spec = GEMM_SPECIALIZATION}; - static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm; + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = - internal::SetFwdConvBlockTransfer(); + internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - internal::SetFwdConvBlockTransfer(); - static constexpr auto C_BLOCK_TRANSFER = - internal::SetCBlockTransfer(); + internal::SetFwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); - // Check limits for the algorithm parameters. - static_assert(InputVectorTransferLimits); - static_assert(InputVectorTransferLimits); - static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + // Check limits for the data transfer parameters. + static_assert(ValidABlockTransfer); + static_assert(ValidBBlockTransfer); + static_assert(ValidCBlockTransfer); + + using enum TensorLayout; + static_assert(IsValidLayout && + A_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout && + B_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout); // The forward convolution kernel class instance with large tensor support. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::OutDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, - BASE_ALGORITHM.num_gemm_k_prefetch_stages, + ALGORITHM.num_gemm_k_prefetch_stages, BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, GRIDWISE_GEMM.ak1, GRIDWISE_GEMM.bk1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, @@ -106,8 +144,8 @@ struct ConvFwdLargeTensorFactory C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, to_sequence_v, C_BLOCK_TRANSFER.scalar_per_vector, - typename Types::AComputeType, - typename Types::BComputeType, + typename Types::InComputeType, + typename Types::WeiComputeType, LOOP_SCHEDULER>; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 89787cc1b3..74554df7e9 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -26,10 +26,10 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load == ALGORITHM.transfer.b.lds_transfer.is_direct_load, @@ -43,6 +43,7 @@ struct ConvFwdXdlV3Factory static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -51,31 +52,81 @@ struct ConvFwdXdlV3Factory static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); // Check limits for the algorithm parameters. - // TODO: Add more limits checks as needed. - static_assert(InputVectorTransferLimits); - static_assert(InputVectorTransferLimits); - static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(ValidABlockTransfer); + static_assert(ValidBBlockTransfer); + static_assert(ValidCBlockTransfer); + + // Layout validations + using enum TensorLayout; + static_assert(IsValidLayout && + A_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout && + B_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::OutDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, BLOCK.block_size, @@ -84,10 +135,10 @@ struct ConvFwdXdlV3Factory BLOCK.per_block.k, GRIDWISE_GEMM.ak1, GRIDWISE_GEMM.bk1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, @@ -108,8 +159,8 @@ struct ConvFwdXdlV3Factory C_BLOCK_TRANSFER.scalar_per_vector, BLOCK_GEMM.scheduler, BLOCK_GEMM.pipeline_version, - typename Types::AComputeType, - typename Types::BComputeType, + typename Types::InComputeType, + typename Types::WeiComputeType, IS_DIRECT_LOAD>; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index bb84479071..cb36122f7c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -26,10 +26,10 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); @@ -48,31 +48,73 @@ struct ConvFwdWmmaFactory static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); // Check limits for the algorithm parameters. - // TODO: Add more limits checks as needed. - static_assert(InputVectorTransferLimits); - static_assert(InputVectorTransferLimits); - static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(ValidABlockTransfer); + static_assert(ValidBBlockTransfer); + static_assert(ValidCBlockTransfer); + // TODO: verify Ds transfer as well + + // Layout validations (same as DeviceGroupedConvFwdMultipleD_Wmma_CShuffle) + using enum TensorLayout; + static_assert(IsValidLayout && + A_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout && + B_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::OutDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, ALGORITHM.num_gemm_k_prefetch_stages, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 8ec5c633ce..b3be21f1f3 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -26,10 +26,10 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); @@ -39,6 +39,7 @@ struct ConvFwdXdlFactory static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -46,31 +47,81 @@ struct ConvFwdXdlFactory static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); // Check limits for the algorithm parameters. - // TODO: Add more limits checks as needed. - static_assert(InputVectorTransferLimits); - static_assert(InputVectorTransferLimits); - static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(ValidABlockTransfer); + static_assert(ValidBBlockTransfer); + static_assert(ValidCBlockTransfer); + + // Layout validations (same as DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle) + using enum TensorLayout; + static_assert(IsValidLayout && + A_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout && + B_BLOCK_TRANSFER.src_vector_dim == 2); + + static_assert(IsValidLayout); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::OutDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, ALGORITHM.num_gemm_k_prefetch_stages, @@ -80,10 +131,10 @@ struct ConvFwdXdlFactory BLOCK.per_block.k, GRIDWISE_GEMM.ak1, GRIDWISE_GEMM.bk1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, @@ -102,10 +153,10 @@ struct ConvFwdXdlFactory C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, to_sequence_v, C_BLOCK_TRANSFER.scalar_per_vector, - typename Types::AComputeType, - typename Types::BComputeType, + typename Types::InComputeType, + typename Types::WeiComputeType, LOOP_SCHEDULER, - ALGORITHM.num_groups_to_merge>; + ALGORITHM.num_conv_groups_to_merge>; }; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp index cce95cb3f1..6ce508b47d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp @@ -116,7 +116,6 @@ struct ConvTileFactory BLOCK_GEMM.warp_tile.k, GroupedConvTraitsType::FixedGemmParams::TransposeC, // TODO:: This template parameter will be moved inside the kernel - ck_tile::memory_operation_enum::set, BLOCK_GEMM.num_wave_groups, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, SCALAR_PER_VECTOR.c>>; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 5da1e4eadb..249fe0ba24 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -10,27 +10,28 @@ namespace ck_tile::builder::factory::internal { // Block transfer parameters for A or B tensor. +template struct BlockTransfer { - ck::Array thread_cluster_dims = {0, 0, 0}; // k0, m, k1 - ck::Array thread_cluster_order = {0, 0, 0}; - ck::Array src_access_order = {0, 0, 0}; - size_t src_vector_dim = 0; - size_t src_scalar_per_vector = 0; - size_t lds_dst_scalar_per_vector = 0; - bool is_direct_load = false; - bool lds_padding = false; + ck::Array thread_cluster_dims{}; + ck::Array thread_cluster_order{}; + ck::Array src_access_order{}; + size_t src_vector_dim = 0; + size_t src_scalar_per_vector = 0; + size_t lds_dst_scalar_per_vector = 0; + bool is_direct_load = false; + bool lds_padding = false; }; template -constexpr BlockTransfer SetFwdConvBlockTransfer() +constexpr BlockTransfer<> SetFwdConvBlockTransfer() { auto& block_xfer = TRANSFER.block_transfer; - auto& block_order = TRANSFER.block_transfer_access_order; + auto& block_order = TRANSFER.thread_cluster_arrange_order; auto& src_order = TRANSFER.src_access_order; auto& lds_cfg = TRANSFER.lds_transfer; - return BlockTransfer{ + return BlockTransfer<>{ .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, @@ -42,6 +43,59 @@ constexpr BlockTransfer SetFwdConvBlockTransfer() }; } +template +constexpr auto SetBwdConvBlockTransfer() +{ + auto& block_xfer = TRANSFER.block_transfer; + auto& block_order = TRANSFER.thread_cluster_arrange_order; + auto& src_order = TRANSFER.src_access_order; + auto& lds_cfg = TRANSFER.lds_transfer; + + constexpr auto array_length = block_order.order.size(); + static_assert(block_order.order.size() == src_order.order.size(), + "Mismatched size between block order and src order"); + + if constexpr(array_length == 3) + { + return BlockTransfer<3>{ + .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], + block_order.order[1], + block_order.order[2]}, + .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .lds_padding = lds_cfg.lds_padding, + }; + } + else if constexpr(array_length == 4) + { + return BlockTransfer<4>{ + .thread_cluster_dims = {block_xfer.k_batch_size, + block_xfer.k0, + block_xfer.m_n, + block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], + block_order.order[1], + block_order.order[2], + block_order.order[3]}, + .src_access_order = {src_order.order[0], + src_order.order[1], + src_order.order[2], + src_order.order[3]}, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .lds_padding = lds_cfg.lds_padding, + }; + } + else + { + static_assert(false, "Internal error: Unsupported array length"); + } +} + // Block transfer parameters for C tensor. struct CBlockTransfer { diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp index a39cd7410b..0cc43fc679 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp @@ -62,14 +62,15 @@ consteval auto GetElementwiseOp() } template -struct ElementwiseOps +struct ConvElementwiseOps { static constexpr auto input_op = GetElementwiseOp(); static constexpr auto weight_op = GetElementwiseOp(); static constexpr auto output_op = GetElementwiseOp(); - using AElementwiseOp = typename decltype(input_op)::Op; - using BElementwiseOp = typename decltype(weight_op)::Op; - using CDEElementwiseOp = typename decltype(output_op)::Op; + + using InElementwiseOp = typename decltype(input_op)::Op; + using WeiElementwiseOp = typename decltype(weight_op)::Op; + using OutElementwiseOp = typename decltype(output_op)::Op; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp index a6c0b48c54..fd6de9ae21 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp @@ -190,7 +190,7 @@ consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence) decltype(TensorLayoutToCK())...>{}; } -template +template requires(ConvSpatialDim) struct AuxiliaryTensorLayouts { @@ -200,34 +200,32 @@ struct AuxiliaryTensorLayouts }; // TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias). -template +template requires(HasElementwiseOpWithAuxiliaryOperands) consteval auto GetAuxiliaryTensorLayouts() { return AuxiliaryTensorLayouts{}; + SPATIAL_DIM>{}; } -template +template requires(!HasElementwiseOpWithAuxiliaryOperands) consteval auto GetAuxiliaryTensorLayouts() { return EmptyAuxiliaryTensorLayout{}; } -template +template requires(ConvSpatialDim && ValidConvInputLayoutForSpatialDim && ValidConvWeightLayoutForSpatialDim && ValidConvOutputLayoutForSpatialDim) struct ConvTensorLayouts { - static_assert(DIR == ConvDirection::FORWARD, "Only Forward convolution is supported."); - using ALayout = decltype(TensorLayoutToCK()); - using BLayout = decltype(TensorLayoutToCK()); - using ELayout = decltype(TensorLayoutToCK()); - using DsLayout = decltype(GetAuxiliaryTensorLayouts())::type; + using InLayout = decltype(TensorLayoutToCK()); + using WeiLayout = decltype(TensorLayoutToCK()); + using OutLayout = decltype(TensorLayoutToCK()); + using DsLayout = decltype(GetAuxiliaryTensorLayouts())::type; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index c819e11d00..0c017e0c47 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -33,7 +33,7 @@ struct DataTypeToCK using type = float; }; template <> -struct DataTypeToCK +struct DataTypeToCK { using type = int32_t; }; @@ -47,6 +47,11 @@ struct DataTypeToCK { using type = ck::f8_t; }; +template <> +struct DataTypeToCK +{ + using type = uint8_t; +}; struct CK_empty_tuple { @@ -151,7 +156,7 @@ consteval auto GetAuxiliaryTensorDataTypes() } template -struct FwdConvTensorDataTypes +struct ConvTensorDataTypes { static constexpr auto input_types = GetTensorDataAndComputeTypes(); @@ -160,20 +165,17 @@ struct FwdConvTensorDataTypes static constexpr auto output_types = GetTensorDataAndComputeTypes(); - using ADataType = typename decltype(input_types.first)::type; - using AComputeType = typename decltype(input_types.second)::type; - using BDataType = typename decltype(weight_types.first)::type; - using BComputeType = typename decltype(weight_types.second)::type; + using InDataType = typename decltype(input_types.first)::type; + using InComputeType = typename decltype(input_types.second)::type; + using WeiDataType = typename decltype(weight_types.first)::type; + using WeiComputeType = typename decltype(weight_types.second)::type; + using OutDataType = typename decltype(output_types.first)::type; + using OutComputeType = typename decltype(output_types.second)::type; using AccDataType = typename decltype(GetTensorAccumulationType())::type; - using EDataType = typename decltype(output_types.first)::type; - - // This is the "compute" type for output. - using CShuffleDataType = typename decltype(output_types.second)::type; - // Data types for the auxiliary tensors (e.g., bias). - using DsDataTypes = typename decltype(GetAuxiliaryTensorDataTypes())::type; + using DsDataType = typename decltype(GetAuxiliaryTensorDataTypes())::type; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index db741f2112..9ed1eebc3c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" @@ -37,7 +38,7 @@ struct BlockGemmSpec template consteval BlockGemmSpec SetBlockGemm() { - constexpr auto& BG = ALGORITHM.block_gemm; + constexpr auto& BG = ALGORITHM.block_gemm_pipeline; ck::BlockGemmPipelineScheduler scheduler; ck::BlockGemmPipelineVersion version; @@ -82,7 +83,7 @@ consteval ck::LoopScheduler SetLoopScheduler() template consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() { - constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; + constexpr auto pipeline_version = ALGORITHM.pipeline_version; using ck_pipeline = ck::PipelineVersion; switch(pipeline_version) { @@ -149,12 +150,30 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization; switch(specialization) { - case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default; - case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; - case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; - case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; - case ConvFwdSpecialization::ODD_C: return ck_conv_spec::OddC; - default: throw "Unknown ConvFwdSpecialization"; + case ConvSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; + case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; + case ConvSpecialization::ODD_C: return ck_conv_spec::OddC; + default: throw "Unsupported ConvSpecialization"; + } +} + +template +consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization +SetBwdWeightConvSpecialization() +{ + constexpr auto specialization = ALGORITHM.bwd_weight_specialization; + using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + switch(specialization) + { + case ConvSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; + case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvSpecialization::ODD_C: return ck_conv_spec::OddC; + case ConvSpecialization::FILTER_3x3: + throw "FILTER_3x3 is not supported for backward weight convolution."; + default: throw "Unsupported ConvSpecialization"; } } diff --git a/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp index 0246c805c2..f6fc2dbda8 100644 --- a/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp @@ -26,11 +26,11 @@ struct ReferenceFactory static constexpr auto kValidation = (internal::ValidateReferenceSignature(), 0); static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Types = internal::FwdConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; - using InDataType = typename Types::ADataType; - using WeiDataType = typename Types::BDataType; - using OutDataType = typename Types::EDataType; + using InDataType = typename Types::InDataType; + using WeiDataType = typename Types::WeiDataType; + using OutDataType = typename Types::OutDataType; struct Instance { @@ -125,9 +125,9 @@ struct ReferenceFactory // Direct Run method (simpler interface, direction-agnostic) template - static void Run(InPtrType input, - WeiPtrType weight, - OutPtrType output, + static void Run(InPtrType* input, + WeiPtrType* weight, + OutPtrType* output, int G, int N, int K, @@ -142,9 +142,9 @@ struct ReferenceFactory if constexpr(ConvDirectionIsForward) { ck_tile::naive_grouped_conv_fwd( - input, - weight, - output, + static_cast(input), + static_cast(weight), + static_cast(output), G, N, K, @@ -160,9 +160,9 @@ struct ReferenceFactory { ck_tile:: naive_grouped_conv_bwd_data( - input, - weight, - output, + static_cast(input), + static_cast(weight), + static_cast(output), G, N, K, @@ -179,19 +179,20 @@ struct ReferenceFactory ck_tile::naive_grouped_conv_bwd_weight(input, - weight, - output, - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); + OutDataType>( + static_cast(input), + static_cast(weight), + static_cast(output), + G, + N, + K, + C, + input_spatial, + filter_spatial, + output_spatial, + strides, + dilations, + left_pads); } } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp index fdbfa7c4e1..359b12c4a3 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp @@ -7,43 +7,52 @@ #pragma once #include "ck_tile/builder/reflect/conv_description.hpp" -#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/instance_to_conv_traits.hpp" namespace ck_tile::reflect { -/// @brief Factory function to create ConvDescription from a convolution instance type -/// @tparam Instance The convolution instance type (must have ConvTraits) -/// @return A ConvDescription object populated with the instance's configuration details -template +/// @brief Concept to check if an Instance type has conv traits +template +concept HasConvTraits = requires { + { conv::instance_to_conv_traits() }; +}; + +/// Factory function to create ConvDescription from a convolution instance type +/// Instance The convolution instance type +/// A ConvDescription object populated with the instance's configuration details +/// +/// TODO: Fix ConvDescription to just use the ConvTraits directly. +template + requires HasConvTraits conv::ConvDescription describe() { - using Traits = conv::ConvTraits; + const auto traits = conv::instance_to_conv_traits(); return conv::ConvDescription( conv::ConvSignatureInfo{ - .spatial_dim = Traits::spatial_dim, - .direction = Traits::direction, - .input_layout = Traits::layout[0], - .weight_layout = Traits::layout[1], - .output_layout = Traits::layout[2], - .data_type = Traits::data_type, - .input_element_op = Traits::input_element_op, - .weight_element_op = Traits::weight_element_op, - .output_element_op = Traits::output_element_op, + .spatial_dim = traits.spatial_dim, + .direction = traits.direction, + .input_layout = traits.layout[0], + .weight_layout = traits.layout[1], + .output_layout = traits.layout[2], + .data_type = traits.data_type, + .input_element_op = traits.input_element_op, + .weight_element_op = traits.weight_element_op, + .output_element_op = traits.output_element_op, }, conv::GemmAlgorithmInfo{ - .thread_block_size = Traits::thread_block_size, - .tile_dims = Traits::tile_dims, - .warp_gemm = Traits::warp_gemm, - .a_tile_transfer = Traits::a_tile_transfer, - .b_tile_transfer = Traits::b_tile_transfer, - .c_tile_transfer = Traits::c_tile_transfer, - .pipeline_version = Traits::pipeline_version, - .pipeline_scheduler = Traits::pipeline_scheduler, - .conv_specialization = Traits::conv_specialization, - .padding = Traits::gemm_padding, + .thread_block_size = traits.thread_block_size, + .tile_dims = traits.tile_dims, + .warp_gemm = traits.warp_gemm, + .a_tile_transfer = traits.a_tile_transfer, + .b_tile_transfer = traits.b_tile_transfer, + .c_tile_transfer = traits.c_tile_transfer, + .pipeline_version = traits.pipeline_version, + .pipeline_scheduler = traits.pipeline_scheduler, + .conv_specialization = traits.conv_specialization, + .padding = traits.gemm_padding, }, - []() { return reflect::instance_string(); }); + []() { return reflect::instance_string(); }); } } // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index 46c9bb488e..a7b6c60a73 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -63,10 +63,7 @@ struct GemmAlgorithmInfo OutputTileTransferInfo c_tile_transfer; builder::PipelineVersion pipeline_version; builder::PipelineScheduler pipeline_scheduler; - std::variant - conv_specialization; + builder::ConvSpecialization conv_specialization; builder::GemmPadding padding; }; diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index a91abd1a46..451a74be34 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -1,670 +1,109 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +// Runtime-accessible convolution kernel configuration data structure +// +// This file defines ConvTraits, a pure data structure that captures the complete +// configuration of a convolution kernel in a domain-specific abstraction, without +// requiring knowledge of the underlying kernel instance implementation details. +// +// ## Purpose and Design +// +// ConvTraits provides type erasure for convolution kernel configurations, allowing +// for reflection of convolution kernel objects. The struct represents kernel +// traits in terms of convolution-specific concepts for AMD GPUs rather than raw +// template parameters. +// +// ## Architecture and Usage +// +// ConvTraits sits at the center of the reflection system: +// +// 1. **Population**: Values are created by `instance_to_conv_traits()` template +// specializations that extract configuration from compile-time InstanceTraits +// +// 2. **Consumption**: Used by ConvDescription to provide human-readable descriptions +// of kernel configurations for debugging, logging, and documentation +// +// ## Structure Organization +// +// The struct separates kernel configuration into two logical categories: +// +// - **Signature Information**: Defines what the kernel computes (direction, layouts, +// data types, elementwise operations, specializations) +// +// - **Algorithm Information**: Defines how the kernel computes (thread block size, +// tile dimensions, memory access patterns, pipeline configuration) +// +// ## Evolution and Extensibility +// +// ConvTraits is designed to evolve through composition (not inheritance): +// +// - Currently supports XDL forward convolution kernels +// - Will extend to the other forward convolutions +// - Will be extended to cover backward data and backward weight convolutions +// - Will incorporate fusion operations and additional specializations +// - Uses std::optional and std::variant for optional/variant fields +// - Eventually will generalize to KernelTraits for GEMM, flash attention, etc. + #pragma once -#include -#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" -#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/utility/pipeline_enum.hpp" -#include "ck/utility/scheduler_enum.hpp" -#include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/reflect/conv_types.hpp" -#include "ck_tile/builder/reflect/instance_traits.hpp" -#include "ck_tile/builder/reflect/instance_traits_util.hpp" #include "ck_tile/builder/types.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/grouped_convolution.hpp" namespace ck_tile::reflect::conv { -// Forward convolution layout concept - checks for A/B/E layout types -template -concept HasFwdConvLayouts = requires { - typename T::ALayout; - typename T::BLayout; - typename T::ELayout; -}; - -// GEMM specialization concept - checks for kGemmSpecialization member -template -concept HasGemmSpec = requires { - { - T::kGemmSpecialization - } -> std::convertible_to; -}; - -// Data types concept - checks for ADataType member -template -concept HasDataTypes = requires { typename T::ADataType; }; - -// Elementwise operations concept - checks for A/B/CDE elementwise operation types -template -concept HasElementwiseOps = requires { - typename T::AElementwiseOperation; - typename T::BElementwiseOperation; - typename T::CDEElementwiseOperation; -}; - -// Tile parameters concept - checks for tile dimension and transfer members -template -concept HasTileParams = requires { - { T::kKPerBlock } -> std::convertible_to; - { T::kMPerBlock } -> std::convertible_to; - { T::kNPerBlock } -> std::convertible_to; - { T::kAK1 } -> std::convertible_to; - { T::kBK1 } -> std::convertible_to; - T::kCThreadClusterLengths; -}; - -// Comprehensive concept that checks if an instance has all XDL forward convolution traits -// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions -template -concept IsXdlFwdConv = HasFwdConvLayouts && HasGemmSpec && HasDataTypes && - HasElementwiseOps && HasTileParams; - -// Primary concept for checking if a type can be described -// Currently only forward convolutions are supported, but this can be extended -// in the future to include backward data and backward weight convolutions -template -concept HasConvTraits = IsXdlFwdConv>; - -// Helper metafunctions to convert from ck enums to builder enums - -/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. -/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. -/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5). -/// @details This function maps CK's block GEMM pipeline version identifiers to the -/// builder framework's standardized pipeline version enum. The pipeline version -/// determines the strategy used for data movement and computation overlap in the -/// GEMM kernel's main loop. -template -constexpr auto convert_pipeline_version() +// Runtime data structure representing a convolution kernel's complete configuration +// +// This pure data struct (no template parameters, no static members) provides +// type erasure for convolution kernel configurations. It can hold the configuration +// from any convolution kernel instance, enabling runtime storage, comparison, and +// manipulation of kernel properties. +// +// The struct is populated by `instance_to_conv_traits()` template specializations +// that extract compile-time configuration from InstanceTraits and convert it to +// this standardized runtime representation. +// +// Members are organized into two categories: +// - **Signature Information**: Defines the computational interface (what to compute) +// - **Algorithm Information**: Defines the implementation strategy (how to compute) +// +// Note: This struct will evolve to support additional convolution variants and +// eventually generalize to other kernel types through composition. +// +// There is a lot we still need to do: +// +// TODO: Generalize type support for all tensors and accumulator. +// TODO: Describe all tensros. +// TODO: Include the full generalization of the signature from the input schema. +// TODO: Include the full generalization of the algorithm from the input schema. +struct ConvTraits { - using enum ck::BlockGemmPipelineVersion; - using enum builder::PipelineVersion; - - switch(ck_ver) - { - case v1: return V1; - case v2: return V2; - case v3: return V3; - case v4: return V4; - case v5: return V5; - } -} - -/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. -/// @tparam ck_ver The CK PipelineVersion enum value to convert. -/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY). -/// @details This function maps CK's general pipeline version identifiers to the -/// builder framework's standardized pipeline version enum. Note that this overload -/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion -/// variant, including support for specialized weight-only pipelines. -template -constexpr auto convert_pipeline_version() -{ - using enum ck::PipelineVersion; - using enum builder::PipelineVersion; - - switch(ck_ver) - { - case v1: return V1; - case v2: return V2; - case v4: return V4; - case weight_only: return WEIGHT_ONLY; - } -} - -/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. -/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert. -/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE). -/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the -/// builder framework's standardized scheduler enum. The scheduler determines how work -/// is distributed and synchronized within and across wavefronts during pipeline execution. -/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates -/// across multiple wavefronts. -template -constexpr auto convert_pipeline_scheduler() -{ - using enum ck::BlockGemmPipelineScheduler; - using enum builder::PipelineScheduler; - - switch(ck_sched) - { - case Intrawave: return INTRAWAVE; - case Interwave: return INTERWAVE; - } -} - -/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. -/// @tparam ck_sched The CK LoopScheduler enum value to convert. -/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE). -/// @details This function maps CK's loop scheduler identifiers to the builder framework's -/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of -/// the main computational loop are scheduled across threads. DEFAULT uses the standard -/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved -/// performance in certain scenarios. -template -constexpr auto convert_pipeline_scheduler() -{ - using enum ck::LoopScheduler; - using enum builder::PipelineScheduler; - - switch(ck_sched) - { - case Default: return DEFAULT; - case Interwave: return INTERWAVE; - } -} - -// Helper metafunctions to derive signature information from Instance types - -/// @brief Helper function to report unsupported convolution direction with a clear error message. -template -[[noreturn]] consteval void report_unsupported_conv_direction_error() -{ - throw "Unsupported convolution direction detected!\n" - "The kernel instance does not have a recognized convolution specialization.\n" - "Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or " - "kConvBwdWeightSpecialization.\n" - "Please verify that your kernel instance is properly configured."; -} - -/// @brief Derives the convolution direction from a device kernel `Instance` type. -/// @tparam Instance The device kernel instance type. -/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT). -template -constexpr builder::ConvDirection conv_direction() -{ - using InstTraits = InstanceTraits; - - if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) - return builder::ConvDirection::FORWARD; - else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) - return builder::ConvDirection::BACKWARD_DATA; - else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) - return builder::ConvDirection::BACKWARD_WEIGHT; - else - { - report_unsupported_conv_direction_error(); - return builder::ConvDirection::FORWARD; // Unreachable - } -} - -/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. -/// @tparam Instance The device kernel instance type. -/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or -/// `builder::ConvBwdWeightSpecialization` enum value. -template -constexpr auto conv_spec() -{ - using InstTraits = InstanceTraits; - - if constexpr(requires { InstTraits::kConvForwardSpecialization; }) - { - using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; - using enum builder::ConvFwdSpecialization; - - switch(InstTraits::kConvForwardSpecialization) - { - case Default: return DEFAULT; - case Filter1x1Pad0: return FILTER_1X1_PAD0; - case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; - case Filter3x3: return FILTER_3x3; - case OddC: return ODD_C; - } - } - else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) - { - using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; - using enum builder::ConvBwdDataSpecialization; - - switch(InstTraits::kConvBwdDataSpecialization) - { - case Default: return DEFAULT; - case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; - } - } - else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) - { - using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; - using enum builder::ConvBwdWeightSpecialization; - - switch(InstTraits::kConvBwdWeightSpecialization) - { - case Default: return DEFAULT; - case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; - case Filter1x1Pad0: return FILTER_1X1_PAD0; - case OddC: return ODD_C; - } - } -} - -// Helper variable template to check if CK layout enums match -template -inline constexpr bool layouts_are = - std::is_same_v && std::is_same_v && std::is_same_v; - -/// @brief Helper function to report unsupported layout combinations with a clear error message. -/// @details This consteval function is designed to fail at compile time with a descriptive -/// error message when an unsupported layout combination is encountered. -template -[[noreturn]] consteval void report_unsupported_layout_error() -{ - // This will produce a compile-time error with the exception message - throw "Unsupported convolution layout combination detected!\n" - "The combination of ALayout, BLayout, and ELayout template parameters\n" - "is not recognized for the given spatial dimension.\n" - "Please verify that your convolution instance uses a supported layout configuration.\n" - "Check the conv_layout() function for the list of supported layout combinations."; -} - -/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. -/// @tparam Instance The device kernel instance type. -/// @return An std::array corresponding to the tensor layouts: -/// index 0 -> Input layout -/// index 1 -> Weight layout -/// index 2 -> Output layout -template -constexpr auto conv_layout() - requires HasFwdConvLayouts> -{ - // Helper lambda to construct layout array - auto layouts = [](auto... Ls) { return std::array{Ls...}; }; - - using A = typename InstanceTraits::ALayout; - using B = typename InstanceTraits::BLayout; - using E = typename InstanceTraits::ELayout; - namespace ctl = ck::tensor_layout::convolution; - using enum builder::TensorLayout; - - switch(InstanceTraits::kSpatialDim) - { - case 1: - if constexpr(layouts_are) - return layouts(GNWC, GKXC, GNWK); - if constexpr(layouts_are) - return layouts(GNWC, GKXC, GNWK); - if constexpr(layouts_are) - return layouts(NWGC, GKXC, NWGK); - if constexpr(layouts_are) - return layouts(NGCW, GKXC, NGKW); - if constexpr(layouts_are) - return layouts(NGCW, GKCX, NGKW); - break; - case 2: - if constexpr(layouts_are) - return layouts(GNHWC, GKYXC, GNHWK); - if constexpr(layouts_are) - return layouts(GNHWC, GKYXC, GNHWK); - if constexpr(layouts_are) - return layouts(NHWGC, GKYXC, NHWGK); - if constexpr(layouts_are) - return layouts(NHWGC, GKYXC, NHWGK); - if constexpr(layouts_are) - return layouts(NGCHW, GKYXC, NGKHW); - if constexpr(layouts_are) - return layouts(NGCHW, GKCYX, NGKHW); - break; - case 3: - if constexpr(layouts_are) - return layouts(GNDHWC, GKZYXC, GNDHWK); - if constexpr(layouts_are) - return layouts(GNDHWC, GKZYXC, GNDHWK); - if constexpr(layouts_are) - return layouts(NDHWGC, GKZYXC, NDHWGK); - if constexpr(layouts_are) - return layouts(NGCDHW, GKZYXC, NGKDHW); - if constexpr(layouts_are) - return layouts(NGCDHW, GKCZYX, NGKDHW); - break; - } - - // If we reach here, the layout combination is not supported - // Call consteval function to trigger a compile-time error with a clear message - report_unsupported_layout_error::kSpatialDim>(); - - // This return is unreachable but needed to satisfy the compiler - return layouts(GNHWC, GKYXC, GNHWK); -} - -/// @brief Helper function to report unsupported data type with a clear error message. -template -[[noreturn]] consteval void report_unsupported_data_type_error() -{ - throw "Unsupported data type detected!\n" - "The ADataType is not recognized.\n" - "Supported types are: ck::half_t (FP16), ck::Tuple (FP16_FP16), " - "ck::bhalf_t (BF16), ck::Tuple (BF16_BF16), float (FP32), " - "ck::Tuple (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t " - "(BF8), " - "int8_t (I8), ck::Tuple (I8_I8), uint8_t (U8).\n" - "Please verify that your kernel instance uses a supported data type."; -} - -/// @brief Derives the data type from a device kernel `Instance` type. -/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8). -template -constexpr builder::DataType conv_data_type() - requires HasDataTypes> -{ - using InstTraits = InstanceTraits; - using ADataType = typename InstTraits::ADataType; - using enum builder::DataType; - - if constexpr(std::is_same_v) - return FP16; - else if constexpr(std::is_same_v>) - return FP16_FP16; - else if constexpr(std::is_same_v) - return BF16; - else if constexpr(std::is_same_v>) - return BF16_BF16; - else if constexpr(std::is_same_v) - return FP32; - else if constexpr(std::is_same_v>) - return FP32_FP32; - else if constexpr(std::is_same_v) - return FP64; - else if constexpr(std::is_same_v) - return FP8; - else if constexpr(std::is_same_v) - return BF8; - else if constexpr(std::is_same_v) - return BF8; - else if constexpr(std::is_same_v) - return I8; - else if constexpr(std::is_same_v>) - return I8_I8; - else if constexpr(std::is_same_v) - return U8; - else - { - report_unsupported_data_type_error(); - return FP32; // Unreachable - } -} - -/// @brief Helper function to report unsupported elementwise operation with a clear error message. -template -[[noreturn]] consteval void report_unsupported_elementwise_op_error() -{ - throw "Unsupported elementwise operation detected!\n" - "The elementwise operation type is not recognized.\n" - "Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, " - "BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, " - "ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, " - "UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, " - "Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, " - "UnaryConvert.\n" - "Please verify that your kernel instance uses a supported elementwise operation."; -} - -/// @brief Derives the elementwise operation from op type. -/// @tparam ElementwiseOp Elementwise operation functor type. -/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation. -template -constexpr builder::ElementwiseOperation elementwise_op() -{ - using enum builder::ElementwiseOperation; - constexpr std::string_view name = detail::elementwise_op_name(); - - if constexpr(detail::case_insensitive_equal(name, "AddClamp")) - return ADD_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd")) - return ADD_RELU_ADD; - else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) - return BIAS_BNORM_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Bilinear")) - return BILINEAR; - else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp")) - return BIAS_BNORM_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Clamp")) - return CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale")) - return CONV_INVSCALE; - else if constexpr(detail::case_insensitive_equal(name, "ConvScale")) - return CONV_SCALE; - else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd")) - return CONV_SCALE_ADD; - else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu")) - return CONV_SCALE_RELU; - else if constexpr(detail::case_insensitive_equal(name, "Scale")) - return SCALE; - else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd")) - return SCALE_ADD; - else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) - return PASS_THROUGH; - else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) - return SCALEADD_SCALEADD_RELU; - else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp")) - return DYNAMIC_UNARY_OP; - else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp")) - return UNARY_COMBINED_OP; - else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp")) - return ACTIVATION_MUL2_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp")) - return ACTIVATION_MUL_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp")) - return ADD_ACTIVATION_MUL_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp")) - return ADD_ACTIVATION_MUL2_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp")) - return ADD_MUL_ACTIVATION_MUL_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp")) - return ADD_MUL2_ACTIVATION_MUL_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert")) - return UNARY_CONVERT; - else if constexpr(detail::case_insensitive_equal(name, "Logistic")) - return LOGISTIC; - else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu")) - return CLIPPED_RELU; - else if constexpr(detail::case_insensitive_equal(name, "Swish")) - return SWISH; - else if constexpr(detail::case_insensitive_equal(name, "Elu")) - return ELU; - else if constexpr(detail::case_insensitive_equal(name, "Power")) - return POWER; - else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu")) - return LEAKY_RELU; - else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs")) - return UNARY_ABS; - else if constexpr(detail::case_insensitive_equal(name, "Relu")) - return RELU; - else if constexpr(detail::case_insensitive_equal(name, "SoftRelu")) - return SOFT_RELU; - else if constexpr(detail::case_insensitive_equal(name, "Sigmoid")) - return SIGMOID; - else if constexpr(detail::case_insensitive_equal(name, "TanH")) - return TANH; - else if constexpr(detail::case_insensitive_equal(name, "Gelu")) - return GELU; - else if constexpr(detail::case_insensitive_equal(name, "Silu")) - return SILU; - else - { - report_unsupported_elementwise_op_error(); - return PASS_THROUGH; // Unreachable - } -} - -/// @brief Derives a gemm padding from a kernel instance type. -/// @tparam Instance - A Device Kernel object type. -/// @return A `builder::GemmPadding` enum value corresponding to kernel padding. -template -constexpr builder::GemmPadding gemm_spec() - requires HasGemmSpec> -{ - using InstTraits = InstanceTraits; - using enum builder::GemmPadding; - using enum ck::tensor_operation::device::GemmSpecialization; - - constexpr auto gemm_spec = InstTraits::kGemmSpecialization; - - switch(gemm_spec) - { - case Default: return DEFAULT; - case MPadding: return M_PADDING; - case NPadding: return N_PADDING; - case KPadding: return K_PADDING; - case MNPadding: return MN_PADDING; - case MKPadding: return MK_PADDING; - case NKPadding: return NK_PADDING; - case MNKPadding: return MNK_PADDING; - case OPadding: return O_PADDING; - case MOPadding: return MO_PADDING; - case NOPadding: return NO_PADDING; - case KOPadding: return KO_PADDING; - case MNOPadding: return MNO_PADDING; - case MKOPadding: return MKO_PADDING; - case NKOPadding: return NKO_PADDING; - case MNKOPadding: return MNKO_PADDING; - } -} - -/// @brief Primary template for extracting convolution traits. -/// @details This struct is the main entry point for reflecting on a convolution -/// kernel's properties. It is specialized to handle different kinds of input types. -template -struct ConvTraits; - -/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`. -/// @details This is the primary specialization used to extract a comprehensive -/// set of traits directly from a fully-formed device kernel `Instance` type. -/// It uses `InstanceTraits` to access the kernel's template parameters. -template - requires IsXdlFwdConv> -struct ConvTraits -{ - using InstTraits = InstanceTraits; - // --- Signature Information --- - /// @brief The number of spatial dimensions in the convolution (1, 2, or 3). - static constexpr int spatial_dim = InstTraits::kSpatialDim; - /// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight). - static constexpr builder::ConvDirection direction = conv_direction(); - /// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK). - static constexpr auto layout = conv_layout(); - /// @brief The primary data type used in the computation (e.g., FP16, FP32). - static constexpr builder::DataType data_type = conv_data_type(); + int spatial_dim; + builder::ConvDirection direction; + std::array layout; // [input, weight, output] + builder::DataType data_type; - static constexpr builder::ElementwiseOperation input_element_op = - elementwise_op(); - static constexpr builder::ElementwiseOperation weight_element_op = - elementwise_op(); - static constexpr builder::ElementwiseOperation output_element_op = - elementwise_op(); + builder::ElementwiseOperation input_element_op; + builder::ElementwiseOperation weight_element_op; + builder::ElementwiseOperation output_element_op; - /// @brief The GEMM specialization used by the kernel - padding - static constexpr auto gemm_padding = gemm_spec(); - /// @brief The convolution-specific specialization (e.g., Default, 1x1). - static constexpr auto conv_specialization = conv_spec(); + builder::GemmPadding gemm_padding; + builder::ConvSpecialization conv_specialization; // --- Algorithm Information --- - /// @brief The total number of threads in a thread block (workgroup). - static constexpr int thread_block_size = InstTraits::kBlockSize; - /// @brief The dimensions of the data tile processed by the thread block. - static constexpr DataTileInfo tile_dims = { - .m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock}; + int thread_block_size; + DataTileInfo tile_dims; - /// @brief Configuration for the A-matrix (input) tile transfer. - static constexpr InputTileTransferInfo a_tile_transfer = { - .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, - .m_or_n = InstTraits::kMPerBlock, - .k1 = InstTraits::kAK1}, - .transfer_params = {.k1 = InstTraits::kAK1, - .thread_cluster_dims = InstTraits::kAThreadClusterLengths, - .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, - .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, - .src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kABlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}; + InputTileTransferInfo a_tile_transfer; + InputTileTransferInfo b_tile_transfer; - /// @brief Configuration for the B-matrix (weights) tile transfer. - static constexpr InputTileTransferInfo b_tile_transfer = { - .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, - .m_or_n = InstTraits::kNPerBlock, - .k1 = InstTraits::kBK1}, - .transfer_params = {.k1 = InstTraits::kBK1, - .thread_cluster_dims = InstTraits::kBThreadClusterLengths, - .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, - .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, - .src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kBBlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; + WarpGemmParams warp_gemm; - /// @brief Parameters for the warp-level GEMM computation. - static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL, - .gemm_n = InstTraits::kNPerXDL, - .m_iter = InstTraits::kMXdlPerWave, - .n_iter = InstTraits::kNXdlPerWave}; + OutputTileTransferInfo c_tile_transfer; - /// @brief Configuration for the C-matrix (output) tile transfer. - static constexpr OutputTileTransferInfo c_tile_transfer = { - .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, - .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, - .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], - InstTraits::kCThreadClusterLengths[1], - InstTraits::kCThreadClusterLengths[2], - InstTraits::kCThreadClusterLengths[3]}, - .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; - - /// @brief Helper to safely get the pipeline version. - /// @details This is only available for some convolutions (e.g., forward). - /// If not present in `InstanceTraits`, it returns a default value. - template - static constexpr auto get_pipeline_version() - { - if constexpr(requires { T::kPipelineVersion; }) - { - return convert_pipeline_version(); - } - else - { - // Return a default or indicate not available - return builder::PipelineVersion::V1; - } - } - - /// @brief The block GEMM pipeline version used by the kernel. - static constexpr auto pipeline_version = get_pipeline_version(); - - /// @brief Helper to safely get the pipeline scheduler. - /// @details This is only available for some convolutions. If not present - /// in `InstanceTraits`, it returns a default value. - template - static constexpr auto get_pipeline_scheduler() - { - if constexpr(requires { T::kPipelineScheduler; }) - { - return convert_pipeline_scheduler(); - } - else if constexpr(requires { T::kLoopScheduler; }) - { - return convert_pipeline_scheduler(); - } - else - { - // Return a default or indicate not available - return builder::PipelineScheduler::DEFAULT; - } - } - - /// @brief The pipeline scheduler used by the kernel. - static constexpr auto pipeline_scheduler = get_pipeline_scheduler(); + builder::PipelineVersion pipeline_version; + builder::PipelineScheduler pipeline_scheduler; }; } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000..cdd238f36a --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = {.m = InstTraits::kMPerBlock, + .n = InstTraits::kNPerBlock, + .k = InstTraits::kKPerBlock}, + .a_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}, + .b_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}, + .warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}, + .c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle = + InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = + InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000..28c43c342f --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = {.m = InstTraits::kMPerBlock, + .n = InstTraits::kNPerBlock, + .k = InstTraits::kKPerBlock}, + .a_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}, + .b_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}, + .warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}, + .c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle = + InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = + InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp new file mode 100644 index 0000000000..c4bed850eb --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = {.m = InstTraits::kMPerBlock, + .n = InstTraits::kNPerBlock, + .k = InstTraits::kKPerBlock}, + .a_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}, + .b_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}, + .warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}, + .c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle = + InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = + InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp new file mode 100644 index 0000000000..46c196e95a --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp @@ -0,0 +1,739 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/pipeline_enum.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/reflect/conv_types.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_util.hpp" +#include "ck_tile/builder/types.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +/// @file conv_traits_helpers.hpp +/// @brief Helper utilities for extracting convolution traits from kernel instances +/// +/// This file provides compile-time reflection utilities to extract configuration +/// information from CK convolution kernel instances and convert them to the builder +/// framework's standardized representation. +/// +/// ## Organization +/// +/// The file is organized into the following sections: +/// +/// 1. **Enum Conversions**: Functions to convert CK enums to builder enums +/// - Pipeline version conversions (BlockGemmPipelineVersion, PipelineVersion) +/// - Pipeline scheduler conversions (BlockGemmPipelineScheduler, LoopScheduler) +/// +/// 2. **Signature Derivation**: Functions to extract signature information from instances +/// - Convolution direction (conv_direction) +/// - Convolution specialization (conv_spec) +/// - Tensor layouts (conv_layout) +/// - Data types (conv_data_type) +/// - Elementwise operations (elementwise_op) +/// - GEMM padding (gemm_spec) +/// +/// 3. **Pipeline Configuration Helpers**: Safe extraction of pipeline parameters +/// - Pipeline version extraction (get_pipeline_version) +/// - Pipeline scheduler extraction (get_pipeline_scheduler) +/// +/// ## Error Handling Strategy +/// +/// This file uses a specific error handling pattern for compile-time errors: +/// - **consteval functions with throw**: Used for error reporting to ensure SFINAE doesn't +/// silently ignore errors. The thrown string becomes part of the compiler error message, +/// providing clear context to developers. +/// - **DO NOT replace with static_assert**: static_assert is silently ignored during SFINAE, +/// which would hide errors instead of reporting them clearly. +/// +/// @example +/// ```cpp +/// using Instance = +/// ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<...>; +/// +/// // Extract convolution direction +/// constexpr auto dir = conv_direction(); +/// +/// // Extract data type +/// constexpr auto dtype = conv_data_type(); +/// +/// // Extract layout configuration +/// constexpr auto layouts = conv_layout(); +/// ``` + +namespace ck_tile::reflect::conv { + +// ============================================================================ +// SECTION 1: ENUM CONVERSIONS +// ============================================================================ + +/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value. +/// @details This function maps CK's block GEMM pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. The pipeline version +/// determines the strategy used for data movement and computation overlap in the +/// GEMM kernel's main loop. +/// +/// Supported mappings: +/// - v1 -> V1 +/// - v2 -> V2 +/// - v3 -> V3 +/// - v4 -> V4 +/// - v5 -> V5 +template +constexpr builder::PipelineVersion convert_pipeline_version() +{ + using enum ck::BlockGemmPipelineVersion; + using enum builder::PipelineVersion; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v3: return V3; + case v4: return V4; + case v5: return V5; + } +} + +/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK PipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value. +/// @details This function maps CK's general pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. Note that this overload +/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion +/// variant, including support for specialized weight-only pipelines. +/// +/// Supported mappings: +/// - v1 -> V1 +/// - v2 -> V2 +/// - v4 -> V4 +/// - weight_only -> WEIGHT_ONLY +template +constexpr builder::PipelineVersion convert_pipeline_version() +{ + using enum ck::PipelineVersion; + using enum builder::PipelineVersion; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v4: return V4; + case weight_only: return WEIGHT_ONLY; + } +} + +/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value. +/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the +/// builder framework's standardized scheduler enum. The scheduler determines how work +/// is distributed and synchronized within and across wavefronts during pipeline execution. +/// +/// Supported mappings: +/// - Intrawave -> INTRAWAVE: Scheduling within a single wavefront +/// - Interwave -> INTERWAVE: Coordination across multiple wavefronts +template +constexpr builder::PipelineScheduler convert_pipeline_scheduler() +{ + using enum ck::BlockGemmPipelineScheduler; + using enum builder::PipelineScheduler; + + switch(ck_sched) + { + case Intrawave: return INTRAWAVE; + case Interwave: return INTERWAVE; + } +} + +/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK LoopScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value. +/// @details This function maps CK's loop scheduler identifiers to the builder framework's +/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of +/// the main computational loop are scheduled across threads. +/// +/// Supported mappings: +/// - Default -> DEFAULT: Standard scheduling strategy +/// - Interwave -> INTERWAVE: Cross-wavefront coordination for improved performance +template +constexpr builder::PipelineScheduler convert_pipeline_scheduler() +{ + using enum ck::LoopScheduler; + using enum builder::PipelineScheduler; + + switch(ck_sched) + { + case Default: return DEFAULT; + case Interwave: return INTERWAVE; + } +} + +// ============================================================================ +// SECTION 2: SIGNATURE DERIVATION FUNCTIONS +// ============================================================================ + +// ---------------------------------------------------------------------------- +// Convolution Direction +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported convolution direction with a clear error message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_conv_direction_error() +{ + throw "Unsupported convolution direction detected!\n" + "The kernel instance does not have a recognized convolution specialization.\n" + "Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or " + "kConvBwdWeightSpecialization.\n" + "Please verify that your kernel instance is properly configured."; +} + +/// @brief Derives the convolution direction from a device kernel Instance type. +/// @tparam Instance The device kernel instance type. +/// @return A builder::ConvDirection enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT). +/// @details This function inspects the Instance's InstanceTraits to determine which +/// convolution specialization field is present, and returns the corresponding direction. +/// +/// The function checks for the presence of: +/// - kConvForwardSpecialization -> FORWARD +/// - kConvBwdDataSpecialization -> BACKWARD_DATA +/// - kConvBwdWeightSpecialization -> BACKWARD_WEIGHT +/// +/// @note Compilation will fail with a clear error message if the instance does not +/// have a recognized convolution specialization field. +template +constexpr builder::ConvDirection conv_direction() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) + return builder::ConvDirection::FORWARD; + else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) + return builder::ConvDirection::BACKWARD_DATA; + else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) + return builder::ConvDirection::BACKWARD_WEIGHT; + else + { + report_unsupported_conv_direction_error(); + return builder::ConvDirection::FORWARD; // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// Convolution Specialization +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported convolution specialization with a clear error +/// message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_conv_spec_error() +{ + throw "Unsupported convolution specialization detected!\n" + "The kernel instance does not have a recognized convolution specialization field.\n" + "Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or " + "kConvBwdWeightSpecialization.\n" + "Please verify that your kernel instance is properly configured."; +} + +/// @brief Derives the convolution-specific specialization from a device kernel Instance type. +/// @tparam Instance The device kernel instance type. +/// @return A builder::ConvSpecialization enum value. +/// @details This function extracts the specialization enum from the Instance's InstanceTraits +/// and converts it to the corresponding builder framework enum. +/// +/// For forward convolutions, supported specializations include: +/// - Default, Filter1x1Pad0, Filter1x1Stride1Pad0, Filter3x3, OddC +/// +/// For backward data convolutions: +/// - Default, Filter1x1Stride1Pad0 +/// +/// For backward weight convolutions: +/// - Default, Filter1x1Stride1Pad0, Filter1x1Pad0, OddC +template +constexpr builder::ConvSpecialization conv_spec() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { InstTraits::kConvForwardSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + using enum builder::ConvSpecialization; + + switch(InstTraits::kConvForwardSpecialization) + { + case Default: return DEFAULT; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter3x3: return FILTER_3x3; + case OddC: return ODD_C; + } + } + else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + using enum builder::ConvSpecialization; + + switch(InstTraits::kConvBwdDataSpecialization) + { + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + } + } + else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + using enum builder::ConvSpecialization; + + switch(InstTraits::kConvBwdWeightSpecialization) + { + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case OddC: return ODD_C; + } + } + else + { + report_unsupported_conv_spec_error(); + return builder::ConvSpecialization::DEFAULT; // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// Tensor Layouts +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported layout combinations with a clear error message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_layout_error() +{ + throw "Unsupported convolution layout combination detected!\n" + "The combination of ALayout, BLayout, and ELayout template parameters\n" + "is not recognized for the given spatial dimension.\n" + "Please verify that your convolution instance uses a supported layout configuration.\n" + "Check the conv_layout() function for the list of supported layout combinations."; +} + +/// @brief Derives the grouped convolution layout from a device kernel Instance type. +/// @tparam Instance The device kernel instance type. +/// @return An std::array containing the layouts for: +/// - [0] Input tensor layout +/// - [1] Weight tensor layout +/// - [2] Output tensor layout +/// @details This function examines the Instance's ALayout, BLayout, and ELayout types +/// along with the spatial dimension to determine the appropriate layout configuration. +/// +/// Supported layout combinations vary by spatial dimension (1D, 2D, 3D convolutions). +/// Common patterns include GNHWC (grouped, batch, spatial, channels) and variants. +/// +/// @note Compilation will fail with a clear error message if the layout combination +/// is not supported for the given spatial dimension. +/// +/// TODO: If we don't check for supported layouts, this function can be simplified. +template +constexpr std::array conv_layout() +{ + using InstTraits = InstanceTraits; + using A = typename InstTraits::ALayout; + using B = typename InstTraits::BLayout; + using E = typename InstTraits::ELayout; + namespace ctl = ck::tensor_layout::convolution; + using enum builder::TensorLayout; + + // Helper to check if layouts match expected types + constexpr auto layouts_match = []() { + return std::is_same_v && std::is_same_v && std::is_same_v; + }; + + // Helper to construct layout array + constexpr auto make_layouts = [](auto in, auto weight, auto out) { + return std::array{in, weight, out}; + }; + + constexpr int spatial_dim = InstTraits::kSpatialDim; + + if constexpr(spatial_dim == 1) + { + if constexpr(layouts_match.template operator()()) + return make_layouts(GNWC, GKXC, GNWK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(GNWC, GKXC, GNWK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NWGC, GKXC, NWGK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NGCW, GKXC, NGKW); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NGCW, GKCX, NGKW); + else + { + report_unsupported_layout_error(); + return make_layouts(GNWC, GKXC, GNWK); // Unreachable + } + } + else if constexpr(spatial_dim == 2) + { + if constexpr(layouts_match.template operator()()) + return make_layouts(GNHWC, GKYXC, GNHWK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(GNHWC, GKYXC, GNHWK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NHWGC, GKYXC, NHWGK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NHWGC, GKYXC, NHWGK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NGCHW, GKYXC, NGKHW); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NGCHW, GKCYX, NGKHW); + else + { + report_unsupported_layout_error(); + return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable + } + } + else if constexpr(spatial_dim == 3) + { + if constexpr(layouts_match.template operator()()) + return make_layouts(GNDHWC, GKZYXC, GNDHWK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(GNDHWC, GKZYXC, GNDHWK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(NDHWGC, GKZYXC, NDHWGK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(NGCDHW, GKZYXC, NGKDHW); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(NGCDHW, GKCZYX, NGKDHW); + else + { + report_unsupported_layout_error(); + return make_layouts(GNDHWC, GKZYXC, GNDHWK); // Unreachable + } + } + else + { + report_unsupported_layout_error(); + return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// Data Types +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported data type with a clear error message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_data_type_error() +{ + throw "Unsupported data type detected!\n" + "The ADataType is not recognized.\n" + "Supported types are: ck::half_t (FP16), ck::Tuple (FP16_FP16), " + "ck::bhalf_t (BF16), ck::Tuple (BF16_BF16), float (FP32), " + "ck::Tuple (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t " + "(BF8), " + "int8_t (I8), ck::Tuple (I8_I8), uint8_t (U8).\n" + "Please verify that your kernel instance uses a supported data type."; +} + +/// @brief Derives the data type from a device kernel Instance type. +/// @tparam Instance The device kernel instance type. +/// @return A builder::DataType enum value representing the input data type. +/// @details This function examines the Instance's ADataType to determine the data type +/// used for the input tensor. The function supports various floating-point and integer +/// types, including tuple types for mixed-precision operations. +/// +/// Supported data types include: +/// - FP16 (ck::half_t) +/// - FP16_FP16 (ck::Tuple) +/// - BF16 (ck::bhalf_t) +/// - BF16_BF16 (ck::Tuple) +/// - FP32 (float) +/// - FP32_FP32 (ck::Tuple) +/// - FP64 (double) +/// - FP8 (ck::f8_t) +/// - BF8 (ck::bf8_fnuz_t, ck::bf8_ocp_t) +/// - I8 (int8_t) +/// - I8_I8 (ck::Tuple) +/// - U8 (uint8_t) +template +constexpr builder::DataType conv_data_type() +{ + using InstTraits = InstanceTraits; + using ADataType = typename InstTraits::ADataType; + using enum builder::DataType; + + if constexpr(std::is_same_v) + return FP16; + else if constexpr(std::is_same_v>) + return FP16_FP16; + else if constexpr(std::is_same_v) + return BF16; + else if constexpr(std::is_same_v>) + return BF16_BF16; + else if constexpr(std::is_same_v) + return FP32; + else if constexpr(std::is_same_v>) + return FP32_FP32; + else if constexpr(std::is_same_v) + return FP64; + else if constexpr(std::is_same_v) + return FP8; + else if constexpr(std::is_same_v) + return BF8; + else if constexpr(std::is_same_v) + return BF8; + else if constexpr(std::is_same_v) + return I8; + else if constexpr(std::is_same_v>) + return I8_I8; + else if constexpr(std::is_same_v) + return U8; + else + { + report_unsupported_data_type_error(); + return FP32; // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// Elementwise Operations +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported elementwise operation with a clear error message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_elementwise_op_error() +{ + throw "Unsupported elementwise operation detected!\n" + "The elementwise operation type is not recognized.\n" + "Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, " + "BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, " + "ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, " + "UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, " + "Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, " + "UnaryConvert.\n" + "Please verify that your kernel instance uses a supported elementwise operation."; +} + +/// @brief Derives the elementwise operation from an operation functor type. +/// @tparam ElementwiseOp Elementwise operation functor type. +/// @return A builder::ElementwiseOperation enum value corresponding to the operation. +/// @details This function uses the operation's type name to determine which elementwise +/// operation is being used. The comparison is case-insensitive. +/// +/// Supported operations include: +/// - Activation operations: Relu, Sigmoid, Tanh, Gelu, Silu, Elu, Swish, etc. +/// - Scaling operations: Scale, ScaleAdd, ConvScale, ConvScaleAdd, etc. +/// - Clamping operations: Clamp, AddClamp, etc. +/// - Combined operations: Add_Activation_Mul_Clamp, etc. +/// - Utility operations: PassThrough, UnaryConvert, etc. +/// +/// TODO: Consider changing this to direct checks on the types, not strings. +template +constexpr builder::ElementwiseOperation elementwise_op() +{ + using enum builder::ElementwiseOperation; + constexpr std::string_view name = detail::elementwise_op_name(); + + if constexpr(detail::case_insensitive_equal(name, "AddClamp")) + return ADD_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd")) + return ADD_RELU_ADD; + else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) + return BIAS_BNORM_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Bilinear")) + return BILINEAR; + else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp")) + return BIAS_BNORM_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Clamp")) + return CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale")) + return CONV_INVSCALE; + else if constexpr(detail::case_insensitive_equal(name, "ConvScale")) + return CONV_SCALE; + else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd")) + return CONV_SCALE_ADD; + else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu")) + return CONV_SCALE_RELU; + else if constexpr(detail::case_insensitive_equal(name, "Scale")) + return SCALE; + else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd")) + return SCALE_ADD; + else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) + return PASS_THROUGH; + else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) + return SCALEADD_SCALEADD_RELU; + else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp")) + return DYNAMIC_UNARY_OP; + else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp")) + return UNARY_COMBINED_OP; + else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp")) + return ACTIVATION_MUL2_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp")) + return ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp")) + return ADD_ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp")) + return ADD_ACTIVATION_MUL2_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp")) + return ADD_MUL_ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp")) + return ADD_MUL2_ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert")) + return UNARY_CONVERT; + else if constexpr(detail::case_insensitive_equal(name, "Logistic")) + return LOGISTIC; + else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu")) + return CLIPPED_RELU; + else if constexpr(detail::case_insensitive_equal(name, "Swish")) + return SWISH; + else if constexpr(detail::case_insensitive_equal(name, "Elu")) + return ELU; + else if constexpr(detail::case_insensitive_equal(name, "Power")) + return POWER; + else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu")) + return LEAKY_RELU; + else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs")) + return UNARY_ABS; + else if constexpr(detail::case_insensitive_equal(name, "Relu")) + return RELU; + else if constexpr(detail::case_insensitive_equal(name, "SoftRelu")) + return SOFT_RELU; + else if constexpr(detail::case_insensitive_equal(name, "Sigmoid")) + return SIGMOID; + else if constexpr(detail::case_insensitive_equal(name, "TanH")) + return TANH; + else if constexpr(detail::case_insensitive_equal(name, "Gelu")) + return GELU; + else if constexpr(detail::case_insensitive_equal(name, "Silu")) + return SILU; + else + { + report_unsupported_elementwise_op_error(); + return PASS_THROUGH; // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// GEMM Padding +// ---------------------------------------------------------------------------- + +/// @brief Derives the GEMM padding specification from a kernel instance type. +/// @tparam Instance A device kernel instance type. +/// @return A builder::GemmPadding enum value corresponding to the kernel's padding configuration. +/// @details This function extracts the GEMM specialization from the Instance's InstanceTraits +/// and converts it to the builder framework's GemmPadding enum. The padding specification +/// indicates which dimensions (M, N, K, O) are padded to handle non-aligned tensor sizes. +/// +/// Supported padding configurations include: +/// - DEFAULT: No padding +/// - M_PADDING, N_PADDING, K_PADDING, O_PADDING: Single dimension padding +/// - MN_PADDING, MK_PADDING, NK_PADDING, etc.: Two dimension padding +/// - MNK_PADDING, MNO_PADDING, etc.: Three dimension padding +/// - MNKO_PADDING: All dimensions padded +template +constexpr builder::GemmPadding gemm_spec() +{ + using InstTraits = InstanceTraits; + using enum builder::GemmPadding; + using enum ck::tensor_operation::device::GemmSpecialization; + + constexpr auto spec = InstTraits::kGemmSpecialization; + + switch(spec) + { + case Default: return DEFAULT; + case MPadding: return M_PADDING; + case NPadding: return N_PADDING; + case KPadding: return K_PADDING; + case MNPadding: return MN_PADDING; + case MKPadding: return MK_PADDING; + case NKPadding: return NK_PADDING; + case MNKPadding: return MNK_PADDING; + case OPadding: return O_PADDING; + case MOPadding: return MO_PADDING; + case NOPadding: return NO_PADDING; + case KOPadding: return KO_PADDING; + case MNOPadding: return MNO_PADDING; + case MKOPadding: return MKO_PADDING; + case NKOPadding: return NKO_PADDING; + case MNKOPadding: return MNKO_PADDING; + } +} + +// ============================================================================ +// SECTION 3: PIPELINE CONFIGURATION HELPERS +// ============================================================================ + +/// @brief Safely extracts the pipeline version from InstanceTraits. +/// @tparam InstTraits The InstanceTraits type to extract pipeline version from. +/// @return The pipeline version as a builder::PipelineVersion enum value. +/// @details This helper function checks if the InstanceTraits has a kPipelineVersion +/// field and extracts it if present. If not present, it returns a default value (V1). +/// This is necessary because not all convolution types expose pipeline version information. +template +constexpr builder::PipelineVersion get_pipeline_version() +{ + if constexpr(requires { InstTraits::kPipelineVersion; }) + { + return convert_pipeline_version(); + } + else + { + return builder::PipelineVersion::V1; + } +} + +/// @brief Safely extracts the pipeline scheduler from InstanceTraits. +/// @tparam InstTraits The InstanceTraits type to extract pipeline scheduler from. +/// @return The pipeline scheduler as a builder::PipelineScheduler enum value. +/// @details This helper function checks if the InstanceTraits has a kPipelineScheduler +/// or kLoopScheduler field and extracts it if present. If neither is present, it returns +/// a default value (DEFAULT). This is necessary because different convolution types may +/// expose scheduler information through different field names. +template +constexpr builder::PipelineScheduler get_pipeline_scheduler() +{ + if constexpr(requires { InstTraits::kPipelineScheduler; }) + { + return convert_pipeline_scheduler(); + } + else if constexpr(requires { InstTraits::kLoopScheduler; }) + { + return convert_pipeline_scheduler(); + } + else + { + return builder::PipelineScheduler::DEFAULT; + } +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp new file mode 100644 index 0000000000..00010e2d48 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -0,0 +1,8 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index f5f3df3159..71db59afb6 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -74,6 +74,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; namespace ck_tile::reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag +{ +}; + // Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle template > { + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag; + // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index ace1b09224..4549b76a3f 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -78,6 +78,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; namespace ck_tile::reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 device kernel +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag +{ +}; + // Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 template > { + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag; + // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 09274d5acd..046e5c3078 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -73,6 +73,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; namespace ck_tile::reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor device kernel +struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag +{ +}; + // Specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor template > { + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag; + // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_reference.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_reference.hpp index b2e8bb6a7c..6875e586cd 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_reference.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_reference.hpp @@ -35,10 +35,10 @@ struct ReferenceCommonTraits typename builder::factory::internal::LayoutToCK::type; // Data types - extract from factory's type helper - using Types = builder::factory::internal::FwdConvTensorDataTypes; - using ADataType = typename Types::ADataType; - using BDataType = typename Types::BDataType; - using EDataType = typename Types::EDataType; + using Types = builder::factory::internal::ConvTensorDataTypes; + using ADataType = typename Types::InDataType; + using BDataType = typename Types::WeiDataType; + using EDataType = typename Types::OutDataType; using AccDataType = float; // Reference uses float accumulation // Elementwise operations - reference only supports PassThrough diff --git a/experimental/builder/include/ck_tile/builder/testing/README.md b/experimental/builder/include/ck_tile/builder/testing/README.md index 85adc59d80..c6662c2b04 100644 --- a/experimental/builder/include/ck_tile/builder/testing/README.md +++ b/experimental/builder/include/ck_tile/builder/testing/README.md @@ -53,7 +53,7 @@ struct ConvSignature { ck_tile::builder::DataType data_type = ck_tile::builder::DataType::FP16; ck_tile::builder::ElementwiseOperation elementwise_operation = - ck_tile::builder::ElementwiseOperation::NONE; + ck_tile::builder::ElementwiseOperation::PASS_THROUGH; }; // Double-check that out structure is well-defined according to the CK-Builder API. @@ -66,7 +66,7 @@ constexpr auto SIGNATURE = ConvSignature{ .direction = ck_tile::builder::ConvDirection::FORWARD, .layout = ck_tile::builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = ck_tile::builder::DataType::FP16, - .elementwise_operation = ck_tile::builder::ElementwiseOperation::NONE, + .elementwise_operation = ck_tile::builder::ElementwiseOperation::PASS_THROUGH, }; ``` @@ -243,7 +243,7 @@ struct ConvSignature { ck_tile::builder::DataType data_type = ck_tile::builder::DataType::FP16; ck_tile::builder::ElementwiseOperation elementwise_operation = - ck_tile::builder::ElementwiseOperation::NONE; + ck_tile::builder::ElementwiseOperation::PASS_THROUGH; }; static_assert(ck_tile::builder::ConvSignatureDescriptor); constexpr auto SIGNATURE = ConvSignature{ @@ -251,7 +251,7 @@ constexpr auto SIGNATURE = ConvSignature{ .direction = ck_tile::builder::ConvDirection::FORWARD, .layout = ck_tile::builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = ck_tile::builder::DataType::FP16, - .elementwise_operation = ck_tile::builder::ElementwiseOperation::NONE, + .elementwise_operation = ck_tile::builder::ElementwiseOperation::PASS_THROUGH, }; // Define the convolution algorithm diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp index 62d265894a..d8910152dd 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp @@ -7,11 +7,15 @@ #include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" #include "ck_tile/builder/testing/testing.hpp" -#include "ck_tile/builder/testing/extent.hpp" +#include "ck_tile/builder/testing/testing_reflect.hpp" +#include "ck_tile/builder/testing/filter_extent.hpp" #include "ck_tile/builder/testing/tensor_buffer.hpp" #include "ck_tile/builder/testing/tensor_initialization.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/validation.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + /// This file implements common functionality for invoking/testing grouped /// forward convolutions created through the CK Builder API. The main item /// of it is the ConvArgs structure - which contains a complete description @@ -37,12 +41,12 @@ namespace ck_tile::builder::test { template struct ConvTensorLengths { - size_t batch_size = 1; // N - size_t groups = 1; // G - size_t input_channels = 1; // C - size_t output_channels = 1; // K - Extent image = {}; // W, H, D - Extent filter = {}; // X, Y, Z + size_t batch_size = 1; // N + size_t groups = 1; // G + size_t input_channels = 1; // C + size_t output_channels = 1; // K + FilterExtent image = {}; // W, H, D + FilterExtent filter = {}; // X, Y, Z }; /// @brief `Args` specialization for forward convolution. @@ -59,12 +63,19 @@ struct Args constexpr static auto WEIGHT_TYPE = SIGNATURE.data_type; constexpr static auto OUTPUT_TYPE = SIGNATURE.data_type; - // TODO: We shouldn't need to call into an internal namespace here. - using Ops = factory::internal::ElementwiseOps; + constexpr static int INPUT_RANK = 3 + SPATIAL_DIM; + constexpr static int WEIGHT_RANK = 3 + SPATIAL_DIM; + constexpr static int OUTPUT_RANK = 3 + SPATIAL_DIM; + + using InputDescriptor = TensorDescriptor; + using WeightDescriptor = TensorDescriptor; + using OutputDescriptor = TensorDescriptor; // TODO: We shouldn't need to call into an internal namespace here. - using Layouts = - factory::internal::ConvTensorLayouts; + using Ops = factory::internal::ConvElementwiseOps; + + // TODO: We shouldn't need to call into an internal namespace here. + using Layouts = factory::internal::ConvTensorLayouts; ConvTensorLengths lengths; @@ -73,19 +84,19 @@ struct Args // implementation (based on ConvParam in old CK / CK Tile) does not // support strides at all. - Extent filter_strides; - Extent filter_dilation; - Extent input_left_pad; - Extent input_right_pad; + FilterExtent filter_strides; + FilterExtent filter_dilation; + FilterExtent input_left_pad; + FilterExtent input_right_pad; - Ops::AElementwiseOp a_elementwise_op; - Ops::BElementwiseOp b_elementwise_op; - Ops::CDEElementwiseOp cde_elementwise_op; + Ops::InElementwiseOp a_elementwise_op; + Ops::WeiElementwiseOp b_elementwise_op; + Ops::OutElementwiseOp cde_elementwise_op; /// This function returns the `TensorDescriptor` corresponding to /// the input-tensor of the convolution problem. This can then /// be used to, for example, allocate memory. - TensorDescriptor make_input_descriptor() const + InputDescriptor make_input_descriptor() const { // TODO: We're using old CK functionality to compute the right // values here, mainly because CK tile does not support the @@ -95,32 +106,38 @@ struct Args // function. const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< - typename Layouts::ALayout>(param); - return TensorDescriptor(desc.GetLengths(), desc.GetStrides()); + typename Layouts::InLayout>(param); + using Extent = typename InputDescriptor::Extent; + return InputDescriptor(Extent::from_vector(desc.GetLengths()), + Extent::from_vector(desc.GetStrides())); } /// This function returns the `TensorDescriptor` corresponding to /// the weight-tensor of the convolution problem. This can then /// be used to, for example, allocate memory. - TensorDescriptor make_weight_descriptor() const + WeightDescriptor make_weight_descriptor() const { // See note in implementation of `make_input_descriptor`. const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed< - typename Layouts::BLayout>(param); - return TensorDescriptor(desc.GetLengths(), desc.GetStrides()); + typename Layouts::WeiLayout>(param); + using Extent = typename WeightDescriptor::Extent; + return WeightDescriptor(Extent::from_vector(desc.GetLengths()), + Extent::from_vector(desc.GetStrides())); } /// This function returns the `TensorDescriptor` corresponding to /// the output-tensor of the convolution problem. This can then /// be used to, for example, allocate memory. - TensorDescriptor make_output_descriptor() const + OutputDescriptor make_output_descriptor() const { // See note in implementation of `make_input_descriptor`. const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed< - typename Layouts::ELayout>(param); - return TensorDescriptor(desc.GetLengths(), desc.GetStrides()); + typename Layouts::OutLayout>(param); + using Extent = typename OutputDescriptor::Extent; + return OutputDescriptor(Extent::from_vector(desc.GetLengths()), + Extent::from_vector(desc.GetStrides())); } /// Convert the Args structure into a CK conv_param structure. This @@ -165,6 +182,12 @@ struct Inputs { void* input; void* weight; + + static void reflect(const Args& args, const auto& inspect) + { + inspect("input", args.make_input_descriptor(), &Inputs::input); + inspect("weight", args.make_weight_descriptor(), &Inputs::weight); + } }; /// @brief `Outputs` specialization for forward convolution. @@ -177,95 +200,24 @@ template struct Outputs { void* output; -}; -/// @brief `UniqueInputs` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see UniqueInputs -/// @see ValidUniqueInputs -template - requires ValidConvSignature && ConvDirectionIsForward -struct UniqueInputs -{ - DeviceBuffer input_buf; - DeviceBuffer weight_buf; - - /// @see ValidUniqueInputs - Inputs get() + static void reflect(const Args& args, const auto& inspect) { - return { - .input = input_buf.get(), - .weight = weight_buf.get(), - }; + inspect("output", args.make_output_descriptor(), &Outputs::output); } }; -/// @brief `UniqueOutputs` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see UniqueOutputs -/// @see ValidUniqueOutputs -template - requires ValidConvSignature && ConvDirectionIsForward -struct UniqueOutputs -{ - DeviceBuffer output_buf; - - /// @see ValidUniqueOutputs - Outputs get() - { - return { - .output = output_buf.get(), - }; - } -}; - -/// @brief `alloc_inputs()` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see alloc_inputs() -template - requires ValidConvSignature && ConvDirectionIsForward && - ValidUniqueInputs -UniqueInputs alloc_inputs(const Args& args) -{ - return { - .input_buf = alloc_tensor_buffer(args.make_input_descriptor()), - .weight_buf = alloc_tensor_buffer(args.make_weight_descriptor()), - }; -} - /// @brief `init_inputs()` specialization for forward convolution. /// /// @tparam SIGNATURE Forward convolution signature. /// /// @see alloc_inputs() template - requires ValidConvSignature && ConvDirectionIsForward && - ValidUniqueInputs -void init_inputs(const Args& args, UniqueInputs& inputs) + requires ValidConvSignature && ConvDirectionIsForward +void init_inputs(const Args& args, Inputs inputs) { - init_tensor_buffer_uniform_fp(inputs.input_buf, args.make_input_descriptor(), -2.0f, 2.0f); - init_tensor_buffer_uniform_fp(inputs.weight_buf, args.make_weight_descriptor(), -2.0f, 2.0f); -} - -/// @brief `alloc_outputs()` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see alloc_outputs() -template - requires ValidConvSignature && ConvDirectionIsForward && - ValidUniqueOutputs -UniqueOutputs alloc_outputs(const Args& args) -{ - return { - .output_buf = alloc_tensor_buffer(args.make_output_descriptor()), - }; + init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f); + init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f); } } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp index cc5c613d95..a90f53ba7d 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp @@ -3,10 +3,10 @@ #pragma once -#include -#include - #include "ck_tile/builder/testing/conv_fwd.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include +#include /// This file contains the implementation details for invoking/testing /// grouped convolution operations in old CK. The main item is the @@ -15,6 +15,63 @@ namespace ck_tile::builder::test { +namespace detail { + +/// @brief Concept for checking whether this is the reference convolution +/// implementation. +/// +/// This is the same as `::ck_tile::builder::test::CkConvInstance`, except +/// with some utility aliases. For that reason, its moved to this detail +/// namespace. +template > +concept CkConvInstance = requires(Conv& conv, + // TODO: This should be changed depending on IsMultiA etc. + // Currently that is not yet supported elsewhere anyway. + const void* p_a, + const void* p_b, + void* p_e, + std::array lengths, + std::array strides, + std::array filter, + Ops::InElementwiseOp elementwise_a, + Ops::WeiElementwiseOp elementwise_b, + Ops::OutElementwiseOp elementwise_cde) { + { + conv.MakeArgument(p_a, + p_b, + // TODO: Support multiple D outputs. + {}, + p_e, + // A lengths/strides + lengths, + strides, + // B lengths/strides + lengths, + strides, + // TODO: Ds lengths/strides + {}, + {}, + // E lengths/strides + lengths, + strides, + // strides/dilations/pads + filter, + filter, + filter, + filter, + // element-wise operations. + elementwise_a, + elementwise_b, + elementwise_cde) + }; +}; + +} // namespace detail + /// @brief Concept for checking whether a convolution is invoked like old CK. /// /// This concept is used to tell whether a convolution implementation is @@ -24,13 +81,8 @@ namespace ck_tile::builder::test { /// /// - SIGNATURE is the operation signature. /// - Conv is a convolution instance created by the CK Builder API. -template -concept IsCkConvInstance = - // TODO: This should be implemented by converting the signature into the - // type parameters for DeviceGroupedConvFwdMultipleABD. For now, just leave - // it empty. Improve when needed, you get the point. Also we should probably - // move this to the ck conv factory helper. - true; +template +concept CkConvInstance = detail::CkConvInstance; /// @brief `run()` specialization for forward convolution and old CK. /// @@ -39,10 +91,9 @@ concept IsCkConvInstance = /// operation. This should be caught and reported by the testing framework. /// /// @see run() -template - requires ValidConvSignature && ConvDirectionIsForward && - IsCkConvInstance -void run(Conv& conv, +template + requires ValidConvSignature && ConvDirectionIsForward +void run(CkConvInstance auto& conv, const Args& args, const Inputs& inputs, const Outputs& outputs) diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp new file mode 100644 index 0000000000..85493e32eb --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp @@ -0,0 +1,114 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/conv_fwd.hpp" +#include +#include + +/// This file contains the implementation details for invoking/testing +/// grouped convolution operations using the reference implementation. +/// The main item is the `run()` function, which is the primary way to +/// invoke the reference execution mechanism. +/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`, +/// but its made specific to the reference implementation, which is +/// invoked in a slightly different way. + +namespace ck_tile::builder::test { + +/// @brief Concept for checking whether this is the reference convolution +/// implementation. +/// +/// This concept is used to tell whether a convolution implementation is +/// likely to be the reference implementation - that is, whether we should +/// invoke it like the reference kernel. This is mainly used with `run()` to +/// differentiate which implementation that should be invoked. +/// +/// - SIGNATURE is the operation signature. +/// - Conv is a convolution instance created by the CK Builder API. +template +concept RefConvInstance = requires(Conv& conv, + const void* input, + const void* weight, + void* output, + int G, + int N, + int K, + int C, + std::vector dims) { + { + conv.Run(input, + weight, + output, + G, + N, + K, + C, + dims, // input_spatial + dims, // filter_spatial + dims, // output_spatial + dims, // strides + dims, // dilations + dims // left_pads + ) + }; +}; + +/// @brief `run()` specialization for forward convolution and the reference +/// implementation. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// @throws std::runtime_error if the arguments weren't actually valid for the +/// operation. This should be caught and reported by the testing framework. +/// +/// @see run() +template + requires ValidConvSignature && + // TODO: Maybe we can unify this implementation for bwd/weight too? + // for now, just concern outselves with reference and see when the + // rest of the bwd/weight plumbing is there. + ConvDirectionIsForward +void run(RefConvInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + // We don't want to compute the output dims manually, just get + // them via the existing infrastructure + const auto param = args.to_ck_conv_param(); + + // TODO: The reference convolution is currently missing a few features. + // Just throw for now, but regard these as TODO items that should be resolved + // eventually. + + // Right pads are not supported right now for some reason. + for(auto right_pad : param.input_right_pads_) + { + if(right_pad != 0) + throw std::runtime_error("TODO: Support right pad in reference conv"); + } + + if(!args.make_input_descriptor().is_packed()) + throw std::runtime_error("TODO: Support non-packed input tensor in reference conv"); + if(!args.make_weight_descriptor().is_packed()) + throw std::runtime_error("TODO: Support non-packed weight tensor in reference conv"); + if(!args.make_output_descriptor().is_packed()) + throw std::runtime_error("TODO: Support non-packed output tensor in reference conv"); + + conv.Run(inputs.input, + inputs.weight, + outputs.output, + param.G_, + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.output_spatial_lengths_, + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/debug.hpp b/experimental/builder/include/ck_tile/builder/testing/debug.hpp new file mode 100644 index 0000000000..4014d62d48 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/debug.hpp @@ -0,0 +1,634 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/error.hpp" +#include "ck_tile/builder/testing/type_traits.hpp" +#include "ck/utility/type_convert.hpp" +#include +#include +#include +#include +#include +#include +#include + +/// This file contains a few debugging utilities, mainly focused around +/// tensor data. The idea is that the functionality in this file is not +/// necessarily used in any testing directly, but is available for the +/// programmer to help with debugging problems. These utilities themselves +/// should be tested just the same, though, so that they don't undergo +/// bitrot while they are not actively being used. + +namespace ck_tile::builder::test { + +namespace detail { + +/// @brief Custom number punctuation for CK-Builder debugging. +/// +/// During debugging, the locale is usually left to the default C locale. +/// The C locale does not have any thousands separator, which makes +/// large numbers hard to read. This is a specialization of the default +/// C++ number punctuation (`std::numpunct`) which separates thousands +/// using `'`, which helps getting a quick overview of the magnitude of +/// a number. This character is chosen because C++14 allows number literals +/// to have this character. +/// +/// @note When using this locale, be sure to restore the old locale in the +/// event that the user actually wants to use a non-standard locale. +/// +/// @see std::numpunct +struct numpunct : std::numpunct +{ + char do_thousands_sep() const override { return '\''; } + + std::string do_grouping() const override + { + // See std::numpunct, this separates by thousands. + return "\3"; + } +}; + +} // namespace detail + +/// @brief Print information about a tensor descriptor. +/// +/// This function dumps useful information from a tensor descriptor to a +/// stream, `std::cout` by default. This includes the number of elements +/// in the tensor, the size of the backing space, lengths, strides, etc. +/// +/// @note All information is printed using a lightly modified locale to +/// get a unified printing experience. The original locale in `stream` is +/// temporarily replaced, but restored before the function returns. +/// +/// @tparam DT The tensor element datatype +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param name A name for the tensor descriptor. +/// @param desc The tensor descriptor to print. +/// @param out The stream to print to, `std::cout` by default. +template +void print_descriptor(std::string_view name, + const TensorDescriptor& desc, + std::ostream& out = std::cout) +{ + // Create a custom stream with a completely new config (locale, + /// precision, fill, etc). Use an osyncstream to buffer the output + /// while were at it (its not likely to help a lot, but why not). + std::osyncstream stream(out.rdbuf()); + stream.imbue(std::locale(std::locale(), new detail::numpunct{})); + + // Print name along with some generic info + const auto size = desc.get_element_size(); + const auto space = desc.get_element_space_size(); + const auto bytes = desc.get_element_space_size_in_bytes(); + const auto packed = desc.is_packed(); + + stream << "Descriptor \"" << name << "\":\n" + << " data type: " << DT << '\n' + << " size: " << size << " elements\n" + << " space: " << space << " elements (" << bytes << " bytes)\n" + << " lengths: " << desc.get_lengths() << '\n' + << " strides: " << desc.get_strides() << '\n' + << " packed: " << (packed ? "yes" : "no") << std::endl; +} + +/// @brief User configuration for printing tensors. +/// +/// This structure houses some configuration fields for customizing how tensors +/// are printed. The default is usually good, though `TensorPrintConfig::unlimited()` +/// is useful if you want to print the entire tensor to the output regardless of size. +struct TensorPrintConfig +{ + /// @brief A limit for the number of columns in a tensor row to print. + /// + /// Each row of a tensor will be printed as a sequence of values. At most + /// this number of values are printed, if there are more, `row_skip_val` + /// will be printed in between. + size_t col_limit = 10; + + /// @brief A limit for the number of rows in a 2D matrix to print + /// + /// Tensors with rank higher than 1 are printed as a single matrix or a series + /// of matrix slices. At most this number of rows of the matrix will be printed. + /// If there are more rows, a row of `matrix_row_skip_val` and possibly + /// `row_skip_val` will be printed in between. + size_t row_limit = 10; + + /// @brief A limit for the number of 2D tensor slices to print. + /// + /// Tensors with rank higher than 2 are flattened into a sequence of slices. At + /// most this number of slices will be printed. + size_t slice_limit = 8; + + /// @brief Text to print at the start of a row of values. + /// + /// This is used by `TensorPrinter`, and printed at the start of a row of tensor + /// values. + std::string_view row_prefix = " "; + + /// @brief Text to print between fields of a row. + /// + /// This is used by `TensorPrinter`, and printed between each value of a row of + /// tensor values. + std::string_view row_field_sep = " "; + + /// @brief Text to print when skipping some number of row values. + /// + /// This is used by `TensorPrinter`, and printed instead of some number of values + /// when the number of values in a row is too large to all print. + std::string_view row_skip_val = "..."; + + /// @brief Text to print when skipping a row of a matrix. + /// + /// This is used by `TensorPrinter`, and printed instead of a value when some + /// number of rows is skipped when printing a matrix. This is similar to + /// `row_skip_val`, except in the vertical direction. Note that ALL values + /// in the skip row is printed this way. + std::string_view matrix_row_skip_val = "..."; + + /// @brief The precision of tensor floating point values. + /// + /// Set the number of decimal digits that is printed for a floating point value. + int float_precision = 3; + + /// @brief Return the default print config, but without any printing limits. + /// + /// This is useful if you want to print the *entire* tensor, but be aware that + /// this may print a lot of data if the tensor is large! + constexpr static TensorPrintConfig unlimited() + { + return { + .col_limit = std::numeric_limits::max(), + .row_limit = std::numeric_limits::max(), + .slice_limit = std::numeric_limits::max(), + }; + } +}; + +namespace detail { + +/// @brief Iterate over a range of values, but limit the amount of iterations. +/// +/// Iterate over values `0..n`, but if `limit > n`, only iterate over the +/// first and last few (`limit // 2)` items. This can be used to iterate over +/// large ranges in a way that not too many values are visited. Its primarily +/// used when printing tensors so that not all values of a giant tensor are +/// dumped to the user's terminal. +/// +/// @param n The total number of items to iterate over. +/// @param limit The maximum number of items to iterate over. Use even values +/// for best results, as this will lead to the same amount of values in the +/// "begin" and "end" sections. +/// @param f A functor to invoke for each element. The sole parameter is the +/// index. +/// @param delim A functor to invoke between the begin and end sections. This +/// function is only invoked if any items are skipped at all. +void limited_foreach(size_t n, size_t limit, auto f, auto delim) +{ + if(n <= limit) + { + for(size_t i = 0; i < n; ++i) + f(i); + } + else + { + const auto begin_count = (limit + 1) / 2; // Round up in case `delim` is odd. + const auto end_count = limit / 2; + const auto skip_count = n - limit; + + for(size_t i = 0; i < begin_count; ++i) + f(i); + + delim(skip_count); + + for(size_t i = n - end_count; i < n; ++i) + f(i); + } +}; + +/// @brief Output stream requirements for use with `TensorPrinter`. +/// +/// The `TensorPrinter` does not write to an ostream directly, but rather writes to +/// a custom stream object. This is mainly so that the user of `TensorPrinter` can +/// get more details than directly with an ostream. Basically, a valid implementation +/// of `TensorPrintStream` exposes 3 things: +/// - A way to print (stringified) tensor elements. +/// - A way to print arbitrary text messages. These are mostly for formatting. This +/// should be implemented using varargs which are directly folded into an ostream, +/// so that functions can be used. +/// - A way to query the max width of any `val` field. +/// +/// @see TensorPrinter for more information. +template +concept TensorPrintStream = requires(Stream& stream, std::string_view val) { + { stream.max_width } -> std::convertible_to; + { stream.val(val) } -> std::same_as; + { stream.msg() } -> std::same_as; + { stream.msg("msg") } -> std::same_as; + { stream.msg(std::setw(3), std::setfill(4), "msg", val) } -> std::same_as; +}; + +/// @brief Utility to print tensors. +/// +/// This structure implements the main logic for printing tensors to a stream. +/// In order to help with formatting, the `TensorPrinter` abstracts over a custom +/// stream type, see `TensorPrintStream`. This type is actually mostly an internal +/// helper and mainly used by `print_tensor`. Its supposed to be constructed +/// manually, but see the field docs for what is required. +/// +/// @tparam DT The data type of the tensor to print. +/// @tparam RANK The rank (number of spatial dimensions) of the tensor to print. +/// +/// @see print_tensor +template +struct TensorPrinter +{ + /// The name of this tensor. This will be used during printing to add extra + /// clarity about what the user is seeing. + std::string_view name; + + /// Configuration details of how to print the tensor. This should be able to + /// be specified by the user, but the default is good in most cases. + TensorPrintConfig config; + + /// The lengths of the tensor to print. These values are directly from + /// `TensorDescriptor::get_lengths()`, stored here to avoid querying them + /// repeatedly. + Extent lengths; + + /// The strides of the tensor to print. These values are directly from + /// `TensorDescriptor::get_strides()`, stored here to avoid querying them + /// repeatedly. + Extent strides; + + /// The tensor's backing buffer. This memory should be host-accessible, for + /// example by copying it back to the host first. + const void* h_buffer; + + /// A common stringstream for stringifying tensor values. This is here mostly + /// so that we can cache the internal allocation. + std::stringstream ss; + + /// @brief Low-level tensor value stringifying function. + /// + /// Print value `value` to the stringstream `ss` (member value). This function + /// is the actual low-level printing function that prints each element of the + /// tensor. In order to get a robust printing implementation, the value is written + /// directly into a stringstream, which is then further processed to be actually + /// written to the output. This way, the format doesn't depend on the ostream + /// configuration. + /// + /// @param value The value to print to the stream. + void stringify_value(const void* value) + { + if constexpr(DT == DataType::UNDEFINED_DATA_TYPE) + { + ss << "??"; + return; + } + + using CKType = detail::cpp_type_t
; + const auto ck_value = *static_cast(value); + + if constexpr(DT == DataType::I32 || DT == DataType::I8 || DT == DataType::U8) + ss << ck_value; + else if constexpr(DT == DataType::FP64 || DT == DataType::FP32) + ss << std::fixed << std::setprecision(config.float_precision) << ck_value; + else if constexpr(DT == DataType::FP16 || DT == DataType::BF16 || DT == DataType::FP8 || + DT == DataType::BF8) + ss << std::fixed + << std::setprecision(config.float_precision) + // Note: We are using CK types here (cpp_type_t uses DataTypeToCK), so + // use CK's type_convert function. + << ::ck::type_convert(ck_value); + else + // TODO: Tuple types? Currently not implemented in DataTypeToCK... + static_assert(false, "stringify_value unsupported data type, please implement"); + } + + /// @brief Print the value at an index to a stream. + /// + /// This function reads the value at `index` and prints it to `stream` (using + /// `stream.val(...)`). + /// + /// @param stream The stream to print to. + /// @param index The index in the tensor of the value to print. + void print_value(TensorPrintStream auto& stream, const Extent& index) + { + const auto offset = calculate_offset(index, strides); + const auto* value_ptr = + &static_cast(h_buffer)[offset * data_type_sizeof(DT)]; + + // Reset the stream without allocating. + // ss.str("") allocates... + ss.clear(); + ss.seekg(0); + ss.seekp(0); + stringify_value(value_ptr); + // ss.view() returns a view of the ENTIRE buffer, which may have + // lingering data since we used seekp() and seekg() to reset the + // stream. For some reason std::stringstream works this way... + // Fortunately tellp() returns how many bytes we've actually + // written. + const auto view = ss.view().substr(0, ss.tellp()); + stream.val(view); + } + + /// @brief Print a 1D row to a stream. + /// + /// Print a row of tensor values to the stream. This function is used for both + /// 1D tensors and for rows of 2D tensors, in which the base coordinate is given + /// by `index`. Note that the print configuration is taken into account to avoid + /// flooding the user's terminal with values. + /// + /// @param stream The stream to print to. + /// @param index The index of the row to print. The rightmost index element is + /// ignored, as that is the index of the value _within_ the row. + void print_row(TensorPrintStream auto& stream, Extent& index) + { + // See note in `print_matrix`. + stream.msg(config.row_prefix); + limited_foreach( + lengths[RANK - 1], + config.col_limit, + [&](auto i) { + stream.msg(config.row_field_sep); + index[RANK - 1] = i; + print_value(stream, index); + }, + [&]([[maybe_unused]] auto skip_count) { + stream.msg(config.row_field_sep); + // Note: Not using stream.val(...) here because we don't want this + // field to partake in max_width computation, nor do we want to + // pad it to the max width. + stream.msg(config.row_skip_val); + }); + + stream.msg('\n'); + } + + /// @brief Print a 2D matrix to a stream. + /// + /// Print a matrix of tensor values to the stream. This function is used for both + /// 2D and slices of higher-dimensional tensors, in which the base coordinate is + /// given by `index`. Note that the print configuration is taken into account to + /// avoid flooding the user's terminal with values. + /// + /// @param stream The stream to print to. + /// @param index The index of the row to print. The 2 rightmost index elements are + /// ignored, as those are the indices of values _within_ the matrix. + void print_matrix(TensorPrintStream auto& stream, Extent& index) + { + limited_foreach( + lengths[RANK - 2], + config.row_limit, + [&](auto i) { + index[RANK - 2] = i; + print_row(stream, index); + }, + [&]([[maybe_unused]] auto row_skip_count) { + // When we encounter a skip row, continue with the same logic + // as printing 1D tensor rows. Instead of actual values, we will + // simply print MATRIX_ROW_SKIP_VAL (usually something like "..."). + stream.msg(config.row_prefix); + limited_foreach( + lengths[RANK - 1], + config.col_limit, + [&]([[maybe_unused]] auto i) { + stream.msg(config.row_field_sep); + // Note: We're using `stream.val(...)` here because we *do* want this field + // to partake in max_width computation, and we *do* want to pad it like + // value fields. This is so that these appear the same width as actual + // values, so that everything is neatly aligned. This also ensures that if + // there are no skip values, then the size of the skip field is not taken + // into account. + stream.val(config.matrix_row_skip_val); + }, + [&]([[maybe_unused]] auto col_skip_count) { + stream.msg(config.row_field_sep); + // Note: Not using stream.val(...) here because we don't want this + // field to partake in max_width computation, nor do we want to + // pad it to the max width. + stream.msg(config.row_skip_val); + }); + stream.msg('\n'); + }); + } + + /// @brief Print a tensor to a stream. + /// + /// This is the main tensor printing function. It calls `print_row` or `print_matrix` + /// (possibly repeatedly) as required. This function prints the entire tensor in + /// `h_buffer` regardless. + /// + /// @param stream The stream to print to. + void print_tensor(TensorPrintStream auto& stream) + { + Extent zero_coord = {}; + if constexpr(RANK == 0) + { + // 0D case: just print the one value + stream.msg(config.row_prefix); + stream.msg(config.row_field_sep); + print_value(stream, zero_coord); + stream.msg('\n'); + } + else if constexpr(RANK == 1) + { + // 1D case: dump everything on one line + print_row(stream, zero_coord); + } + else if constexpr(RANK == 2) + { + // 2D case: print a 2D matrix + print_matrix(stream, zero_coord); + } + else + { + // For higher dimensions, print each window as a slice + // We want to limit the *total* number of slices using `slice_limit`, + // not the number in each axis. So flatten the remaining dimensions. + // This also avoids recursion in this function in general. + + // First get the shape minus the 2 inner dimensions + Extent outer_shape; + std::copy_n(lengths.begin(), RANK - 2, outer_shape.begin()); + + NdIter iter(outer_shape); + detail::limited_foreach( + iter.numel(), + config.slice_limit, + [&](auto outer_flat_index) { + // Now decode the outer index and turn it back into a complete index + const auto outer_index = iter(outer_flat_index); + Extent index = {}; + std::copy_n(outer_index.begin(), RANK - 2, index.begin()); + + // Print an extra separating line between two slices + if(outer_flat_index != 0) + stream.msg('\n'); + + // Print an information header about the current slice + stream.msg("Tensor \"", name, "\", slice ["); + for(auto x : outer_index) + stream.msg(x, ", "); + stream.msg(":, :]\n"); + + // And print is as matrix + print_matrix(stream, index); + }, + [&](auto skip_count) { stream.msg("\n(skipping ", skip_count, " slices...)\n"); }); + } + } +}; + +/// @brief Implementation of `TensorPrintStream` to figure out the maximum +/// width of a field. +/// +/// In order to produce neatly aligned tensors, where all values of each row +/// appear on the same columns, we have to figure out the maximum width of +/// each field. This print stream helps with that: It does not actually print +/// anything, it just figures out the maximum width of any value (not message). +/// +/// @details OK, this function does actually print things, but only to an +/// internal `stringstream`. This is so that we can easily figure out the +/// width of the field (in bytes), just by counting the amount of bytes +/// written into the string stream. +/// +/// @see TensorPrintStream +struct MaxFieldWidthStream +{ + size_t max_width = 0; + + /// @brief Print a tensor value to the stream + /// + /// "Print" a value to the stream. This function figures out the width + /// of the value when printed, and then composes it with `max_width` to + /// figure out the total maximum. + /// + /// @param value The value to print. + void val(std::string_view value) { max_width = std::max(max_width, value.size()); } + + /// @brief Print a message to the stream. + /// + /// "Print" a non-value message to the stream. In this implementation, + /// everything is discarded. + /// + /// @tparam Args the types of the values to print. + /// + /// @param args The values to print. + template + void msg([[maybe_unused]] const Args&... args) + { + } +}; + +/// @brief Implementation of `TensorPrintStream` which actually prints. +/// +/// In contrast to `MaxFieldWidthStream`, this function actually prints +/// to an ostream, taking the value produced by that type into account. +struct OutputStream +{ + std::ostream& stream; + // The maximum width of each tensor value. + size_t max_width; + + /// @brief Print a tensor value to the stream + /// + /// Actually print a value into the stream, (right-)padding it to + /// `max_width`. + /// + /// @param value The value to print. + void val(std::string_view value) + { + stream << std::setfill(' ') << std::setw(max_width) << value; + } + + /// @brief Print a message to the stream. + /// + /// This prints a non-value message directly to the ostream, as if + /// folded via `operator<<`. + /// + /// @tparam Args the types of the values to print. + /// + /// @param args The values to print. + template + void msg(const Args&... args) + { + (stream << ... << args); + } +}; + +} // namespace detail + +/// @brief Print device tensor values to an ostream. +/// +/// Print the values of a tensor to an ostream. This function neatly formats +/// the tensor according to `config`, tabulating the values so that they are +/// vertically aligned and skipping values to prevent flooding the terminal. +/// With the default config, this function is good to get a quick overview +/// of what a tensor looks like. For a more complete overview, consider +/// supplying `TensorPrintConfig::unlimited()` to get everything (but beware +/// of flooding the terminal). Tensors are printed with the rightmost-dimension +/// as inner dimension, these values appear on the same row in the output. +/// +/// @tparam DT The data type of the tensor. +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param name A name for the tensor. This will be used to add some extra identifying +/// information during printing. +/// @param desc The descriptor for the tensor memory layout. +/// @param d_buffer The tensor's actual data buffer. This is expected to be +/// _device accessible_ memory, as its copied back to the host first. +/// @param config Tensor printing configuration. This allows tweaking some details +/// of the printing process. +/// @param out The ostream to print to, `std::cout` by default. +template +void print_tensor(std::string_view name, + const TensorDescriptor& desc, + const void* d_buffer, + TensorPrintConfig config = {}, + std::ostream& out = std::cout) +{ + // Copy memory to the host (printing from device is sketchy) + const auto space = desc.get_element_space_size_in_bytes(); + std::vector h_buffer(space); + check_hip(hipMemcpy(h_buffer.data(), d_buffer, space, hipMemcpyDeviceToHost)); + + // Create a custom stream with a completely new config (locale, + /// precision, fill, etc). Use an osyncstream to buffer the output + /// while were at it (its not likely to help a lot, but why not). + std::osyncstream stream(out.rdbuf()); + stream.imbue(std::locale(std::locale(), new detail::numpunct{})); + + // Print a header for the entire tensor (regardless of if there are multiple slices). + stream << "Tensor \"" << name << "\": shape = " << desc.get_lengths() << "\n"; + + detail::TensorPrinter printer = { + .name = name, + .config = config, + .lengths = desc.get_lengths(), + .strides = desc.get_strides(), + .h_buffer = h_buffer.data(), + .ss = std::stringstream(), + }; + + // We're actually going to print twice: once to figure out the + // maximum width of the fields, and once to actually print to the stream. + + // Print once to figure out the maximum field width. + detail::MaxFieldWidthStream max_field_width; + printer.print_tensor(max_field_width); + + // Actually print to the output stream. + detail::OutputStream tensor_out = { + .stream = stream, + .max_width = max_field_width.max_width, + }; + printer.print_tensor(tensor_out); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/error.hpp b/experimental/builder/include/ck_tile/builder/testing/error.hpp new file mode 100644 index 0000000000..242f2a8e51 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/error.hpp @@ -0,0 +1,150 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +/// This file defines some utilities for dealing with HIP errors. In the CK-Builder +/// testing code, we'd like to just turn them into exceptions: This cleans up testing +/// code as we don't need to think about returning error codes, but its still much +/// cleaner than just creating a hard crash and thereby possibly interrupting other +/// units in the same test. The testing framework can catch these exceptions where +/// necessary. +/// +/// While the exceptions defined in this file are in principle suitable for general +/// usage, HIP functions which return HIP error codes (`hipError_t`) should be +/// checked using the `check_hip` function. + +namespace ck_tile::builder::test { + +/// @brief Generic HIP exception. +/// +/// This is a derivation of `std::runtime_error` which represents a HIP error code. +/// +/// @see std::runtime_error +/// @see hipError_t +struct HipError : std::runtime_error +{ + /// @brief Utility for formatting HIP error messages + /// + /// Returns a human-readable description of a HIP error. Given a description of the + /// activity that the user tried to perform, this function appends the HIP-specific + /// information such as the stringified version of the error code, and the error + /// code itself (for reference). + /// + /// @param user_msg User-given message about the activity at time of error. + /// @param code The status to report. + /// @param src The location where this error was discovered. + static std::string + format_error(std::string_view user_msg, hipError_t code, std::source_location src) + { + std::stringstream msg; + msg << user_msg << ": " << hipGetErrorString(code) << " (" << code << ")"; + if(src.function_name()) + msg << " in function '" << src.function_name(); + msg << "' at " << src.file_name() << ":" << src.line() << ":" << src.column(); + return msg.str(); + } + + /// @brief Construct a generic HIP error. + /// + /// @param msg User-given message about the activity at time of error. + /// @param code The status to report. + /// @param src The location where this error was discovered. Defaults to the caller's + /// location. + HipError(std::string_view msg, + hipError_t code, + std::source_location src = std::source_location::current()) + : std::runtime_error(format_error(msg, code, src)), code_(code) + { + } + + /// @brief Retrieve the inner error code. + /// + /// This function returns the status code that was encountered while checking an + /// operation for errors. + hipError_t code() const { return code_; } + + private: + hipError_t code_; +}; + +/// @brief HIP out of memory error. +/// +/// This a derivation of `HipError` which is specialized for Out-of-memory errors. This +/// makes it easier to attach additional context, and to match on these errors while +/// using `catch` blocks. +/// +/// @see HipError +struct OutOfDeviceMemoryError : HipError +{ + /// @brief Construct an out-of-device-memory error. + /// + /// @param msg User-given message about the activity at time of error. + /// @param src The location where this error was discovered. Defaults to the caller's + /// location. + OutOfDeviceMemoryError(std::string_view msg = "failed to allocate device memory", + std::source_location src = std::source_location::current()) + : HipError(msg, hipErrorOutOfMemory, src) + { + } +}; + +/// @brief Check HIP status for errors. +/// +/// This function checks a HIP status code (obtained from a HIP function call) for any +/// errors. If the status `code` is not `hipSuccess`, this function throws an instance of +/// `HipError`. The exact type thats thrown depends on the status. If `code` represents +/// an out-of-memory error `hipErrorOutOfMemory`, then `OutOfDeviceMemoryError` will be +/// thrown instead. +/// +/// @param msg User-given message about the activity at possible time of error. +/// @param code The HIP status code to examine. +/// @param src The location where this status was set. Defaults to the caller's location. +/// +/// @throws HipError if `code` is not `hipSuccess`. +/// +/// @see HipError +/// @see OutOfDeviceMemoryError +inline void check_hip(std::string_view msg, + hipError_t code, + std::source_location src = std::source_location::current()) +{ + // -Wswitch-enum throws a warning if this code is changed into a switch, even with + // the `default` label... + + if(code == hipSuccess) + // When you beat the error allegations + return; + else if(code == hipErrorOutOfMemory) + throw OutOfDeviceMemoryError(msg, src); + else + throw HipError(msg, code, src); +} + +/// @brief Check HIP status for errors. +/// +/// This function is similar to `check_hip(std::string_view, hipError_t)`, except that a +/// default message is given. +/// +/// @param code The HIP status code to examine. +/// @param src The location where this status was set. Defaults to the caller's location. +/// +/// @throws HipError if `code` is not `hipSuccess`. +/// +/// @see HipError +/// @see OutOfDeviceMemoryError +/// @see check_hip(std::string_view, hipError_t) +inline void check_hip(hipError_t code, std::source_location src = std::source_location::current()) +{ + check_hip(code == hipErrorOutOfMemory ? "failed to allocate device memory" + : "HIP runtime error", + code, + src); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/extent.hpp b/experimental/builder/include/ck_tile/builder/testing/filter_extent.hpp similarity index 50% rename from experimental/builder/include/ck_tile/builder/testing/extent.hpp rename to experimental/builder/include/ck_tile/builder/testing/filter_extent.hpp index a2d9b3ff4c..3587ac406f 100644 --- a/experimental/builder/include/ck_tile/builder/testing/extent.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/filter_extent.hpp @@ -5,28 +5,29 @@ namespace ck_tile::builder::test { -/// This structure describes a 1-, 2-, or 3-D extent. Its used to -/// communicate 1-, 2- or 3-D sizes and strides of tensors. -/// Depending on the dimension, the structure will have the `width`, -/// `height`, and `depth` fields available. +/// This structure describes a 1-, 2-, or 3-D extent for convolution +/// filters. Its used to communicate 1-, 2- or 3-D sizes and strides +/// of tensors, specifically for convolution filters. Depending on the +/// dimension, the structure will have the `width`, `height`, and +/// `depth` fields available. template -struct Extent; +struct FilterExtent; template <> -struct Extent<1> +struct FilterExtent<1> { size_t width = 1; }; template <> -struct Extent<2> +struct FilterExtent<2> { size_t width = 1; size_t height = 1; }; template <> -struct Extent<3> +struct FilterExtent<3> { size_t width = 1; size_t height = 1; diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp index 42f85f8017..3f5a9dd465 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp @@ -3,19 +3,15 @@ #pragma once +#include "ck_tile/builder/testing/error.hpp" +#include #include #include -#include -#include -#include -#include -#include "ck_tile/builder/conv_signature_concepts.hpp" -#include "ck_tile/builder/testing/type_traits.hpp" -#include "ck_tile/host/host_tensor.hpp" +#include -/// This file deals with tensor memory allocation: Both the act of allocating -/// and (automatically) deallocating memory, as well as utilities for managing -/// the layout of tensor data in memory. +/// This file deals with tensor memory management and allocation. The main +/// item is the `DeviceBuffer`: An owned piece of device memory, which is +/// automatically freed when it goes out of scope. namespace ck_tile::builder::test { @@ -39,31 +35,6 @@ struct DeviceMemoryDeleter } }; -/// @brief HIP out of memory error -/// -/// This is a derivation of `std::runtime_error` specialized for HIP -/// out-of-memory errors. -/// -/// @see std::runtime_error -struct OutOfDeviceMemoryError : std::runtime_error -{ - /// @brief Utility for formatting out-of-memory error messages - /// - /// Returns a human-readable description of a HIP out-of-memory error. - /// - /// @param status The status to report - static std::string format_error(hipError_t status) - { - return std::string("failed to allocate hip memory: ") + hipGetErrorString(status) + " (" + - std::to_string(status) + ")"; - } - - /// @brief Construct an out-of-memory error using `status` as message. - /// - /// @param status A HIP error status that was encountered while allocating memory. - OutOfDeviceMemoryError(hipError_t status) : std::runtime_error(format_error(status)) {} -}; - /// @brief Automatically managed GPU memory. /// /// The `DeviceBuffer` is an automatically managed pointer for GPU memory. When @@ -96,117 +67,29 @@ inline DeviceBuffer alloc_buffer(size_t size) std::byte* d_buf = nullptr; if(const auto status = hipMalloc(&d_buf, size); status != hipSuccess) { - throw OutOfDeviceMemoryError(status); + // Add some additional context + + size_t free, total; + check_hip("failed to get HIP memory info", hipMemGetInfo(&free, &total)); + + std::stringstream ss; + ss << "failed to allocate device memory (tried to allocate " << size << " bytes with only " + << free << " available)"; + + throw OutOfDeviceMemoryError(ss.str()); } return DeviceBuffer(d_buf); } -/// @brief Type managing tensor data layout in memory. +/// @brief "Align" an offset to a multiple of a particular alignment. /// -/// This structure describes a tensor in memory. It does not actually hold any -/// reference to memory, it just describes how the memory should be laid out if it -/// were. +/// Returns `addr` aligned to the next multiple of `alignment`. /// -/// @note This type is very much like ck_tile::HostTensorDescriptor, except that it -/// also includes the data type of the elements of htis tensor. This is mainly to -/// make the descriptor a _complete_ description of a tensor rather than just the -/// dimensions in strides, which helps in reducing clutter in uses of this type. -/// -/// @note All strides are still in _elements_. -/// -/// @tparam DT The conceptual data type of the tensor elements. This need not be the -/// type that the data is actually stored as in memory. -template -struct TensorDescriptor +/// @param addr The address to align. +/// @param alignment The alignment. +inline size_t align_fwd(size_t addr, size_t alignment) { - // For now, the implementation of this type is based on - // `ck_tile::HostTensorDescriptor`, so that we can prototype without - // reimplementing the `HostTensorDescriptor` for the 3rd time. You can regard - // the use of `ck_tile::HostTensorDescriptor` here as an implementation detail. - - /// The conceptual data type of the tensor elements. This need not be the type - /// that the data is actually stored as in memory. - constexpr static DataType data_type = DT; - - /// @brief Create a tensor descriptor from lengths and strides. - /// - /// @param lengths A sequence of tensor lengths, the conceptial dimensions of - /// the tensor in elements. - /// @param strides A sequence of in-memory strides of the tensor, measured in - /// elements. Each element of `strides`` corresponds to one at the same index - /// in `lengths`, the amount of elements to skip in memory to find the next - /// element along that axis. - TensorDescriptor(std::span lengths, std::span strides) - : inner_descriptor_(lengths, strides) - { - // TODO: Validation of strides? For now we just delegate the details of the - // construction to the CK Tile HostTensorDescriptor. - } - - /// Query the conceptual dimensions of the tensor. - /// - /// @returns A span of tensor dimensions, one for every axis. Note that the order - /// does *not* correspond with memory layout, query the in-memory strides for - /// that. - /// - /// @see get_strides() - std::span get_lengths() const { return inner_descriptor_.get_lengths(); } - - /// Query the in-memory strides of the tensor. - /// - /// @returns A span of tensor dimensions, one for every axis. Each element - /// corresponds directly with the stride in elements at the same index in the - /// tensor dimensions. - /// - /// @see get_lengths() - std::span get_strides() const { return inner_descriptor_.get_strides(); } - - /// @brief Compute total tensor size in elements. - /// - /// This function returns the total size of the memory backing a tensor with - /// this descriptor in *elements*, including required extra size for strides. - /// - /// @see get_element_space_size_in_bytes() - size_t get_element_space_size() const { return inner_descriptor_.get_element_space_size(); } - - /// @brief Compute total tensor size in bytes. - /// - /// This function is like `get_element_space_size()`, except that the returned - /// value is measured in *bytes* rather than *elements*. Use this function for - /// figuring out how much memory needs to be allocated for a particular tensor. - /// - /// @see get_element_space_size() - size_t get_element_space_size_in_bytes() const - { - // For now, the backing type is the naive C++-type that represents the data - // type. When we are going to support packed types such as i4 and fp6, this - // is going to become more complicated. - return get_element_space_size() * data_type_sizeof(DT); - } - - private: - ck_tile::HostTensorDescriptor inner_descriptor_; -}; - -/// @brief Allocate automatically managed GPU memory corresponding to a tensor descriptor. -/// -/// This function is similar to `alloc_buffer()`, except that the required size is -/// derived automatically from a tensor descriptor. The returned buffer is valid for -/// tensors with that layout. Strides are also taken into account when computing the -/// required size. -/// -/// @tparam DT The conceptual datatype of the elements of the tensor. -/// @param descriptor A descriptor of the memory layout of the tensor to allocate. -/// @throws OutOfDeviceMemoryError if memory allocation failed. -/// -/// @see TensorDescriptor -/// @see DeviceBuffer -/// @see OutOfDeviceMemoryError -/// @see hipMalloc() -template -DeviceBuffer alloc_tensor_buffer(const TensorDescriptor
& descriptor) -{ - return alloc_buffer(descriptor.get_element_space_size_in_bytes()); + return addr % alignment == 0 ? addr : addr - addr % alignment + alignment; } } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp new file mode 100644 index 0000000000..4c99f05c46 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp @@ -0,0 +1,502 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/testing/type_traits.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/host/host_tensor.hpp" + +/// This file deals with tensor memory layout. The `TensorDescriptor` is the +/// main item, which is a type that describes (but not manages!) the layout +/// of tensor memory. There are also some related utilities. + +namespace ck_tile::builder::test { + +/// @brief Tensor dimensions type +/// +/// An Extent describes size in tensor space, usually either the tensor lengths +/// (conceptual size) or the tensor strides (memory layout). This type is mainly +/// used by the `TensorDescriptor`. This type is based on `std::array` +/// and supports all relevant operations on that. +/// +/// @note In practical terms, this type is not just an alias of `std::array` for +/// two reasons: First, writing a separate type allows us to write a custom +/// CTAD deduction guideline. This allows users to write `Extent{1, 2, 3}` and +/// get an instance of the correct type, whereas `std::array{1, 2, 3}` yields an +/// instance of `std::array`. This, in turn, allows inferring the rank +/// from the instance (useful in combination with `make_descriptor`), as it alows +/// us to write `function(Extent{1, 2, 3})`. Note that `function({1, 2, 3})` is +/// not valid before C++26 because `{1, 2, 3}` is an initializer list (even if +/// `function` accepts an instance of `Extent`), which does not have a known size +/// at compile time. Second, creating a separate struct for the `Extent` allows +/// additional (static) member functions. +/// +/// @tparam RANK The rank (number of spatial dimensions) of the tensor that this +/// extent describes a size of. +/// +/// @see TensorDescriptor +/// @see make_descriptor +template +struct Extent : std::array +{ + using Base = std::array; + // Note: Default constructor inherited from std::array. + + /// @brief Construct an extent from an `std::vector`. + /// + /// This function can be used to turn an `std::vector` into an `Extent`. + /// Because this code is mainly intended for testing, the vector's size is + /// checked. If its not equal to `RANK`, an exception is thrown. + /// + /// @throws std::runtime_error if the size of `extent` is not equal to `RANK`. + static Extent from_vector(const std::vector& extent) + { + if(extent.size() != RANK) + { + std::stringstream msg; + msg << "invalid rank! expected: " << RANK << ", got: " << extent.size(); + throw std::runtime_error(msg.str()); + } + + Extent result; + std::copy_n(extent.begin(), RANK, result.begin()); + return result; + } + + // Note: std::array doesn't like generating indexing code when the RANK + // is zero. Looks like there is a missing __device__ overload in ROCm 7.1 + // at least. Its not terribly important, but just override the default + // operator[] to fix it. + + /// @brief Array indexing operator + /// + /// `std::array` has issues with this operator when RANK=0, this version + /// fixes that. + /// + /// @param i The index to index the array with. + /// + /// @see std::array::operator[] + __device__ __host__ size_t operator[](size_t i) const + { + if constexpr(RANK > 0) + { + return Base::operator[](i); + } + else + { + __builtin_unreachable(); + } + } + + /// @brief Array indexing operator + /// + /// `std::array` has issues with this operator when RANK=0, this version + /// fixes that. + /// + /// @param i The index to index the array with. + /// + /// @see std::array::operator[] + __device__ __host__ size_t& operator[](size_t i) + { + if constexpr(RANK > 0) + { + return Base::operator[](i); + } + else + { + __builtin_unreachable(); + } + } +}; + +// This is a deduction guideline necessary to resolve `Extent{1, 2, 3}` to the +// correct type. This definition is practically the same as that of `std::array`. +template +Extent(T...) -> Extent; + +/// @brief Extent printer +/// +/// This function implements an ostream printing overload for `Extent`, so that +/// they can be printed in the usual `stream << extent` fashion. +/// +/// @tparam RANK Rank (number of spatial dimensions) of the extent. +/// +/// @param stream The stream to print the extent to. +/// @param extent The extent to print to the stream. +template +std::ostream& operator<<(std::ostream& stream, const Extent& extent) +{ + stream << '['; + bool first = true; + for(const auto x : extent) + { + if(first) + first = false; + else + stream << ", "; + + stream << x; + } + + return stream << ']'; +} + +/// @brief Concept for automatically deriving tensor memory layout. +/// +/// A `TensorStridesGenerator` is a type which can be used to automatically +/// derive the strides (memory layout) of a tensor, given the tensor lengths. +/// This is mainly used to avoid manually computing strides. +/// +/// Implementors of this concept are required to implement `operator()`, +/// which accepts an instance of `Extent` (the tensor lengths) and +/// yields another instance of `Extent` (the tensor strides). Note +/// that the returned strides are expected to be "pre-scanned", meaning +/// that the offset in memory of a tensor can be computed as +/// `dot(index * strides)` (where `*` is element-wise multiplication). +/// +/// @see TensorDescriptor +/// @see PackedRightLayout +/// @see PackedLeftLayout +template +concept TensorStridesGenerator = requires(const G& generator, const Extent& lengths) { + { generator(lengths) } -> std::convertible_to>; +}; + +/// @brief Layout generator where right-most dimension has stride 1 and +/// all dimensions are packed. +/// +/// This structure implements a `TensorStridesGenerator` which generates +/// a memory layout which has the right-most dimension equal to 1, and +/// all other strides increase right-to-left as a products of the extent. +/// This corresponds with a row-major layout. +/// +/// @see TensorStridesGenerator +/// @see TensorDescriptor +struct PackedRightLayout +{ + /// @brief Stride generation implementation. + /// + /// This is the main function which implements the stride generation + /// + /// @tparam RANK The rank of the tensor. + /// + /// @param lengths The lengths of the tensor. + /// + /// @returns The tensor's memory layout according to the definition + /// of `PackedRightLayout`. + /// + /// @see TensorStridesGenerator + template + Extent operator()(const Extent& lengths) const + { + Extent strides = {}; + size_t numel = 1; + + for(size_t i = RANK; i > 0; --i) + { + strides[i - 1] = numel; + numel *= lengths[i - 1]; + } + + return strides; + } +}; +static_assert(TensorStridesGenerator, + "PackedRightLayout should be a TensorStridesGenerator!"); + +/// @brief Layout generator where left-most dimension has stride 1 and +/// all dimensions are packed. +/// +/// This structure implements a `TensorStridesGenerator` which generates +/// a memory layout which has the left-most dimension equal to 1, and +/// all other strides increase left-to-right as a products of the extent. +/// This corresponds with a column-major layout. +/// +/// @see TensorStridesGenerator +/// @see TensorDescriptor +struct PackedLeftLayout +{ + /// @brief Stride generation implementation. + /// + /// This is the main function which implements the stride generation + /// + /// @tparam RANK The rank of the tensor. + /// + /// @param lengths The lengths of the tensor. + /// + /// @returns The tensor's memory layout according to the definition + /// of `PackedLeftLayout`. + /// + /// @see TensorStridesGenerator + template + Extent operator()(const Extent& lengths) const + { + Extent strides = {}; + size_t numel = 1; + + for(size_t i = 0; i < RANK; ++i) + { + strides[i] = numel; + numel *= lengths[i]; + } + + return strides; + } +}; +static_assert(TensorStridesGenerator, + "PackedLeftLayout should be a TensorStridesGenerator!"); + +/// @brief Type managing tensor data layout in memory. +/// +/// This structure describes a tensor in memory. It does not actually hold any +/// reference to memory, it just describes how the memory should be laid out if it +/// were. +/// +/// @note This type is very much like ck_tile::HostTensorDescriptor, except that it +/// also includes the data type of the elements of htis tensor. This is mainly to +/// make the descriptor a _complete_ description of a tensor rather than just the +/// dimensions in strides, which helps in reducing clutter in uses of this type. +/// +/// @note All strides are still in _elements_. +/// +/// @tparam DT The conceptual data type of the tensor elements. This need not be the +/// type that the data is actually stored as in memory. +/// @tparam RANK The tensor "rank": the number of conceptial spatial dimensions that +/// the tensor covers. +template +struct TensorDescriptor +{ + // For now, the implementation of this type is based on + // `ck_tile::HostTensorDescriptor`, so that we can prototype without + // reimplementing the `HostTensorDescriptor` for the 3rd time. You can regard + // the use of `ck_tile::HostTensorDescriptor` here as an implementation detail. + + /// @brief Tensor extent alias + /// + /// This alias represents a std::array which holds tensor dimensions. There is one + /// item for each dimension in the tensor, and each item corresponds with the + /// value for that dimension. + using Extent = ::ck_tile::builder::test::Extent; + + /// The conceptual data type of the tensor elements. This need not be the type + /// that the data is actually stored as in memory. + constexpr static DataType data_type = DT; + + /// The tensor "rank": the number of conceptial spatial dimensions that the + /// tensor covers. + constexpr static size_t rank = RANK; + + /// @brief Create a tensor descriptor from lengths and strides. + /// + /// @param lengths A sequence of tensor lengths, the conceptial dimensions of + /// the tensor in elements. + /// @param strides A sequence of in-memory strides of the tensor, measured in + /// elements. Each element of `strides`` corresponds to one at the same index + /// in `lengths`, the amount of elements to skip in memory to find the next + /// element along that axis. + TensorDescriptor(const Extent& lengths, const Extent& strides) + : inner_descriptor_(lengths, strides) + { + // TODO: Validation of strides? For now we just delegate the details of the + // construction to the CK Tile HostTensorDescriptor. + } + + /// @brief Create a tensor descriptor with lengths and automatic layout. + /// + /// This function initializes a tensor descriptor using lengths, and by deriving + /// the memory layout from the layout generator `Generator`. The tensor will be + /// initialized with the strides yielded from `Generator`. + /// + /// @tparam Generator The generator type to generate the strides with. For example, + /// `PackedRightLayout` or `PackedLeftLayout`. + /// + /// @param lengths A sequence of tensor lengths, the conceptial dimensions of + /// the tensor in elements. + /// @param gen An instance of `Generator` to generate the strides with. + /// + /// @see TensorStridesGenerator + /// @see PackedLeftLayout + /// @see PackedRightLayout + template + requires TensorStridesGenerator + TensorDescriptor(const Extent& lengths, const Generator& gen) + : TensorDescriptor(lengths, gen(lengths)) + { + } + + /// Query the conceptual dimensions of the tensor. + /// + /// @returns A span of tensor dimensions, one for every axis. Note that the order + /// does *not* correspond with memory layout, query the in-memory strides for that. + /// + /// @see get_strides() + Extent get_lengths() const + { + // TODO: This is ugly for now. We should ditch the HostTensorDescriptor, and + // after that this can just be `return lengths_;` (and make it const Extent&). + Extent result; + std::copy_n(inner_descriptor_.get_lengths().begin(), RANK, result.begin()); + return result; + } + + /// Query the in-memory strides of the tensor. + /// + /// @returns A span of tensor dimensions, one for every axis. Each element + /// corresponds directly with the stride in elements at the same index in the + /// tensor dimensions. + /// + /// @see get_lengths() + Extent get_strides() const + { + // TODO: This is ugly for now. We should ditch the HostTensorDescriptor, and + // after that this can just be `return strides_;` (and make it const Extent&). + Extent result; + std::copy_n(inner_descriptor_.get_strides().begin(), RANK, result.begin()); + return result; + } + + /// @brief Compute conceptual tensor size in elements. + /// + /// This function returns the size of the tensor in elements. This function only + /// takes the lengths into account, not the strides. In order to allocate memory + /// for the tensor, use `get_element_space_size()`. + /// + /// @see get_lengths + /// @see get_element_space_size + size_t get_element_size() const { return inner_descriptor_.get_element_size(); } + + /// @brief Compute total tensor space size in elements. + /// + /// This function returns the total size of the memory backing a tensor with + /// this descriptor in *elements*, including required extra size for strides. + /// + /// @see get_element_space_size_in_bytes() + size_t get_element_space_size() const { return inner_descriptor_.get_element_space_size(); } + + /// @brief Compute total tensor size in bytes. + /// + /// This function is like `get_element_space_size()`, except that the returned + /// value is measured in *bytes* rather than *elements*. Use this function for + /// figuring out how much memory needs to be allocated for a particular tensor. + /// + /// @see get_element_space_size() + size_t get_element_space_size_in_bytes() const + { + // For now, the backing type is the naive C++-type that represents the data + // type. When we are going to support packed types such as i4 and fp6, this + // is going to become more complicated. + return get_element_space_size() * data_type_sizeof(DT); + } + + /// @brief Check if a tensor is packed in memory. + /// + /// This function checks whether the tensor memory is "packed", that is, whether + /// all elements are continuous in memory with no gaps. + bool is_packed() const + { + // First sort by stride, then check if they match the scan of the + // sizes. + const auto& lengths = inner_descriptor_.get_lengths(); + const auto& strides = inner_descriptor_.get_strides(); + + std::array indices; + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](auto i, auto j) { + return strides[i] < strides[j]; + }); + + size_t x = 1; + for(size_t i = 0; i < RANK; ++i) + { + if(strides[indices[i]] != x) + return false; + + x *= lengths[indices[i]]; + } + + return true; + } + + /// @brief Get a tensor descriptor for the space backing a tensor. + /// + /// This function returns a tensor descriptor which represents the buffer space + /// required to a tensor with this descriptor. This is mainly useful to process + /// buffers with functions which normally operate over tensor descriptors. The + /// resulting tensor descriptor describes a 1D tensor with the same number of + /// elements as in the space. + /// + /// @see get_element_space_size() + TensorDescriptor get_space_descriptor() const + { + ck_tile::builder::test::Extent<1> lengths = {this->get_element_space_size()}; + ck_tile::builder::test::Extent<1> strides = {1}; + return TensorDescriptor(lengths, strides); + } + + private: + ck_tile::HostTensorDescriptor inner_descriptor_; +}; + +/// @brief Tensor descriptor construction helper. +/// +/// This function can be used to create a tensor descriptor. It accepts the same +/// parameters as the constructor of `TensorDescriptor`, that is, a sequence of +/// lengths and a sequence of strides (or a generator to generate the strides). +/// The main use of this function is that it allows automatic inference of the `RANK` +/// parameter. C++ constructors do not allow partial specification of type parameters, +/// and so its impossible to write `TensorDescriptor
x(Extent{1, 2, 3}, ...)` +/// and have the `RANK` be automatically inferred. Functions do allow this though, +/// so this function can be used to write `make_descriptor(Extent{1, 2, 3}, ...)` +/// +/// @tparam DT The conceptual data type of the tensor elements. This need not be the +/// type that the data is actually stored as in memory. +/// @tparam RANK The tensor "rank": the number of conceptial spatial dimensions that +/// the tensor covers. +/// +/// @param lengths A sequence of tensor lengths, the conceptial dimensions of +/// the tensor in elements. +/// @param strides A sequence of in-memory strides of the tensor, or a generator +/// to generate those strides from the tensor lengths. +/// +/// @see TensorDescriptor +template +TensorDescriptor make_descriptor(const Extent& lengths, const auto& strides) +{ + return TensorDescriptor(lengths, strides); +} + +/// @brief Allocate automatically managed GPU memory corresponding to a tensor descriptor. +/// +/// This function is similar to `alloc_buffer()`, except that the required size is +/// derived automatically from a tensor descriptor. The returned buffer is valid for +/// tensors with that layout. Strides are also taken into account when computing the +/// required size. +/// +/// @tparam DT The conceptual datatype of the elements of the tensor. +/// @tparam RANK The conceptual rank (number of dimensions) of the tensor. +/// +/// @param descriptor A descriptor of the memory layout of the tensor to allocate. +/// +/// @throws OutOfDeviceMemoryError if memory allocation failed. +/// +/// @see TensorDescriptor +/// @see DeviceBuffer +/// @see OutOfDeviceMemoryError +/// @see hipMalloc() +template +DeviceBuffer alloc_tensor_buffer(const TensorDescriptor& descriptor) +{ + return alloc_buffer(descriptor.get_element_space_size_in_bytes()); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp new file mode 100644 index 0000000000..28ab954de9 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp @@ -0,0 +1,351 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include +#include +#include + +/// This file implements a generic GPU tensor "foreach" function. This +/// functionality turned out useful in separate parts of the testing +/// system, hence its implemented in a separate file. This version is +/// not particularly efficient (but it should at least be readable), +/// but it should be easy to replace the implementation in the future, +/// should that be needed. + +namespace ck_tile::builder::test { + +/// @brief Utility structure for N-dimensional iteration using a flat index +/// +/// This structure's main purpose is to "unmerge" a flattened index into a +/// multi-dimensional index, which helps when iterating over multi-dimensional +/// indices without having to write an arbitrary amount of nested for loops. +/// A minimal amount of precomputation must be done to do this efficiently, +/// which is handled in the constructor of this type. +/// +/// @details Decoding a flat index into a multi-dimensional index is done by +/// first computing a reverse scan of the shape. These values can then be +/// used to decode the index in the usual way: +/// +/// x = flat_idx / (size_y * size_z) +/// y = flat_idx % (size_y * size_z) / size_z +/// z = flat_idx % (size_y * size_z) % size_z +/// etc +/// +/// The decode order is such that the innermost dimension (right in +/// the shape extent) changes the fastest. +/// +/// @tparam RANK The rank (number of spatial dimensions) of the tensor to +/// iterate. +template +struct NdIter +{ + /// @brief Prepare N-dimensional iteration over a particular shape. + /// + /// Precompute ashape into a form that can be used to easily decode a flat + /// index into a multi-dimensional index. + /// + /// @param shape The shape to iterate over. + explicit NdIter(const Extent& shape) + { + // Precompute shape_scan = [..., shape[-2] * shape[-1], shape[-1], 1] + + numel_ = 1; + for(int i = RANK; i > 0; --i) + { + shape_scan_[i - 1] = numel_; + numel_ *= shape[i - 1]; + } + } + + /// @brief Unflatten a flat index into a multi-dimensional index + /// + /// This applies the usual multi-dimensional indexing method over the + /// precomputed shape scan to get back a multi-dimensional index. + /// The decode order is such that the innermost dimension (right in + /// the shape extent) changes the fastest. + /// + /// @param flat_index The "flattened" (1-dimensional) index of the tensor + /// + /// @returns A multi-dimensional index into the tensor + /// + /// @pre `0 <= flat_index < size()` (in other words, the `flat_index` must + /// be in bounds of the tensor shape that this `NdIter` was made from). + __host__ __device__ Extent operator()(size_t flat_index) const + { + Extent index = {}; + auto idx = flat_index; + for(size_t i = 0; i < RANK; ++i) + { + const auto scanned_dim = shape_scan_[i]; + index[i] = idx / scanned_dim; + idx %= scanned_dim; + } + + return index; + } + + /// @brief Return the total elements to iterate over + /// + /// Get the total number of elements in the shape to iterate over. This value + /// can be used to construct a complete for loop to iterate over all indices + /// of a tensor, for example: + /// + /// for(size_t i = 0; i < iter.numel(); ++i) + /// { + /// const auto index = iter(i); + /// use(index); + /// } + __host__ __device__ size_t numel() const { return numel_; } + + private: + /// Reverse (right) scan of the shape to iterate over. + Extent shape_scan_; + + /// The total number of elements in the shape. This value turns out to be almost + /// always required when iterating over a shape, so just store it in this type + /// so that it is easily accessible. + size_t numel_; +}; + +template +NdIter(Extent) -> NdIter; + +/// @brief Concept for constraining tensor iteration functors. +/// +/// This concept checks that a functor has the correct signature for +/// use with the `tensor_foreach` function. +template +concept ForeachFunctor = requires(const F& f, const Extent& index) { + { f(index) } -> std::same_as; +}; + +namespace detail { + +/// @brief Default foreach kernel block size +/// +/// This value is the default number of threads in each block when +/// executing the foreach kernel. This value is mostly arbitrary, +/// 256 is usually a good default for AMD GPUs. +/// +/// @see tensor_foreach +constexpr int DEVICE_FOREACH_BLOCK_SIZE = 256; + +/// @brief Tensor iteration kernel +/// +/// This kernel implements the actual iteration logic, and is intended +/// to be used solely by `tensor_foreach` to iterate & invoke the +/// actual callback. +/// +/// @tparam BLOCK_SIZE The number of threads in each block on the GPU. +/// @tparam RANK The rank (number of spatial dimensions) of the tensor to +/// iterate. +/// @tparam F The type of the callback to invoke. This function must be +/// compatible with execution as a __device__ function. +/// +/// @param iter An NdIter instance to help iterating over the tensor. +/// @param f The callback to invoke for each index of the tensor. This +/// functor must be eligible for running on the GPU. +template + requires ForeachFunctor +__global__ __launch_bounds__(BLOCK_SIZE) // + void foreach_kernel(NdIter iter, F f) +{ + const auto gid = blockIdx.x * BLOCK_SIZE + threadIdx.x; + for(size_t flat_idx = gid; flat_idx < iter.numel(); flat_idx += gridDim.x * BLOCK_SIZE) + { + // Compute the current index. + const auto index = iter(flat_idx); + + // Then invoke the callback with the index. + f(index); + } +} + +/// @brief A utility to get a C++ type for a CKB type +/// +/// Right now this is just an alias of an internal CKB helper, +/// but this should probably be moved elsewhere. +template +using cpp_type_t = typename builder::factory::internal::DataTypeToCK
::type; + +} // namespace detail + +/// @brief Calculate tensor memory offset given index and strides. +/// +/// This function returns the offset in memory in a tensor, given a particular +/// multi-dimensional index and a particular set of strides. Each value in the +/// index corresponds one-to-one with a value in the strides, which are the +/// index and stride at that dimension in the tensor. These strides must be +/// pre-scanned, meaning that each index is the absolute stride of elements +/// along that axis. In essence, this means that you should pass the output of +/// `TensorDescriptor::get_strides()` into this function. +/// +/// @pre The index must be inside the tensor space. +/// +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param index A multi-dimensional index inside the tensor space. +/// @param strides A set of strides, one for each dimension. +/// +/// @see TensorDescriptor +template +__host__ __device__ size_t calculate_offset(const Extent& index, const Extent& strides) +{ + size_t offset = 0; +#pragma unroll + for(size_t i = 0; i < RANK; ++i) + { + offset += index[i] * strides[i]; + } + return offset; +} + +/// @brief Invoke a callback on the GPU for every index in a tensor. +/// +/// This function invokes a callback functor on the GPU, for each index in +/// a tensor. This function _only_ takes care of iterating over all indices +/// in a tensor of a particular shape; this function does not handle or know +/// about actual tensor data. +/// +/// @note This function is currently implemented relatively naively: The +/// iteration order is always row-wise, implemented as a persistent kernel. +/// The main objective of this function is to be used with the CK-Builder +/// testing system, and so readability and correctness should be preferred +/// over performance. If this is ever a source of performance problems, +/// feel free to replace the implementation with something better. +/// +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param shape The shape of the tensor to iterate over. +/// @param f The callback to invoke for each index of the tensor. This +/// functor must be eligible for running on the GPU. +/// +/// @see ForeachFunctor +/// @see detail::foreach_kernel +template +void tensor_foreach(const Extent& shape, ForeachFunctor auto f) +{ + constexpr int block_size = detail::DEVICE_FOREACH_BLOCK_SIZE; + const auto kernel = detail::foreach_kernel; + + int occupancy; + check_hip(hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, block_size, 0)); + + int device; + check_hip(hipGetDevice(&device)); + + int multiprocessors; + check_hip( + hipDeviceGetAttribute(&multiprocessors, hipDeviceAttributeMultiprocessorCount, device)); + + // Pre-scan the shape to help indexing in the kernel. + // Note: the order is not that important, so long as the iteration + // order in the kernel is from large-to-small. Right layout is the + // easiest solution for that. + + NdIter iter(shape); + + // Reset any errors from previous launches. + (void)hipGetLastError(); + + kernel<<>>(iter, f); + check_hip(hipGetLastError()); +} + +/// @brief Concept for tensor initializing functors. +/// +/// This concept checks that a functor has the correct signature for +/// use with the `fill_tensor` function. +template +concept FillTensorFunctor = requires(const F& f, const Extent& index) { + { f(index) } -> std::convertible_to>; +}; + +/// @brief Utility for initializing tensors. +/// +/// This function is a utility helper for initializing tensors. It accepts a +/// tensor descriptor, buffer, and a callback. The callback is invoked for every +/// coordinate (which is passed to the callback), and the tensor is initialized +/// with resulting value. +/// +/// @tparam DT The tensor element datatype +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param desc The descriptor of the tensor to initialize. +/// @param buffer The memory of the tensor to initialize. +/// @param f A functor used to get the value at a particular coordinate. +/// +/// @see FillTensorFunctor +template +void fill_tensor(const TensorDescriptor& desc, + void* buffer, + FillTensorFunctor auto f) +{ + const auto strides = desc.get_strides(); + tensor_foreach(desc.get_lengths(), [buffer, f, strides](const auto& index) { + using T = detail::cpp_type_t
; + auto* ptr = static_cast(buffer); + const auto offset = calculate_offset(index, strides); + + ptr[offset] = f(index); + }); +} + +/// @brief Concept for tensor buffer initializing functors. +/// +/// This concept checks that a functor has the correct signature for +/// use with the `fill_tensor_buffer` function. +template +concept FillTensorBufferFunctor = requires(const F& f, size_t index) { + { f(index) } -> std::convertible_to>; +}; + +/// @brief Utility for initializing tensor buffers. +/// +/// This function is a utility for initializing memory backing a tensor buffer. In +/// contrast to `fill_tensor`, this function first extracts the backing space of +/// the tensor, and then invokes the callback for each (flat) index. This function +/// is particular useful for initializing out-of-bounds indices with a known with a +/// known value. +/// +/// @tparam DT The tensor element datatype +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param desc The descriptor of the tensor to initialize. +/// @param buffer The memory of the tensor to initialize. +/// @param f A functor used to get the value at a particular index. +/// +/// @see FillTensorBufferFunctor +template +void fill_tensor_buffer(const TensorDescriptor& desc, + void* buffer, + FillTensorBufferFunctor
auto f) +{ + fill_tensor(desc.get_space_descriptor(), buffer, [f](auto index) { return f(index[0]); }); +} + +/// @brief Utility for clearing tensor buffers to a particular value. +/// +/// This function initializes all memory backing a particular tensor buffer to +/// one specific value, zero by default. Note that this function ignores strides, +/// and clears the entire buffer backing the tensor. +/// +/// @tparam DT The tensor element datatype +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param desc The descriptor of the tensor to initialize. +/// @param buffer The memory of the tensor to initialize. +/// @param value The value to initialize the tensor buffer with. +template +void clear_tensor_buffer(const TensorDescriptor& desc, + void* buffer, + detail::cpp_type_t
value = detail::cpp_type_t
{0}) +{ + fill_tensor_buffer(desc, buffer, [value]([[maybe_unused]] size_t i) { return value; }); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp index 15cb43f369..2976e6c14b 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp @@ -19,15 +19,30 @@ namespace ck_tile::builder::test { -template -void init_tensor_buffer_uniform_int(const DeviceBuffer& buf, - const TensorDescriptor
& descriptor, - int min_val, - int max_val) +/// @brief Initialize tensor data with a uniform int distribution +/// +/// This function initializes a tensor's device memory with random integer data, +/// drawn from a uniform distribution. The initialization is done directly on the +/// GPU. Note that the entire buffer is filled with the specified distribution +/// regardless of whether the layout is packed. +/// +/// @tparam DT The data type of the tensor memory to initialize +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param buf The device memory to initialize +/// @param descriptor A tensor descriptor describing the precise layout of the +/// tensor memory. +/// @param min_value The minimum value of the distribution (inclusive). +/// @param max_value The maximum value of the distribution (exclusive). +template +void init_tensor_buffer_uniform_int(void* buf, + const TensorDescriptor& descriptor, + int min_value, + int max_value) { size_t size = descriptor.get_element_space_size_in_bytes(); - if(max_val - min_val <= 1) + if(max_value - min_value <= 1) { throw std::runtime_error("Error while filling device tensor with random integer data: max " "value must be at least 2 greater than min value, otherwise " @@ -38,19 +53,34 @@ void init_tensor_buffer_uniform_int(const DeviceBuffer& buf, // we might be asked to generate int values on fp data types that don't have the required // precision - if(static_cast(max_val - 1) == static_cast(min_val)) + if(static_cast(max_value - 1) == static_cast(min_value)) { throw std::runtime_error("Error while filling device tensor with random integer data: " "insufficient precision in specified range"); } size_t packed_size = ck::packed_size_v; fill_tensor_uniform_rand_int_values<<<256, 256>>>( - static_cast(buf.get()), min_val, max_val, (size * packed_size) / sizeof(ck_type)); + static_cast(buf), min_value, max_value, (size * packed_size) / sizeof(ck_type)); } -template -void init_tensor_buffer_uniform_fp(const DeviceBuffer& buf, - const TensorDescriptor
& descriptor, +/// @brief Initialize tensor data with a uniform float distribution +/// +/// This function initializes a tensor's device memory with random floating data, +/// drawn from a uniform distribution. The initialization is done directly on the +/// GPU. Note that the entire buffer is filled with the specified distribution +/// regardless of whether the layout is packed. +/// +/// @tparam DT The data type of the tensor memory to initialize +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param buf The device memory to initialize +/// @param descriptor A tensor descriptor describing the precise layout of the +/// tensor memory. +/// @param min_value The minimum value of the distribution (inclusive). +/// @param max_value The maximum value of the distribution (exclusive). +template +void init_tensor_buffer_uniform_fp(void* buf, + const TensorDescriptor& descriptor, float min_value, float max_value) { @@ -59,15 +89,30 @@ void init_tensor_buffer_uniform_fp(const DeviceBuffer& buf, using ck_type = factory::internal::DataTypeToCK
::type; size_t packed_size = ck::packed_size_v; - fill_tensor_uniform_rand_fp_values<<<256, 256>>>(reinterpret_cast(buf.get()), + fill_tensor_uniform_rand_fp_values<<<256, 256>>>(reinterpret_cast(buf), min_value, max_value, (size * packed_size) / sizeof(ck_type)); } -template -void init_tensor_buffer_normal_fp(const DeviceBuffer& buf, - const TensorDescriptor
& descriptor, +/// @brief Initialize tensor data with a normal float distribution +/// +/// This function initializes a tensor's device memory with random floating data, +/// drawn from a normal distribution. The initialization is done directly on the +/// GPU. Note that the entire buffer is filled with the specified distribution +/// regardless of whether the layout is packed. +/// +/// @tparam DT The data type of the tensor memory to initialize +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param buf The device memory to initialize +/// @param descriptor A tensor descriptor describing the precise layout of the +/// tensor memory. +/// @param sigma The standard deviation of the distribution. +/// @param mean The mean of the distribution. +template +void init_tensor_buffer_normal_fp(void* buf, + const TensorDescriptor& descriptor, float sigma, float mean) { @@ -76,7 +121,7 @@ void init_tensor_buffer_normal_fp(const DeviceBuffer& buf, using ck_type = factory::internal::DataTypeToCK
::type; size_t packed_size = ck::packed_size_v; fill_tensor_norm_rand_fp_values<<<256, 256>>>( - static_cast(buf.get()), sigma, mean, (size * packed_size) / sizeof(ck_type)); + static_cast(buf), sigma, mean, (size * packed_size) / sizeof(ck_type)); } } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/testing.hpp b/experimental/builder/include/ck_tile/builder/testing/testing.hpp index a0dfa27409..eb16402bc2 100644 --- a/experimental/builder/include/ck_tile/builder/testing/testing.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/testing.hpp @@ -5,6 +5,10 @@ #include +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/validation.hpp" + /// This file is the main header for the CK-Builder testing system. A high-level /// description of this testing system is documented in /// `ck_tile/builder/testing/README.md`. This file deals mainly deals with the @@ -78,7 +82,7 @@ namespace ck_tile::builder::test { /// that this structure is an aggregrate so that it can be initialized using C++20 /// designated initializers to keep the tests readable. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. template struct Args; @@ -98,7 +102,7 @@ struct Args; /// structure is an aggregrate so that it can be initialized using C++20 /// designated initializers to keep the tests readable. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. template struct Inputs; @@ -118,7 +122,7 @@ struct Inputs; /// structure is an aggregrate so that it can be initialized using C++20 /// designated initializers to keep the tests readable. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. template struct Outputs; @@ -130,10 +134,10 @@ struct Outputs; /// be created using `alloc_inputs()` and that an instance of the corresponding /// `Inputs` structure can be obtained using `.get()`. /// -/// @note The easiest way to implement this type is to use the `DeviceBuffer` -/// type to allocate individual device buffers for each input tensor. +/// @note A default implementation is provided for this type if `Inputs` +/// supports `TensorReflectable`. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. /// /// @see alloc_inputs() /// @see ValidUniqueInputs @@ -149,10 +153,10 @@ struct UniqueInputs; /// be created using `alloc_outputs()` and that an instance of the corresponding /// `Outputs` structure can be obtained using `.get()`. /// -/// @note The easiest way to implement this type is to use the `DeviceBuffer` -/// type to allocate individual device buffers for each output tensor. +/// @note A default implementation is provided for this type if `Outputs` +/// supports `TensorReflectable`. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. /// /// @see alloc_outputs() /// @see ValidUniqueOutputs @@ -195,7 +199,15 @@ concept ValidUniqueOutputs = requires(UniqueOutputs& inputs) { /// amount of memory required and then allocate it on the device, for example /// using `alloc_buffer` or `alloc_tensor_buffer`. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. +/// +/// @note A default implementation is provided for this function if `Inputs` +/// supports `TensorReflectable`. +/// +/// @tparam SIGNATURE The signature to specialize the structure for. +/// +/// @param args The run-time arguments of the operation. /// /// @see Inputs /// @see UniqueInputs @@ -203,21 +215,26 @@ concept ValidUniqueOutputs = requires(UniqueOutputs& inputs) { /// @see alloc_tensor_buffer() template requires ValidUniqueInputs -UniqueInputs alloc_inputs(const Args& args); +UniqueInputs alloc_inputs(const Args& args) = delete; -/// @brief Allocate inputs corresponding to a signature. +/// @brief Initialize inputs corresponding to a signature. /// /// The `init_inputs()` function is used to initialize pseudo-random data -/// to the tensors specified in the Inputs structure. +/// to the tensors specified in the Inputs structure. Implementors should +/// fill each of the tensors in `inputs` with appropriate random data. +/// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. /// /// @tparam SIGNATURE the signature to specialize the structure for. /// +/// @param args The run-time arguments of the operation. +/// @param inputs The operation inputs to initialize with random data. +/// /// @see Inputs -/// @see UniqueInputs /// @see tensor_initialization template - requires ValidUniqueInputs -void init_inputs(const Args& args, UniqueInputs& inputs); +void init_inputs(const Args& args, Inputs inputs) = delete; /// @brief Allocate outputs corresponding to a signature. /// @@ -226,7 +243,15 @@ void init_inputs(const Args& args, UniqueInputs& inputs); /// amount of memory required and then allocate it on the device, for example /// using `alloc_buffer` or `alloc_tensor_buffer`. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. +/// +/// @note A default implementation is provided for this function if `Outputs` +/// supports `TensorReflectable`. +/// +/// @tparam SIGNATURE The signature to specialize the structure for. +/// +/// @param args The run-time arguments of the operation. /// /// @see Outputs /// @see UniqueOutputs @@ -234,7 +259,34 @@ void init_inputs(const Args& args, UniqueInputs& inputs); /// @see alloc_tensor_buffer() template requires ValidUniqueOutputs -UniqueInputs alloc_outputs(const Args& args); +UniqueInputs alloc_outputs(const Args& args) = delete; + +/// @brief Compare device operation outputs. +/// +/// This function implements the main comparison functionality, used to compare +/// the output of one implementation for a particular `SIGNATURE` with that of +/// another. Usually, the `expected` output should be computed by a reference +/// implementation. +/// +/// The implementation of this function generates a "report", which includes +/// detailed information about which tensors are different, how many elements +/// were incorrect, and where (a subset of) those elements are located within +/// the tensor. See `ValidationReport` for more information about the report. +/// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. +/// +/// @tparam SIGNATURE The signature to specialize the structure for. +/// +/// @param args The run-time arguments of the operation. +/// @param actual The actual results, the results of the operation to-be-tested. +/// @param expected The expected results, the results of the reference implementation. +/// +/// @see ValidationReport +template +ValidationReport validate(const Args& args, + Outputs actual, + Outputs expected) = delete; /// @brief Invoke a device operation created by CK Builder. /// @@ -257,7 +309,7 @@ UniqueInputs alloc_outputs(const Args& args); /// @post The tensors in `outputs` are overwritten with the outputs of the device /// operation. /// -/// @tparam SIGNATURE the signature to specialize this function for +/// @tparam SIGNATURE The signature to specialize this function for /// @tparam Operation the kernel of the operation to invoke. This type should be /// one that is created using the Builder API. /// @param operation An instance of the operation to invoke. @@ -265,10 +317,13 @@ UniqueInputs alloc_outputs(const Args& args); /// @param inputs The input tensor data. Will not be modified by this function. /// @param outputs The output tensor data. The contents will be overwritten by /// this function. +/// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. template void run(Operation& operation, const Args& args, const Inputs& inputs, - const Outputs& outputs); + const Outputs& outputs) = delete; } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp b/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp new file mode 100644 index 0000000000..81d5b7a6f5 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp @@ -0,0 +1,199 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +/// testing.hpp requires developers of a type of SIGNATURE to implement +/// quite a lot of functionality for each SIGNATURE. For example, next +/// to `Args`, `Inputs`, `Outputs`, `run`, they also have to define +/// `UniqueInputs`, `UniqueOutputs`, `alloc_inputs`, `alloc_outputs`, +/// and `validate`. The implementation of these latter few functions +/// is usually quite straight forward and adds a bunch of copy-paste +/// overhead. The functionality in this file offers an alternative +/// route: By implementing some reflection functionality in `Inputs` +/// and `Outputs`, we can automatically derive most of the functionality. + +namespace ck_tile::builder::test { + +/// @brief Check whether an `Input` or `Output` struct can be reflected. +/// +/// In order to avoid having to manually redefine a bunch of types related to +/// each `Inputs`/`Outputs` structure, those structures can also provide some +/// "reflection" functionality. To this end, they should implement +/// `static void reflect(const Args args&, auto inspect)`, where `inspect` +/// is called with information about each field in the struct. In more detail, +/// the signature of the `inspect` function is as follows: +/// +/// void inspect( +/// // A human-readable name for the tensor +/// std::string_view name, +/// // Descriptor for the tensor in memory, usually obtained via `args`. +/// const TensorDescriptor& desc, +/// // Member pointer to a field of `T`, which is a GPU-memory pointer +/// // to the relevant tensor memory. +/// void* T::* ptr); +/// +/// Here, `T` is `Inputs` or `Outputs`. +/// +/// @see Inputs +/// @see Outputs +template +concept TensorReflectable = requires(const Args& args) { + { + T::reflect(args, + []([[maybe_unused]] std::string_view name, + // Note: This will be a TensorDescriptor, but the actual + // DT and RANK may differ depending on member. + [[maybe_unused]] const auto& desc, + [[maybe_unused]] void* T::*ptr) {}) + }; +}; + +namespace detail { + +/// The default alignment between tensors allocated separately +/// by `UniqueTensors`. This should be large enough to accomodate +/// any type. hipMalloc returns an alignment of 256 by default. +constexpr size_t TENSOR_ALIGNMENT = 256; + +/// @brief Common type for automatically managing memory of sets of tensors. +/// +/// This type implements the automatic memory management logic for `Inputs` and +/// `Outputs` that support reflection. +/// +/// @tparam SIGNATURE The signature to specialize the structure for. +/// @tparam Tensors The `Inputs` or `Outputs` type corresponding to `SIGNATURE`. +template + requires TensorReflectable +struct UniqueTensors +{ + /// @brief Allocate tensors. + /// + /// This function computes the total size of memory to allocate according to + /// the tensors in `args`, and then allocates it as a continuous buffer. + /// + /// @param args The run-time arguments of the operation. + explicit UniqueTensors(const Args& args) + { + // First compute the total size of all tensors combined + size_t total_size = 0; + Tensors::reflect(args, + [&, this]([[maybe_unused]] std::string_view name, + const auto& desc, + [[maybe_unused]] void* Tensors::*ptr) { + total_size = align_fwd(total_size, TENSOR_ALIGNMENT); + total_size += desc.get_element_space_size_in_bytes(); + }); + + data_ = alloc_buffer(total_size); + + // Now assign the pointers based on the same offsets that + // we computed in the first loop. + size_t offset = 0; + Tensors::reflect(args, + [&, this]([[maybe_unused]] std::string_view name, + const auto& desc, + [[maybe_unused]] void* Tensors::*ptr) { + offset = align_fwd(offset, TENSOR_ALIGNMENT); + tensors_.*ptr = data_.get() + offset; + offset += desc.get_element_space_size_in_bytes(); + }); + } + + /// @brief Return raw `Inputs` or `Outputs` type. + /// + /// @see ValidUniqueInputs + /// @see ValidUniqueOutputs + Tensors get() const { return tensors_; } + + private: + /// Owning pointer of input memory + DeviceBuffer data_; + /// Struct with pointers to each tensor. Stored here so that we + /// don't need to keep recomputing it. + Tensors tensors_; +}; + +} // namespace detail + +/// @brief Implementation of `UniqueInputs` for `Inputs` that support reflection. +/// +/// @tparam SIGNATURE The signature to specialize for. +/// +/// @see UniqueInputs +template + requires TensorReflectable, SIGNATURE> +struct UniqueInputs : detail::UniqueTensors> +{ + using detail::UniqueTensors>::UniqueTensors; +}; + +/// @brief Implementation of `UniqueOutputs` for `Outputs` that support reflection. +/// +/// @tparam SIGNATURE The signature to specialize for. +/// +/// @see UniqueOutputs +template + requires TensorReflectable, SIGNATURE> +struct UniqueOutputs : detail::UniqueTensors> +{ + using detail::UniqueTensors>::UniqueTensors; +}; + +/// @brief Implementation of `alloc_inputs` for `Inputs` that support reflection. +/// +/// @tparam SIGNATURE The signature to specialize for. +/// +/// @param args The run-time arguments of the operation. +/// +/// @see alloc_inputs +template + requires TensorReflectable, SIGNATURE> +UniqueInputs alloc_inputs(const Args& args) +{ + static_assert(ValidUniqueInputs, "sanity check"); + return UniqueInputs(args); +} + +/// @brief Implementation of `alloc_outputs` for `Outputs` that support reflection. +/// +/// @tparam SIGNATURE The signature to specialize for. +/// +/// @param args The run-time arguments of the operation. +/// +/// @see alloc_outputs +template + requires TensorReflectable, SIGNATURE> +UniqueOutputs alloc_outputs(const Args& args) +{ + static_assert(ValidUniqueOutputs, "sanity check"); + return UniqueOutputs(args); +} + +/// @brief Implementation of `validate` for `Outputs` that support reflection. +/// +/// @tparam SIGNATURE The signature to specialize for. +/// +/// @param args The run-time arguments of the operation. +/// @param actual The actual results, the results of the operation to-be-tested. +/// @param expected The expected results, the results of the reference implementation. +/// +/// @see alloc_outputs +template + requires TensorReflectable, SIGNATURE> +ValidationReport +validate(const Args& args, Outputs actual, Outputs expected) +{ + ValidationReport report; + + Outputs::reflect( + args, [&](std::string_view name, const auto& desc, void* Outputs::*ptr) { + report.check(name, desc, actual.*ptr, expected.*ptr); + }); + + return report; +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp b/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp index 8db0e5d25d..4026642bd0 100644 --- a/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp @@ -39,7 +39,7 @@ constexpr size_t data_type_sizeof(DataType data_type) case DataType::FP8: return 1; case DataType::BF8: return 1; case DataType::FP64: return 8; - case DataType::INT32: return 4; + case DataType::I32: return 4; case DataType::I8: return 1; case DataType::I8_I8: return 2; case DataType::U8: return 1; diff --git a/experimental/builder/include/ck_tile/builder/testing/validation.hpp b/experimental/builder/include/ck_tile/builder/testing/validation.hpp new file mode 100644 index 0000000000..158f271e21 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/validation.hpp @@ -0,0 +1,204 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/error.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_foreach.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck/utility/type_convert.hpp" +#include +#include +#include +#include +#include + +/// This file implements functionality related to "validation", ie, functionality +/// to compare tensors. The functionality in this file should be testing-framework +/// agnostic, and it should NOT generate any error messages by itself. Instead, +/// all relevant information should be stored in the `ValidationReport` structure. +/// This structure should then be used to generate error messages, explainations, +/// etc, by the actual testing framework that the user has chosen. + +namespace ck_tile::builder::test { + +/// @brief Information about how a set of comparisons failed or succeeded. +/// +/// This structure represents a "report" generated by comparing sets of tensors. +/// Its intended to be used as the result of `ckt::validate()`, where `check()` +/// is invoked for each of the output tensors of a particular device operation. +/// The test should be considered successful if _all_ of those checks passes, +/// which can inspected by asserting that `get_errors().size()` is 0. +struct ValidationReport +{ + /// @brief Information related to a single tensor comparison. + /// + /// This structure holds the information about the result of comparing + /// two particular tensors. + struct Case + { + /// The name of the tensor that was compared here, stored here for convenience + /// so that reporting any errors is easier. + std::string tensor_name; + + /// The number of elements which were different between the two compared tensors. + uint64_t wrong_elements; + + /// The total number of elements in each tensor. + uint64_t total_elements; + + /// The number of elements which were bitwise 0. + uint64_t zero_elements; + + /// @brief Check whether both the output and reference tensor were both all zeros. + /// + /// If both tensors are all zero, it indicates either an incorrect testing setup + /// or an issue with the testing framework. For that reason we also consider that + /// a failure. + bool is_all_zero() const { return zero_elements == total_elements; } + + /// @brief Return whether the check associated to this case was successful. + /// + /// This function returns whether the check associated to this case was successful, + /// which is directly derived from checking whether the number of incorrect elements + /// was 0 AND whether the tensor was not all zero. + bool is_ok() const { return wrong_elements == 0 && !is_all_zero(); } + }; + + /// @brief Get comparison cases which were incorrect. + /// + /// This function returns a vector of comparison cases that did not succeed, ie, for + /// which `Case::is_ok` return false. In order to check whether validation passed, it + /// is sufficient to assert that this function returns no cases. + std::vector get_errors() const + { + std::vector errors; + std::copy_if(reports_.begin(), + reports_.end(), + std::back_inserter(errors), + [](const auto& report) { return !report.is_ok(); }); + return errors; + } + + /// @brief Compare two tensors and record the results in the report. + /// + /// This is the main function used to compare two tensors. The results of this + /// comparison, including any supplemental information, is recorded into the report. + /// + /// @returns `false` if the comparison failed. If so, the details can be found via + /// `get_errors()`. + /// + /// @tparam DT The data type of the tensors to check. + /// @tparam RANK The rank (number of spatial dimensions) of the tensor to check. + /// + /// @param tensor_name The name of the tensors to check. This should be a value by which + /// whoever is debugging the associated test later can easily find out which of the + /// outputs of a device operation was incorrect. + /// @param descriptor The descriptor (memory layout) of the tensor. + /// @param actual The device buffer with the values of the tensor to-be-tested, ie, the + /// results of the device operation. + /// @param expected The device buffer with the values of the reference tensor. These are + /// treated as a "golden standard", and should usually be generated by a reference + /// implementation. + /// @param rtol The relative acceptable tolerance between two values. + /// @param atol The absolute acceptable tolerance between two values. + template + bool check(std::string_view tensor_name, + const TensorDescriptor& descriptor, + const void* actual, + const void* expected, + double rtol = 1e-3, + double atol = 1e-3); + + private: + std::vector reports_; +}; + +template +bool ValidationReport::check(std::string_view tensor_name, + const TensorDescriptor& descriptor, + const void* actual_data, + const void* expected_data, + double rtol, + double atol) +{ + const auto strides = descriptor.get_strides(); + + // During development and CI, only the kernels that were changed would fail, and so we can + // assume that the average case does not have errors. Therefore, split out testing into a + // quick test which just counts the incorrect elements, and a more in-depth test that also + // returns the indices of the incorrect items. + + // Initial pass: count errors + + // Allocate and reset counter + auto d_counters = alloc_buffer(sizeof(uint64_t) * 2); + check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2)); + + auto d_error_count = &reinterpret_cast(d_counters.get())[0]; + auto d_zero_count = &reinterpret_cast(d_counters.get())[1]; + + tensor_foreach(descriptor.get_lengths(), [=](auto index) { + using CKType = typename factory::internal::DataTypeToCK
::type; + + const auto* actual = static_cast(actual_data); + const auto* expected = static_cast(expected_data); + + static_assert(!std::is_same_v, + "TODO implement compare_kernel() for double"); + + const auto offset = calculate_offset(index, strides); + + const auto a = actual[offset]; + const auto b = expected[offset]; + + const auto o = static_cast(type_convert(a)); + const auto r = static_cast(type_convert(b)); + const auto err = std::abs(o - r); + + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + // We expect the number of errors to be very low, so just use an atomic + // for now. + atomicAdd(d_error_count, 1); + } + + // Now compare the numbers as bitwise too. + // Update the counter if they're both zero. + using Bytes = std::array; + bool all_zero = true; + for(auto x : std::bit_cast(a)) + { + if(x != std::byte{0}) + all_zero = false; + } + for(auto x : std::bit_cast(b)) + { + if(x != std::byte{0}) + all_zero = false; + } + if(all_zero) + { + atomicAdd(d_zero_count, 1); + } + }); + + uint64_t error_count = 0; + check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); + uint64_t zero_count = 0; + check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); + + // TODO: Gather detailed coordinates. + + reports_.push_back(Case{ + .tensor_name = std::string(tensor_name), + .wrong_elements = error_count, + .total_elements = descriptor.get_element_size(), + .zero_elements = zero_count, + }); + + return reports_.back().is_ok(); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index c1c62e91fa..c4cca05e52 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -24,7 +24,7 @@ enum class DataType FP8, BF8, FP64, - INT32, + I32, I8, I8_I8, U8 @@ -192,8 +192,8 @@ enum class TileConvSpecialization FILTER_3x3 }; -// Enums for the forward convolution specialization. -enum class ConvFwdSpecialization +// Enums for the convolution specializations. +enum class ConvSpecialization { DEFAULT, FILTER_1X1_PAD0, @@ -202,22 +202,6 @@ enum class ConvFwdSpecialization ODD_C }; -// Enums for the backward data convolution specialization. -enum class ConvBwdDataSpecialization -{ - DEFAULT, - FILTER_1X1_STRIDE1_PAD0, -}; - -// Enums for the backward weight convolution specialization. -enum class ConvBwdWeightSpecialization -{ - DEFAULT, - FILTER_1X1_STRIDE1_PAD0, - FILTER_1X1_PAD0, - ODD_C, -}; - // Enums for the Gemm padding. enum class GemmPadding { @@ -249,11 +233,13 @@ enum class PipelineScheduler enum class ConvAlgorithmSpecialization { LARGE_TENSOR, - REFERENCE // GPU reference implementation for validation + REFERENCE, // GPU reference implementation for validation, + TWO_STAGE, + MULTIPLE_D }; -// toString methods for enum classes -inline std::string_view toString(DataType dt) +// to_string methods for enum classes +inline std::string_view to_string(DataType dt) { using enum DataType; switch(dt) @@ -267,7 +253,7 @@ inline std::string_view toString(DataType dt) case FP8: return "FP8"; case BF8: return "BF8"; case FP64: return "FP64"; - case INT32: return "INT32"; + case I32: return "I32"; case I8: return "I8"; case I8_I8: return "I8_I8"; case U8: return "U8"; @@ -276,7 +262,7 @@ inline std::string_view toString(DataType dt) } } -inline std::string_view toString(ConvDirection dir) +inline std::string_view to_string(ConvDirection dir) { using enum ConvDirection; switch(dir) @@ -288,7 +274,7 @@ inline std::string_view toString(ConvDirection dir) } } -inline std::string_view toString(ElementwiseOperation op) +inline std::string_view to_string(ElementwiseOperation op) { using enum ElementwiseOperation; switch(op) @@ -332,7 +318,7 @@ inline std::string_view toString(ElementwiseOperation op) } } -inline std::string_view toString(PipelineVersion ver) +inline std::string_view to_string(PipelineVersion ver) { using enum PipelineVersion; switch(ver) @@ -347,7 +333,7 @@ inline std::string_view toString(PipelineVersion ver) } } -inline std::string_view toString(GemmSpecialization spec) +inline std::string_view to_string(GemmSpecialization spec) { using enum GemmSpecialization; switch(spec) @@ -372,9 +358,9 @@ inline std::string_view toString(GemmSpecialization spec) } } -inline std::string_view toString(ConvFwdSpecialization spec) +inline std::string_view to_string(ConvSpecialization spec) { - using enum ConvFwdSpecialization; + using enum ConvSpecialization; switch(spec) { case DEFAULT: return "DEFAULT"; @@ -386,31 +372,7 @@ inline std::string_view toString(ConvFwdSpecialization spec) } } -inline std::string_view toString(ConvBwdDataSpecialization spec) -{ - using enum ConvBwdDataSpecialization; - switch(spec) - { - case DEFAULT: return "DEFAULT"; - case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0"; - default: return "Unknown"; - } -} - -inline std::string_view toString(ConvBwdWeightSpecialization spec) -{ - using enum ConvBwdWeightSpecialization; - switch(spec) - { - case DEFAULT: return "DEFAULT"; - case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0"; - case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0"; - case ODD_C: return "ODD_C"; - default: return "Unknown"; - } -} - -inline std::string_view toString(GemmPadding padding) +inline std::string_view to_string(GemmPadding padding) { using enum GemmPadding; switch(padding) @@ -435,7 +397,7 @@ inline std::string_view toString(GemmPadding padding) } } -inline std::string_view toString(PipelineScheduler sched) +inline std::string_view to_string(PipelineScheduler sched) { using enum PipelineScheduler; switch(sched) @@ -447,7 +409,7 @@ inline std::string_view toString(PipelineScheduler sched) } } -inline std::string_view toString(TensorLayout layout) +inline std::string_view to_string(TensorLayout layout) { using enum TensorLayout; switch(layout) @@ -503,63 +465,46 @@ inline std::string_view toString(TensorLayout layout) } // ostream operator overloads for enum classes -inline std::ostream& operator<<(std::ostream& os, DataType dt) { return os << toString(dt); } +inline std::ostream& operator<<(std::ostream& os, DataType dt) { return os << to_string(dt); } -inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) { return os << toString(dir); } +inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) +{ + return os << to_string(dir); +} inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op) { - return os << toString(op); + return os << to_string(op); } inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver) { - return os << toString(ver); + return os << to_string(ver); } inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec) { - return os << toString(spec); + return os << to_string(spec); } -inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec) +inline std::ostream& operator<<(std::ostream& os, ConvSpecialization spec) { - return os << toString(spec); -} - -inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec) -{ - return os << toString(spec); -} - -inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec) -{ - return os << toString(spec); + return os << to_string(spec); } inline std::ostream& operator<<(std::ostream& os, GemmPadding padding) { - return os << toString(padding); + return os << to_string(padding); } inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched) { - return os << toString(sched); + return os << to_string(sched); } inline std::ostream& operator<<(std::ostream& os, TensorLayout layout) { - return os << toString(layout); -} - -// ostream operator overload for std::variant of convolution specializations -inline std::ostream& operator<<(std::ostream& os, - const std::variant& spec) -{ - std::visit([&os](const auto& s) { os << s; }, spec); - return os; + return os << to_string(layout); } } // namespace ck_tile::builder diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 800d485660..9890563859 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -80,33 +80,41 @@ add_ck_builder_test(test_ckb_conv_builder test_instance_traits_util.cpp unit_device_buffer.cpp unit_tensor_descriptor.cpp + unit_tensor_foreach.cpp + unit_error.cpp + unit_validation.cpp + unit_debug.cpp + unit_conv_fwd_testing.cpp unit_conv_elementwise_op.cpp unit_conv_tensor_layout.cpp unit_conv_tensor_type.cpp unit_conv_thread_block.cpp unit_conv_tuning_params.cpp) - - # Tests the inline diff utility used for comparing strings in tests assertions - add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp) +target_link_libraries(test_ckb_conv_builder PRIVATE utility) - # GPU reference validation tests (in validation/ folder) - # 1. Reference kernel execution and InstanceTraits - add_ck_builder_test(test_ckb_reference_execution - validation/test_reference_execution.cpp - validation/test_reference_instance_traits.cpp) - target_link_libraries(test_ckb_reference_execution PRIVATE utility) - - # Note: Optimized kernel validation tests will be added after merging dev branch - # with kernel Run() implementation from colleague's work +# Tests the inline diff utility used for comparing strings in tests assertions +add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp) + +# GPU reference validation tests (in validation/ folder) +# 1. Reference kernel execution and InstanceTraits +add_ck_builder_test(test_ckb_reference_execution + validation/test_reference_execution.cpp + validation/test_reference_instance_traits.cpp) +target_link_libraries(test_ckb_reference_execution PRIVATE utility) + +# Note: Optimized kernel validation tests will be added after merging dev branch +# with kernel Run() implementation from colleague's work + +# Tests convolution trait selection and configuration +add_ck_builder_test(test_ckb_conv_traits + conv/ck/test_conv_traits.cpp + conv/ck/unit_instance_to_conv_traits_features.cpp + conv/ck/unit_instance_to_conv_traits_instances.cpp) + +# Tests convolution problem description and parameter handling +add_ck_builder_test(test_ckb_conv_description + test_conv_description.cpp) - # Tests convolution trait selection and configuration - add_ck_builder_test(test_ckb_conv_traits - conv/ck/test_conv_traits.cpp) - - # Tests convolution problem description and parameter handling - add_ck_builder_test(test_ckb_conv_description - test_conv_description.cpp) - ################################################################################ # REGRESSION TESTS - Integration Tests (With Kernel Compilation) ################################################################################ @@ -117,7 +125,7 @@ add_ck_builder_test(test_ckb_conv_builder # Verifies that GetInstanceString() methods and other functions produce valid kernel code. # Tests various convolution types: # - Group convolution (v3, standard, large tensor, WMMA, DL variants) -# - Backward weight group convolution (XDL) +# - Backward weight group convolution (XDL standard, XDL v3, WMMA, DL, multiple D, two-stage variants) # Requires kernel compilation to validate the generated strings through the base class. set(INSTANCE_STRING_TESTS @@ -160,10 +168,35 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/ck/test_ckb_conv_fwd_3d_fp16.cpp conv/ck/test_ckb_conv_fwd_3d_fp32.cpp conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp - conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp - conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp) + ) target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) +set(BWD_WEIGHT_TESTS + conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp + conv/ck/test_ckb_conv_bwd_weight_dl.cpp + conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +) + +if (CK_USE_WMMA) + list(APPEND BWD_WEIGHT_TESTS + conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp + conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp + conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp + ) +endif() + +add_ck_builder_test(test_ckb_build_bwd_weight_instances ${BWD_WEIGHT_TESTS}) +target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility) + +add_ck_builder_test(test_ckb_build_bwd_data_instances + conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp + ) +target_link_libraries(test_ckb_build_bwd_data_instances PRIVATE utility) + ################################################################################ # FACTORY TESTS - Expensive Regression Tests (Full MIOpen Kernel Set) @@ -217,6 +250,8 @@ endforeach() set(CKB_REGRESSION_TESTS test_ckb_instance_string test_ckb_build_fwd_instances + test_ckb_build_bwd_weight_instances + test_ckb_build_bwd_data_instances test_ckb_testing_utils # test_ckb_factory_grouped_convolution_forward_convscale # test_ckb_factory_grouped_convolution_forward_scaleadd_ab diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp new file mode 100644 index 0000000000..584bce2f1b --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl{} + .with_thread_block(cku::ThreadBlock_256_128x128x16) + .with_bwd_specialization(cku::ConvSpecialization::DEFAULT) + .with_dl_thread_config(cku::DlThreadConfig_16x1x4x4x1) + .with_dl_thread_cluster(cku::DlThreadCluster_8x2) + .with_dl_transfer(cku::DlTransfer5D); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DBf16_DL, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeight_Dl", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough"}); +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp new file mode 100644 index 0000000000..404d1dbacd --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 3, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNDHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKZYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNDHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_3DFp16_MultiD_Wmma_ShuffleV3_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3", + expected_transfer_parameters, + "Default", + "GNDHWC,GKZYXC,GNDHWK", + "PassThrough,PassThrough,PassThrough", + "fp16,fp16>"}); // check compute types +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp new file mode 100644 index 0000000000..206fc8beb9 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle{} + .with_thread_block(cku::ThreadBlock_256_128x128x8) + .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) + .with_transfer(cku::BwdTransfer_4x64x1) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DFp16_MultiD_CShuffle_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough", + "fp16,fp16>"}); // check compute types +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp new file mode 100644 index 0000000000..782f33f845 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; + +constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCHW}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NGKHW}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) + .with_num_conv_groups_to_merge(2) + .with_transpose_params(2, 2); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DFp16_TwoStage_Wmma_CShuffle_V3, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3", + expected_transfer_parameters, + "Default", + "NGCHW,GKYXC,NGKHW", + "PassThrough,PassThrough,PassThrough", + "Intrawave", + "v1", + "fp16,fp16,2,2>"}); // Check compute types and transpose params. +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp new file mode 100644 index 0000000000..a2a877dbcd --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_block_gemm(cku::BlockGemmDesc_v2_intrawave) + .with_num_conv_groups_to_merge(2) + .with_transpose_params(2, 4); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DBf16_TwoStage_CShuffle, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough", + "Intrawave,v2", // pipeline versions + "bf16,bf16,2,4>"}); // compute types and transpose params +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp new file mode 100644 index 0000000000..ff350ac804 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; + +constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 3, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCDHW}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = NGKDHW}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT) + .with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_3DBf16_Wmma_CShuffle, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeight_Wmma_CShuffle", + expected_transfer_parameters, + "Default", + "NGCDHW,GKZYXC,NGKDHW", + "PassThrough,PassThrough,PassThrough", + "v1"}); +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp new file mode 100644 index 0000000000..60f7d5bd64 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; + +constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCW}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NGKW}}}; + +constexpr auto ALGORITHM = + cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) + .with_transpose_params(4, 4); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_1DBf16_Wmma_CShuffle_V3, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeight_Wmma_CShuffleV3", + expected_transfer_parameters, + "Filter1x1Stride1Pad0", + "NGCW,GKXC,NGKW", + "PassThrough,PassThrough,PassThrough", + "Intrawave", + "v1", + "bf16,bf16,4,4>"}); +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp new file mode 100644 index 0000000000..892f1d35ef --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{} + .with_thread_block(cku::ThreadBlock_256_128x128x8) + .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) + .with_transfer(cku::BwdTransfer_4x64x1) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_transpose_params(2, 2); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DFp16_CShuffle_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffle", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough", + "fp16,fp16,2,2>"}); // check compute types and transpose params +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp new file mode 100644 index 0000000000..4ad97209e5 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; + +constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCW}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NGKW}}}; + +constexpr auto ALGORITHM = + cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) + .with_block_gemm(cku::BlockGemmDesc_v2_intrawave); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_1DBf16_CShuffle_V3, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3", + expected_transfer_parameters, + "Filter1x1Stride1Pad0", + "NGCW,GKXC,NGKW", + "PassThrough,PassThrough,PassThrough", + "Intrawave", + "v2"}); +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp index 284b3929ee..8d85370b26 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp @@ -30,11 +30,11 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, - GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0, + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v2_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp index 6802e0caf8..d3ace110c4 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp @@ -27,11 +27,12 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_64_64x32x32) + .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x16x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 2, PipelineScheduler::DEFAULT); + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(2); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index 14463bbc17..06d200429c 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -22,18 +22,20 @@ TEST(FwdConvInstances, constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, .direction = FORWARD, .data_type = I8, - .accumulation_data_type = INT32, + .accumulation_data_type = I32, .input = {.config = {.layout = GNWC}}, .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = GNWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} - .with_thread_block(FwdThreadBlock_128_64x64x64) - .with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave) - .with_transfer(FwdTransfer_4x32x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT); + .with_thread_block(ThreadBlock_128_64x64x64) + .with_gemm_config(GemmParams_Wmma_2x1_per_wave) + .with_transfer(Transfer_4x32x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(2) + .with_gridwise_gemm_pipeline(PipelineVersion::V1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp index 4a5618a6b1..610e2fad5f 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp @@ -27,10 +27,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; @@ -64,10 +64,11 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::FILTER_3x3, + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v5_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp index 0d9563e05a..23edef5436 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp @@ -32,11 +32,12 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_64_64x32x32) + .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x16x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp index 9bea834ef9..58171cd530 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -25,15 +25,16 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} - .with_thread_block(FwdThreadBlock_256_128x128x16) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_thread_block(ThreadBlock_256_128x128x16) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) - .with_dl_transfer(DlFwdTransfer); + .with_dl_transfer(DlTransfer4D); using Builder = ConvBuilder; const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; run_test({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", expected_transfer_parameters, "Default", @@ -59,16 +60,17 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} - .with_thread_block(FwdThreadBlock_256_128x128x16) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, - GemmSpecialization::MNKPadding) + .with_thread_block(ThreadBlock_256_128x128x16) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) - .with_dl_transfer(DlFwdTransfer); + .with_dl_transfer(DlTransfer4D); using Builder = ConvBuilder; const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; run_test({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", expected_transfer_parameters, "Filter1x1Pad0", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index aa53aa9666..3e5e39191e 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -5,12 +5,16 @@ #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" #include "ck_tile/builder/testing/conv_fwd_ck.hpp" +#include "ck_tile/builder/testing/conv_fwd_reference.hpp" #include "ck_tile/host/device_prop.hpp" +#include "testing_utils.hpp" namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; +using ck_tile::test::MatchesReference; + constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, .direction = ckb::ConvDirection::FORWARD, @@ -21,16 +25,18 @@ constexpr auto SIGNATURE = .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(cku::FwdThreadBlock_256_256x256x32) + .with_thread_block(cku::ThreadBlock_256_256x256x32) .with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(cku::FwdTransfer_4x64x1) - .with_specializations(ckb::ConvFwdSpecialization::DEFAULT, - ckb::GemmSpecialization::MNKPadding) + .with_transfer(cku::Transfer_4x64x1) + .with_fwd_specializations(ckb::ConvSpecialization::DEFAULT, + ckb::GemmSpecialization::MNKPadding) .with_block_gemm(cku::BlockGemmDesc_v3_intrawave); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; +using Reference = ckb::ConvBuilder::Instance; + TEST(Fwd2DFp16_CShufV3_GNHWC, Create) { const auto expected_transfer_parameters = to_string(ALGORITHM); @@ -78,11 +84,17 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd) .cde_elementwise_op = {}, }; - auto inputs = alloc_inputs(args); - auto outputs = alloc_outputs(args); + auto inputs = ckt::alloc_inputs(args); + auto outputs = ckt::alloc_outputs(args); + auto reference = ckt::alloc_outputs(args); - init_inputs(args, inputs); + ckt::init_inputs(args, inputs.get()); auto conv = Instance{}; ckt::run(conv, args, inputs.get(), outputs.get()); + + auto ref_conv = Reference{}; + ckt::run(ref_conv, args, inputs.get(), reference.get()); + + EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); } diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp index 79ee4915e8..bb35c53ba0 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp @@ -26,11 +26,11 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_128x128x32) + .with_thread_block(ThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, - GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0, + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp index 3e3d7e8c2b..b117e693fe 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp @@ -27,11 +27,12 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_256_256x128x32) + .with_thread_block(ThreadBlock_256_256x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave) - .with_transfer(FwdTransfer_4x64x1_fp8) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); + .with_transfer(Transfer_4x64x1_fp8) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 3019c57a18..97bc0a00e5 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -25,14 +25,13 @@ TEST(FwdConvInstances, .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ - .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_256_256x128x32) - .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x16x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, - GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{} + .with_thread_block(ThreadBlock_256_256x128x32) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; @@ -62,14 +61,14 @@ TEST( .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ - .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_128_128x128x32) - .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x16x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, - GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{} + .with_thread_block(ThreadBlock_128_128x128x32) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp index 3f9bdfb972..9e6ca00e58 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp @@ -27,10 +27,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v3_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp index b30f958bc4..56d4b8be59 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp @@ -27,11 +27,11 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_128x128x32) + .with_thread_block(ThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, - GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index 33c01c8ac4..df8339241b 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -27,11 +27,11 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, - GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index d5661ad67b..42235df2fe 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include #include @@ -86,72 +86,72 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) ck::half_t, // BComputeDataType false>; // DirectLoad - // Use ConvTraits to extract compile-time information - using Traits = ck_tile::reflect::conv::ConvTraits; + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); // Verify signature information - EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); - EXPECT_THAT(Traits::layout, + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_THAT(traits.layout, ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations - EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information - EXPECT_EQ(Traits::thread_block_size, 256); + EXPECT_EQ(traits.thread_block_size, 256); // Verify tile dimensions - EXPECT_EQ(Traits::tile_dims.m, 128); - EXPECT_EQ(Traits::tile_dims.n, 128); - EXPECT_EQ(Traits::tile_dims.k, 16); + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); // Verify A tile transfer info - EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2); - EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128); - EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8); - EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8); - EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); - EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); - EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); - EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2); - EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8); - EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); - EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); // Verify B tile transfer info - EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2); - EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128); - EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8); - EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8); - EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); - EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); - EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); - EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2); - EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8); - EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); - EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); // Verify warp GEMM params - EXPECT_EQ(Traits::warp_gemm.gemm_m, 32); - EXPECT_EQ(Traits::warp_gemm.gemm_n, 32); - EXPECT_EQ(Traits::warp_gemm.m_iter, 4); - EXPECT_EQ(Traits::warp_gemm.n_iter, 4); + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); // Verify output tile transfer info - EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); - EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); - EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); - EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); // Verify pipeline configuration - EXPECT_EQ(Traits::pipeline_scheduler, PipelineScheduler::INTRAWAVE); - EXPECT_EQ(Traits::pipeline_version, PipelineVersion::V1); + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); } // Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle @@ -214,30 +214,30 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction) ck::LoopScheduler::Default, // LoopSched 1>; // NumGroupsToMerge - // Use ConvTraits to extract compile-time information - using Traits = ck_tile::reflect::conv::ConvTraits; + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); // Verify signature information - EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); - EXPECT_THAT(Traits::layout, + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_THAT(traits.layout, ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations - EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information - EXPECT_EQ(Traits::thread_block_size, 256); + EXPECT_EQ(traits.thread_block_size, 256); // Verify tile dimensions - EXPECT_EQ(Traits::tile_dims.m, 128); - EXPECT_EQ(Traits::tile_dims.n, 128); - EXPECT_EQ(Traits::tile_dims.k, 16); + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); } // Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) @@ -298,29 +298,29 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) ck::half_t, // BComputeDataType ck::LoopScheduler::Default>; // LoopSched - // Use ConvTraits to extract compile-time information - using Traits = ck_tile::reflect::conv::ConvTraits; + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); // Verify signature information - EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); - EXPECT_THAT(Traits::layout, + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_THAT(traits.layout, ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations - EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information - EXPECT_EQ(Traits::thread_block_size, 256); + EXPECT_EQ(traits.thread_block_size, 256); // Verify tile dimensions - EXPECT_EQ(Traits::tile_dims.m, 128); - EXPECT_EQ(Traits::tile_dims.n, 128); - EXPECT_EQ(Traits::tile_dims.k, 16); + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); } } // anonymous namespace diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp new file mode 100644 index 0000000000..72269c38ac --- /dev/null +++ b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp @@ -0,0 +1,800 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// ============================================================================ +// Unit Tests for Individual Conversion Functions +// ============================================================================ +// +// PURPOSE: +// -------- +// These tests verify individual conversion and extraction functions that +// transform raw CK kernel parameters into semantic types. Each test +// focuses on a single conversion function to ensure it correctly maps +// CK types to builder enums and structures. +// +// TEST COVERAGE: +// -------------- +// 1. Enum Conversions: +// - Pipeline versions (BlockGemmPipelineVersion and PipelineVersion) +// - Pipeline schedulers (BlockGemmPipelineScheduler and LoopScheduler) +// +// 2. Elementwise Operations (14 operations): +// - PassThrough, Scale, Relu, Gelu, Sigmoid, Tanh, ScaleAdd +// - Silu, Swish, Elu, LeakyRelu, UnaryConvert, ConvScale, ConvScaleAdd +// +// 3. Convolution Properties: +// - Direction detection (Forward) +// - Specializations (Default, Filter1x1Pad0, Filter1x1Stride1Pad0, +// Filter3x3, OddC) +// +// 4. Layout Detection: +// - 1D layouts (GNWC, NWGC, NGCW) +// - 2D layouts (GNHWC, NHWGC, NGCHW with GKYXC/GKCYX) +// - 3D layouts (GNDHWC, NDHWGC, NGCDHW) +// +// 5. Data Type Detection: +// - FP16, BF16, FP32, I8 +// +// 6. Pipeline Configuration: +// - Pipeline versions (V2, V3) +// - Schedulers (Interwave) +// +// 7. GEMM Padding Variations (17 types): +// - Default, MNK, M, N, K, MN, MK, NK +// - O, MO, NO, KO, MNO, MKO, NKO, MNKO +// ============================================================================ + +#include "ck/utility/scheduler_enum.hpp" +#include "ck_tile/builder/types.hpp" +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +using ::ck_tile::builder::ConvDirection; +using ::ck_tile::builder::DataType; +using ::ck_tile::builder::ElementwiseOperation; +using ::ck_tile::builder::GemmPadding; +using ::ck_tile::builder::PipelineScheduler; +using ::ck_tile::builder::PipelineVersion; +using ::ck_tile::builder::TensorLayout; +using ::testing::ElementsAre; + +// ============================================================================ +// Test Helper Templates +// ============================================================================ +// These templates reduce boilerplate by providing sensible defaults for +// template parameters that don't vary in most tests. +// ============================================================================ + +namespace defaults { +// Default values used across most tests +static constexpr int kBlockSize = 256; +static constexpr int kMPerBlock = 128; +static constexpr int kNPerBlock = 128; +static constexpr int kKPerBlock = 16; +static constexpr int kAK1 = 8; +static constexpr int kBK1 = 8; +static constexpr int kMPerXDL = 32; +static constexpr int kNPerXDL = 32; +static constexpr int kMXdlPerWave = 4; +static constexpr int kNXdlPerWave = 4; +static constexpr int kABlockTransferSrcVectorDim = 2; +static constexpr int kABlockTransferSrcScalarPerVector = 8; +static constexpr int kABlockTransferDstScalarPerVector_AK1 = 8; +static constexpr int kABlockLdsExtraM = 1; +static constexpr int kBBlockTransferSrcVectorDim = 2; +static constexpr int kBBlockTransferSrcScalarPerVector = 8; +static constexpr int kBBlockTransferDstScalarPerVector_BK1 = 8; +static constexpr int kBBlockLdsExtraN = 1; +static constexpr int kCShuffleMXdlPerWavePerShuffle = 1; +static constexpr int kCShuffleNXdlPerWavePerShuffle = 1; +static constexpr int kCDEBlockTransferScalarPerVector_NPerBlock = 8; +static constexpr bool kDirectLoad = false; + +using DefaultABlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>; +using DefaultABlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>; +using DefaultABlockTransferSrcAccessOrder = ck::Sequence<1, 0, 2>; +using DefaultBBlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>; +using DefaultBBlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>; +using DefaultBBlockTransferSrcAccessOrder = ck::Sequence<1, 0, 2>; +using DefaultCDEBlockTransferClusterLengths = ck::Sequence<1, 32, 1, 8>; +} // namespace defaults + +// DeviceInstanceForTests - V3 variant with sensible defaults +template +using DeviceInstanceForTests_V3 = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + NDimSpatial, + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + AccDataType, + ADataType, + ck::Tuple<>, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ConvForwardSpecialization, + GemmSpec, + defaults::kBlockSize, + defaults::kMPerBlock, + defaults::kNPerBlock, + defaults::kKPerBlock, + defaults::kAK1, + defaults::kBK1, + defaults::kMPerXDL, + defaults::kNPerXDL, + defaults::kMXdlPerWave, + defaults::kNXdlPerWave, + defaults::DefaultABlockTransferThreadClusterLengths, + defaults::DefaultABlockTransferThreadClusterArrangeOrder, + defaults::DefaultABlockTransferSrcAccessOrder, + defaults::kABlockTransferSrcVectorDim, + defaults::kABlockTransferSrcScalarPerVector, + defaults::kABlockTransferDstScalarPerVector_AK1, + defaults::kABlockLdsExtraM, + defaults::DefaultBBlockTransferThreadClusterLengths, + defaults::DefaultBBlockTransferThreadClusterArrangeOrder, + defaults::DefaultBBlockTransferSrcAccessOrder, + defaults::kBBlockTransferSrcVectorDim, + defaults::kBBlockTransferSrcScalarPerVector, + defaults::kBBlockTransferDstScalarPerVector_BK1, + defaults::kBBlockLdsExtraN, + defaults::kCShuffleMXdlPerWavePerShuffle, + defaults::kCShuffleNXdlPerWavePerShuffle, + defaults::DefaultCDEBlockTransferClusterLengths, + defaults::kCDEBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ADataType, + BDataType, + defaults::kDirectLoad>; + +// Test case helper for specialization testing +template +using SpecializationTestInstance = + DeviceInstanceForTests_V3<2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Spec>; + +// Test case helper for layout testing (1D, 2D, 3D) +template +using LayoutTestInstance = DeviceInstanceForTests_V3; + +// Test case helper for data type testing +template +using DataTypeTestInstance = DeviceInstanceForTests_V3<2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + DataType, + DataType, + DataType, + AccDataType>; + +// Test case helper for pipeline version testing +template +using PipelineVersionTestInstance = DeviceInstanceForTests_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + ck::BlockGemmPipelineScheduler::Intrawave, + PipelineVer>; + +// Test case helper for pipeline scheduler testing +template +using PipelineSchedulerTestInstance = DeviceInstanceForTests_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + Scheduler>; + +// Test case helper for GEMM padding testing +template +using GemmPaddingTestInstance = DeviceInstanceForTests_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + GemmSpec>; + +// ============================================================================ +// Test Enum Conversion Functions +// ============================================================================ + +TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineVersion) +{ + using ck_tile::reflect::conv::convert_pipeline_version; + using enum ::ck::BlockGemmPipelineVersion; + using enum ::ck_tile::builder::PipelineVersion; + EXPECT_EQ(convert_pipeline_version(), V1); + EXPECT_EQ(convert_pipeline_version(), V2); + EXPECT_EQ(convert_pipeline_version(), V3); + EXPECT_EQ(convert_pipeline_version(), V4); + EXPECT_EQ(convert_pipeline_version(), V5); +} + +TEST(InstanceToConvTraits, ConvertsPipelineVersion) +{ + using ck_tile::reflect::conv::convert_pipeline_version; + using enum ck::PipelineVersion; + using enum PipelineVersion; + EXPECT_EQ(convert_pipeline_version(), V1); + EXPECT_EQ(convert_pipeline_version(), V2); + EXPECT_EQ(convert_pipeline_version(), V4); + EXPECT_EQ(convert_pipeline_version(), WEIGHT_ONLY); +} + +TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineScheduler) +{ + using ck_tile::reflect::conv::convert_pipeline_scheduler; + using enum ck::BlockGemmPipelineScheduler; + using enum PipelineScheduler; + EXPECT_EQ(convert_pipeline_scheduler(), INTRAWAVE); + EXPECT_EQ(convert_pipeline_scheduler(), INTERWAVE); +} + +TEST(InstanceToConvTraits, ConvertsLoopScheduler) +{ + using ck_tile::reflect::conv::convert_pipeline_scheduler; + using enum ck::LoopScheduler; + using enum PipelineScheduler; + EXPECT_EQ(convert_pipeline_scheduler(), DEFAULT); + EXPECT_EQ(convert_pipeline_scheduler(), INTERWAVE); +} + +// ============================================================================ +// Test Elementwise Operations +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsPassThroughOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, PASS_THROUGH); +} + +TEST(InstanceToConvTraits, ExtractsScaleOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SCALE); +} + +TEST(InstanceToConvTraits, ExtractsReluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, RELU); +} + +TEST(InstanceToConvTraits, ExtractsGeluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, GELU); +} + +TEST(InstanceToConvTraits, ExtractsSigmoidOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SIGMOID); +} + +TEST(InstanceToConvTraits, ExtractsTanhOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, TANH); +} + +TEST(InstanceToConvTraits, ExtractsScaleAddOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SCALE_ADD); +} + +TEST(InstanceToConvTraits, ExtractsSiluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SILU); +} + +TEST(InstanceToConvTraits, ExtractsSwishOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SWISH); +} + +TEST(InstanceToConvTraits, ExtractsEluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, ELU); +} + +TEST(InstanceToConvTraits, ExtractsLeakyReluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, LEAKY_RELU); +} + +TEST(InstanceToConvTraits, ExtractsUnaryConvertOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, UNARY_CONVERT); +} + +TEST(InstanceToConvTraits, ExtractsConvScaleOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, CONV_SCALE); +} + +TEST(InstanceToConvTraits, ExtractsConvScaleAddOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, CONV_SCALE_ADD); +} + +// ============================================================================ +// Test Convolution Direction Detection +// ============================================================================ + +TEST(InstanceToConvTraits, DetectsForwardDirection) +{ + using DeviceInstance = DeviceInstanceForTests_V3<>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); +} + +// ============================================================================ +// Test Convolution Specialization Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsDefaultSpecialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); +} + +TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0); +} + +TEST(InstanceToConvTraits, ExtractsFilter1x1Stride1Pad0Specialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, + ck_tile::builder::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0); +} + +TEST(InstanceToConvTraits, ExtractsFilter3x3Specialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_3x3); +} + +TEST(InstanceToConvTraits, ExtractsOddCSpecialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::ODD_C); +} + +// ============================================================================ +// Test 1D Convolution Layout Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsGnwcLayout) +{ + using DeviceInstance = LayoutTestInstance<1, + ck::tensor_layout::convolution::GNWC, + ck::tensor_layout::convolution::GKXC, + ck::tensor_layout::convolution::GNWK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 1); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNWC, TensorLayout::GKXC, TensorLayout::GNWK)); +} + +TEST(InstanceToConvTraits, ExtractsNwgcLayout) +{ + using DeviceInstance = LayoutTestInstance<1, + ck::tensor_layout::convolution::NWGC, + ck::tensor_layout::convolution::GKXC, + ck::tensor_layout::convolution::NWGK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 1); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NWGC, TensorLayout::GKXC, TensorLayout::NWGK)); +} + +TEST(InstanceToConvTraits, ExtractsNgcwLayout) +{ + using DeviceInstance = LayoutTestInstance<1, + ck::tensor_layout::convolution::NGCW, + ck::tensor_layout::convolution::GKXC, + ck::tensor_layout::convolution::NGKW>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 1); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NGCW, TensorLayout::GKXC, TensorLayout::NGKW)); +} + +// ============================================================================ +// Test 2D Convolution Layout Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsGnhwcLayout) +{ + using DeviceInstance = LayoutTestInstance<2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); +} + +TEST(InstanceToConvTraits, ExtractsNhwgcLayout) +{ + using DeviceInstance = LayoutTestInstance<2, + ck::tensor_layout::convolution::NHWGC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::NHWGK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NHWGC, TensorLayout::GKYXC, TensorLayout::NHWGK)); +} + +TEST(InstanceToConvTraits, ExtractsNgchwGkyxcLayout) +{ + using DeviceInstance = LayoutTestInstance<2, + ck::tensor_layout::convolution::NGCHW, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::NGKHW>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NGCHW, TensorLayout::GKYXC, TensorLayout::NGKHW)); +} + +TEST(InstanceToConvTraits, ExtractsNgchwGkcyxLayout) +{ + using DeviceInstance = LayoutTestInstance<2, + ck::tensor_layout::convolution::NGCHW, + ck::tensor_layout::convolution::GKCYX, + ck::tensor_layout::convolution::NGKHW>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NGCHW, TensorLayout::GKCYX, TensorLayout::NGKHW)); +} + +// ============================================================================ +// Test 3D Convolution Layout Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsGndhwcLayout) +{ + using DeviceInstance = LayoutTestInstance<3, + ck::tensor_layout::convolution::GNDHWC, + ck::tensor_layout::convolution::GKZYXC, + ck::tensor_layout::convolution::GNDHWK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 3); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNDHWC, TensorLayout::GKZYXC, TensorLayout::GNDHWK)); +} + +TEST(InstanceToConvTraits, ExtractsNdhwgcLayout) +{ + using DeviceInstance = LayoutTestInstance<3, + ck::tensor_layout::convolution::NDHWGC, + ck::tensor_layout::convolution::GKZYXC, + ck::tensor_layout::convolution::NDHWGK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 3); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NDHWGC, TensorLayout::GKZYXC, TensorLayout::NDHWGK)); +} + +TEST(InstanceToConvTraits, ExtractsNgcdhwLayout) +{ + using DeviceInstance = LayoutTestInstance<3, + ck::tensor_layout::convolution::NGCDHW, + ck::tensor_layout::convolution::GKZYXC, + ck::tensor_layout::convolution::NGKDHW>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 3); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NGCDHW, TensorLayout::GKZYXC, TensorLayout::NGKDHW)); +} + +// ============================================================================ +// Test Data Type Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsFp16DataType) +{ + using DeviceInstance = DataTypeTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.data_type, DataType::FP16); +} + +TEST(InstanceToConvTraits, ExtractsBf16DataType) +{ + using DeviceInstance = DataTypeTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.data_type, DataType::BF16); +} + +TEST(InstanceToConvTraits, ExtractsFp32DataType) +{ + using DeviceInstance = DataTypeTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.data_type, DataType::FP32); +} + +TEST(InstanceToConvTraits, ExtractsI8DataType) +{ + using DeviceInstance = DataTypeTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.data_type, DataType::I8); +} + +// ============================================================================ +// Test Pipeline Version Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsPipelineV2) +{ + using DeviceInstance = PipelineVersionTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V2); +} + +TEST(InstanceToConvTraits, ExtractsPipelineV3) +{ + using DeviceInstance = PipelineVersionTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V3); +} + +TEST(InstanceToConvTraits, ExtractsInterwaveScheduler) +{ + using DeviceInstance = PipelineSchedulerTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTERWAVE); +} + +// ============================================================================ +// Test GEMM Padding Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsDefaultGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT); +} + +TEST(InstanceToConvTraits, ExtractsMnkGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MNK_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::M_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsNPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::N_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsKPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::K_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMnPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MN_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMkPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MK_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsNkPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::NK_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsOPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::O_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsNoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::NO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsKoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::KO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMnoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MNO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMkoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MKO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsNkoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::NKO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMnkoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MNKO_PADDING); +} + +} // anonymous namespace diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp new file mode 100644 index 0000000000..38942f9d45 --- /dev/null +++ b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp @@ -0,0 +1,262 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// ============================================================================ +// Unit Tests for Complete Device Instance Transformations +// ============================================================================ +// +// PURPOSE: +// -------- +// These tests verify the complete instance_to_conv_traits transformation +// for entire Device class templates. Each test validates that all traits +// are correctly extracted from a specific Device class instantiation. +// +// TEST COVERAGE: +// -------------- +// Complete transformation verification for each XDL Device class template: +// 1. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +// 2. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +// 3. DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +// +// Each test verifies: +// - Spatial dimension extraction +// - Convolution direction +// - Data type detection +// - GEMM padding configuration +// - Tile dimensions (M, N, K per block) +// - Pipeline scheduler and version +// ============================================================================ + +#include + +#include +#include +#include +#include +#include + +namespace { + +using ::ck_tile::builder::ConvDirection; +using ::ck_tile::builder::DataType; +using ::ck_tile::builder::GemmPadding; +using ::ck_tile::builder::PipelineScheduler; +using ::ck_tile::builder::PipelineVersion; + +// ============================================================================ +// Comprehensive Transformation Tests - Per Device Class Template +// ============================================================================ +// These tests verify the complete InstanceTraits → ConvTraits transformation +// for each forward convolution Device class template. +// ============================================================================ + +TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffleV3) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + false>; // DirectLoad + + using InstTraits = ck_tile::reflect::InstanceTraits; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + // Verify signature information + EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT); + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock); + EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock); + EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock); + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffle) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default, // LoopSched + 1>; // NumGroupsToMerge + + using InstTraits = ck_tile::reflect::InstanceTraits; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + // Verify signature information + EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT); + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock); + EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock); + EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock); + // Verify pipeline configuration (uses LoopScheduler instead of BlockGemmPipelineScheduler) + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +TEST(InstanceToConvTraits, TransformsFwdMultipleDXdlLargeTensor) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default>; // LoopSched + + using InstTraits = ck_tile::reflect::InstanceTraits; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + // Verify signature information + EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT); + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock); + EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock); + EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock); + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +} // anonymous namespace diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp index ad31fc52bc..89baf9b51b 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp @@ -8,26 +8,27 @@ namespace { using namespace ck_tile::builder::test_utils; -TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::BACKWARD_DATA, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + constexpr ConvSignature BwdDataConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::BACKWARD_DATA, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; - constexpr auto FwdConvAlgorithm = + constexpr auto BwdDataConvAlgorithm = ConvAlgorithm_Tile_GroupedConvolutionKernel{} .with_tile_specializations(TileConvSpecialization::DEFAULT) - .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_thread_block(TileThreadBlock_64x64x64) .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) - .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_transfer(TileTransfer_4x4x4) .with_tile_optimizations(TileOptimizations{ .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); - using Builder = ConvBuilder; + using Builder = ConvBuilder; run_ck_tile_test({ "grouped_convolution_backward_data", "fp16", @@ -39,7 +40,6 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 "Default", "Intrawave", "CShuffleEpilogue", - "set", "pipeline_AgBgCrCompV3", "DoubleSmemBuffer_0", "NumWaveGroups_1", diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp index 47908e0e5b..292d852b91 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -8,26 +8,27 @@ namespace { using namespace ck_tile::builder::test_utils; -TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::BACKWARD_WEIGHT, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + constexpr ConvSignature BwdWeightConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::BACKWARD_WEIGHT, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; - constexpr auto FwdConvAlgorithm = + constexpr auto BwdWeightConvAlgorithm = ConvAlgorithm_Tile_GroupedConvolutionKernel{} .with_tile_specializations(TileConvSpecialization::DEFAULT) - .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_thread_block(TileThreadBlock_64x64x64) .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) - .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_transfer(TileTransfer_4x4x4) .with_tile_optimizations(TileOptimizations{ .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); - using Builder = ConvBuilder; + using Builder = ConvBuilder; run_ck_tile_test({ "grouped_convolution_backward_weight", "fp16", @@ -39,7 +40,6 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 "Default", "Intrawave", "CShuffleEpilogue", - "set", "pipeline_AgBgCrCompV3", "DoubleSmemBuffer_0", "NumWaveGroups_1", diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp index 083d9d9955..2c35fb5076 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp @@ -21,9 +21,9 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 constexpr auto FwdConvAlgorithm = ConvAlgorithm_Tile_GroupedConvolutionKernel{} .with_tile_specializations(TileConvSpecialization::DEFAULT) - .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_thread_block(TileThreadBlock_64x64x64) .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) - .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_transfer(TileTransfer_4x4x4) .with_tile_optimizations(TileOptimizations{ .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); @@ -39,7 +39,6 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 "Default", "Intrawave", "CShuffleEpilogue", - "set", "pipeline_AgBgCrCompV3", "DoubleSmemBuffer_0", "NumWaveGroups_1", diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index bf61eb7026..b775505a26 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -28,18 +28,31 @@ struct ThreadBlock }; static_assert(ckb::ThreadBlockDescriptor); -// Describe gridwise XDL GEMM parameters. -struct GridwiseXdlGemm +struct XdlParams { - // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! - size_t ak1 = 0; - size_t bk1 = 0; size_t m_per_xdl = 0; size_t n_per_xdl = 0; size_t m_xdl_per_wave = 0; size_t n_xdl_per_wave = 0; }; -static_assert(ckb::GridwiseXdlGemmDescriptor); +static_assert(ckb::GridwiseXdlGemmDescriptor); + +// Describe gridwise XDL GEMM parameters. +struct GridwiseFwdXdlGemm +{ + // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! + size_t ak1 = 0; + size_t bk1 = 0; + XdlParams xdl_params; +}; +static_assert(ckb::GridwiseFwdXdlGemmDescriptor); + +struct GridwiseBwdXdlGemm +{ + size_t k1 = 0; + XdlParams xdl_params; +}; +static_assert(ckb::GridwiseBwdXdlGemmDescriptor); // Describe gridwise WMMA GEMM parameters. struct GridwiseWmmaGemm @@ -49,25 +62,36 @@ struct GridwiseWmmaGemm size_t n_per_wmma = 0; size_t m_wmma_per_wave = 0; size_t n_wmma_per_wave = 0; - PipelineVersion pipeline_version; }; static_assert(ckb::GridwiseWmmaGemmDescriptor); -struct BlockGemm +struct BlockGemmPipeline { PipelineVersion pipeline_version; PipelineScheduler scheduler; }; -static_assert(ckb::BlockGemmDescriptor); +static_assert(ckb::BlockGemmPipelineDescriptor); // Describe Aand B block transfer thread cluster lengths. +template struct BlockTransfer { size_t k0; size_t m_n; size_t k1; + size_t k_batch_size; }; -static_assert(ckb::BlockTransferDescriptor); + +// Specialization for ThreadSliceLength == 3 +template <> +struct BlockTransfer<3> +{ + size_t k0; + size_t m_n; + size_t k1; +}; +static_assert(ckb::BlockTransferDescriptor, 3>); +static_assert(ckb::BlockTransferDescriptor, 4>); // Describe C block transfer thread cluster lengths. struct ThreadCluster @@ -97,31 +121,35 @@ struct Epilogue }; static_assert(EpilogueDescriptor); +template struct AccessOrder { - std::array order; + std::array order; }; -static_assert(AccessOrderDescriptor); +static_assert(ThreadClusterOrderDescriptor>); +static_assert(ThreadClusterOrderDescriptor>); -struct TransferAB +template +struct InputTransfer { - BlockTransfer block_transfer; + BlockTransfer block_transfer; LdsTransfer lds_transfer; - AccessOrder block_transfer_access_order; - AccessOrder src_access_order; + AccessOrder thread_cluster_arrange_order; + AccessOrder src_access_order; }; -struct TransferC +struct OutputTransfer { ThreadCluster thread_cluster_dims; Epilogue epilogue; }; -struct TransferABC +template +struct Transfer { - TransferAB a; - TransferAB b; - TransferC c; + InputTransfer a; + InputTransfer b; + OutputTransfer c; }; // DL-specific descriptors @@ -142,17 +170,19 @@ struct DlThreadCluster }; static_assert(ckb::DlThreadClusterDescriptor); +template struct DlBlockTransfer { - std::array thread_slice_lengths; - std::array thread_cluster_lengths; - std::array thread_cluster_arrange_order; - std::array src_access_order; - std::array src_vector_tensor_lengths; - std::array src_vector_tensor_contiguous_dim_order; - std::array dst_vector_tensor_lengths; + std::array thread_slice_lengths; + std::array thread_cluster_lengths; + std::array thread_cluster_arrange_order; + std::array src_access_order; + std::array src_vector_tensor_lengths; + std::array src_vector_tensor_contiguous_dim_order; + std::array dst_vector_tensor_lengths; }; -static_assert(ckb::DlBlockTransferDescriptor); +static_assert(ckb::DlBlockTransferDescriptor4D>); +static_assert(ckb::DlBlockTransferDescriptor5D>); struct DlEpilogue { @@ -169,9 +199,14 @@ struct ThreadBlock_ ThreadBlock thread_block; }; -struct XdlGemm_ +struct FwdXdlGemm_ { - GridwiseXdlGemm gridwise_gemm; + GridwiseFwdXdlGemm gridwise_gemm; +}; + +struct BwdXdlGemm_ +{ + GridwiseBwdXdlGemm gridwise_gemm; }; struct WmmaGemm_ @@ -179,27 +214,48 @@ struct WmmaGemm_ GridwiseWmmaGemm gridwise_gemm; }; +template struct Transfer_ { - TransferABC transfer; + Transfer transfer; }; -struct ConvSpecialization_ +struct ConvSpecializationFwd_ { - ConvFwdSpecialization fwd_specialization; + ConvSpecialization fwd_specialization; GemmSpecialization gemm_specialization; }; +struct ConvSpecializationBwdWeight_ +{ + ConvSpecialization bwd_weight_specialization; +}; + struct Prefetch_ { size_t num_gemm_k_prefetch_stages; - size_t num_groups_to_merge; PipelineScheduler loop_scheduler; }; +struct TransposeParams_ +{ + size_t max_transpose_transfer_src_scalar_per_vector{1}; + size_t max_transpose_transfer_dst_scalar_per_vector{1}; +}; + +struct GemmBatchOptions_ +{ + size_t num_conv_groups_to_merge{1}; +}; + struct BlockGemm_ { - BlockGemm block_gemm; + BlockGemmPipeline block_gemm_pipeline; +}; + +struct GridGemm_ +{ + PipelineVersion pipeline_version; }; struct DlThreadConfig_ @@ -212,33 +268,34 @@ struct DlThreadCluster_ DlThreadCluster thread_cluster; }; -struct DlBlockTransferAB +template +struct DlTransfer { - DlBlockTransfer block_transfer; -}; - -struct DlBlockTransferC -{ - DlEpilogue epilogue; -}; - -struct DlTransferABC -{ - DlBlockTransferAB a; - DlBlockTransferAB b; - DlBlockTransferC c; + DlBlockTransfer a; + DlBlockTransfer b; + DlEpilogue c; }; +template struct DlTransfer_ { - DlTransferABC transfer; + DlTransfer transfer; }; -// Specialization wrapper for large tensor support -template -struct LargeTensorWrapper +struct TwoStageSpecialization_ +{ + static constexpr ConvAlgorithmSpecialization specialization = + ConvAlgorithmSpecialization::TWO_STAGE; +}; + +struct MultipleDSpecialization_ +{ + static constexpr ConvAlgorithmSpecialization specialization = + ConvAlgorithmSpecialization::MULTIPLE_D; +}; + +struct LargeTensorSpecialization_ { - BaseAlgorithm base_algorithm; static constexpr ConvAlgorithmSpecialization specialization = ConvAlgorithmSpecialization::LARGE_TENSOR; }; @@ -329,7 +386,11 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_gemm_config(const GemmConfig& gemm) const { auto result = *this; - if constexpr(std::is_base_of_v) + if constexpr(std::is_base_of_v) + { + result.gridwise_gemm = gemm; + } + else if constexpr(std::is_base_of_v) { result.gridwise_gemm = gemm; } @@ -337,46 +398,82 @@ struct ConvAlgorithmTemplate : Components... { result.gridwise_gemm = gemm; } + else + { + static_assert(false, "Unrecognized GemmConfig type"); + } return result; } template constexpr auto with_transfer(const T& t) const { - static_assert(std::is_base_of_v); + static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || + std::is_base_of_v, ConvAlgorithmTemplate>); auto result = *this; result.transfer = t; return result; } - constexpr auto with_specializations(ConvFwdSpecialization fwd_spec, - GemmSpecialization gemm_spec) const + constexpr auto with_fwd_specializations(ConvSpecialization fwd_spec, + GemmSpecialization gemm_spec) const { - static_assert(std::is_base_of_v); + static_assert(std::is_base_of_v); auto result = *this; result.fwd_specialization = fwd_spec; result.gemm_specialization = gemm_spec; return result; } - constexpr auto with_prefetch_config(size_t k_prefetch_stages, - size_t groups_to_merge, - PipelineScheduler scheduler) const + constexpr auto with_bwd_specialization(ConvSpecialization bwd_spec) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.bwd_weight_specialization = bwd_spec; + return result; + } + + constexpr auto with_prefetch_config(size_t k_prefetch_stages, PipelineScheduler scheduler) const { static_assert(std::is_base_of_v); auto result = *this; result.num_gemm_k_prefetch_stages = k_prefetch_stages; - result.num_groups_to_merge = groups_to_merge; result.loop_scheduler = scheduler; return result; } + constexpr auto with_transpose_params(size_t max_src_scalar_per_vector, + size_t max_dst_scalar_per_vector) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.max_transpose_transfer_src_scalar_per_vector = max_src_scalar_per_vector; + result.max_transpose_transfer_dst_scalar_per_vector = max_dst_scalar_per_vector; + return result; + } + + constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.num_conv_groups_to_merge = num_groups_to_merge; + return result; + } + template constexpr auto with_block_gemm(const BG& bg) const { static_assert(std::is_base_of_v); - auto result = *this; - result.block_gemm = bg; + auto result = *this; + result.block_gemm_pipeline = bg; + return result; + } + + constexpr auto with_gridwise_gemm_pipeline(const PipelineVersion plv) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.pipeline_version = plv; return result; } @@ -401,7 +498,8 @@ struct ConvAlgorithmTemplate : Components... template constexpr auto with_dl_transfer(const T& t) const { - static_assert(std::is_base_of_v); + static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || + std::is_base_of_v, ConvAlgorithmTemplate>); auto result = *this; result.transfer = t; return result; @@ -453,26 +551,49 @@ struct ConvAlgorithmTemplate : Components... } }; -// Algorithm types +// Fwd algorithm types using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, + ConvSpecializationFwd_, + Prefetch_, + GemmBatchOptions_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, + ConvSpecializationFwd_, + BlockGemm_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, + ConvSpecializationFwd_, + GridGemm_, + Prefetch_, + GemmBatchOptions_>; + using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvAlgorithmTemplate; + DlTransfer_<>>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = - LargeTensorWrapper; + ConvAlgorithmTemplate, + ConvSpecializationFwd_, + Prefetch_, + GemmBatchOptions_, + LargeTensorSpecialization_>; +// CK Tile algorithm using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + TransposeParams_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + TransposeParams_, + GemmBatchOptions_, + TwoStageSpecialization_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + MultipleDSpecialization_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + TransposeParams_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + TransposeParams_, + GemmBatchOptions_, + TwoStageSpecialization_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + GridGemm_, + Prefetch_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + MultipleDSpecialization_>; + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/test_bwd_data_instance_traits.cpp b/experimental/builder/test/test_bwd_data_instance_traits.cpp index f26b5d7caf..fe94d16a7d 100644 --- a/experimental/builder/test/test_bwd_data_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_data_instance_traits.cpp @@ -81,7 +81,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/, GroupedConvTraitsType::FixedGemmParams::TransposeC, - ck_tile::memory_operation_enum::set /*memory_operation*/, 1 /*kNumWaveGroups*/, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; diff --git a/experimental/builder/test/test_bwd_weight_instance_traits.cpp b/experimental/builder/test/test_bwd_weight_instance_traits.cpp index c7c4e370e2..dbb3a0a8fc 100644 --- a/experimental/builder/test/test_bwd_weight_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_weight_instance_traits.cpp @@ -184,7 +184,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/, GroupedConvTraitsType::FixedGemmParams::TransposeC, - ck_tile::memory_operation_enum::set /*memory_operation*/, 1 /*kNumWaveGroups*/, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 5d6bc102e6..bcea406fa7 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -120,36 +120,34 @@ struct DefaultAlgorithm ckb::test::ThreadBlock thread_block{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; - ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8, - .bk1 = 8, - .m_per_xdl = 16, - .n_per_xdl = 16, - .m_xdl_per_wave = 8, - .n_xdl_per_wave = 8}; + ckb::test::GridwiseFwdXdlGemm gridwise_gemm{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 16, .n_per_xdl = 16, .m_xdl_per_wave = 8, .n_xdl_per_wave = 8}}; - ckb::test::TransferABC transfer{ + ckb::test::Transfer<> transfer{ .a = { - .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {.order = {0, 1, 2}}, - .src_access_order = {.order = {0, 1, 2}}, + .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {.order = {0, 1, 2}}, + .src_access_order = {.order = {0, 1, 2}}, }, .b = { - .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 2, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {.order = {0, 1, 2}}, - .src_access_order = {.order = {0, 1, 2}}, + .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {.order = {0, 1, 2}}, + .src_access_order = {.order = {0, 1, 2}}, }, .c = { @@ -161,10 +159,11 @@ struct DefaultAlgorithm }, }; - ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT; - ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; - ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4, - .scheduler = ckb::PipelineScheduler::INTRAWAVE}; + ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; + ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; + ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4, + .scheduler = + ckb::PipelineScheduler::INTRAWAVE}; }; static_assert(ckb::ConvAlgorithmDescriptor); diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index 6dd2a4eada..ad0a2cadc6 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -795,7 +795,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/, GroupedConvTraitsType::FixedGemmParams::TransposeC, - ck_tile::memory_operation_enum::set /*memory_operation*/, 1 /*kNumWaveGroups*/, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; diff --git a/experimental/builder/test/test_inline_diff.cpp b/experimental/builder/test/test_inline_diff.cpp index 8d3a90c95f..6a7a7ac8f7 100644 --- a/experimental/builder/test/test_inline_diff.cpp +++ b/experimental/builder/test/test_inline_diff.cpp @@ -5,8 +5,7 @@ #include "testing_utils.hpp" -namespace ck_tile::builder { -namespace { +using ck_tile::test::inlineDiff; TEST(InlineDiff, simpleColorDiff) { @@ -16,8 +15,8 @@ TEST(InlineDiff, simpleColorDiff) // some easy tests // you can veryfy the ungodly strings are meaningful by running echo -e "" - EXPECT_THAT(test::inlineDiff(str1, str2, true), "hello"); - EXPECT_THAT(test::inlineDiff(str1, str3, true), + EXPECT_THAT(inlineDiff(str1, str2, true), "hello"); + EXPECT_THAT(inlineDiff(str1, str3, true), "[\x1B[36mwor\x1B[0m|\x1B[35mhel\x1B[0m]l[\x1B[36md\x1B[0m|\x1B[35mo\x1B[0m]"); } @@ -28,8 +27,8 @@ TEST(InlineDiff, noColorDiff) std::string str3{"world"}; // some easy tests without color - EXPECT_THAT(test::inlineDiff(str1, str2, false), "hello"); - EXPECT_THAT(test::inlineDiff(str1, str3, false), "[wor|hel]l[d|o]"); + EXPECT_THAT(inlineDiff(str1, str2, false), "hello"); + EXPECT_THAT(inlineDiff(str1, str3, false), "[wor|hel]l[d|o]"); } TEST(InlineDiff, complexColorDiff) @@ -42,11 +41,8 @@ TEST(InlineDiff, complexColorDiff) "this part has degeahc, this part has, this part added, this part has ana extra letter"}; EXPECT_THAT( - test::inlineDiff(str5, str4, true), + inlineDiff(str5, str4, true), "this part has [\x1B[36mchanged\x1B[0m|\x1B[35mdegeahc\x1B[0m], this part has[\x1B[36m " "been left out\x1B[0m|\x1B[35m\x1B[0m], this part[\x1B[36m\x1B[0m|\x1B[35m added\x1B[0m], " "this part has an[\x1B[36m\x1B[0m|\x1B[35ma\x1B[0m] extra letter"); }; - -} // namespace -} // namespace ck_tile::builder diff --git a/experimental/builder/test/testing_utils.hpp b/experimental/builder/test/testing_utils.hpp index 7a03851ac4..b84d53b6df 100644 --- a/experimental/builder/test/testing_utils.hpp +++ b/experimental/builder/test/testing_utils.hpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #include +#include "ck_tile/builder/testing/testing.hpp" #include #include #include @@ -21,6 +22,16 @@ /// dedicated function to override to provide printing support. std::ostream& operator<<(std::ostream& os, hipError_t status); +namespace ck_tile::builder::test { + +template +std::ostream& operator<<(std::ostream& os, [[maybe_unused]] Outputs outputs) +{ + return os << ""; +} + +} // namespace ck_tile::builder::test + namespace ck_tile::test { static bool isTerminalOutput() { return isatty(fileno(stdout)) || isatty(fileno(stderr)); } @@ -150,4 +161,47 @@ struct HipStatusMatcher : public ::testing::MatcherInterface /// @param error The error to expect. ::testing::Matcher HipError(hipError_t error); +template +struct ReferenceOutputMatcher + : public ::testing::MatcherInterface> +{ + ReferenceOutputMatcher(const builder::test::Args& args, + builder::test::Outputs expected) + : args_(&args), expected_(expected) + { + } + + bool MatchAndExplain(builder::test::Outputs actual, + [[maybe_unused]] ::testing::MatchResultListener* listener) const override + { + const auto report = ck_tile::builder::test::validate(*args_, actual, expected_); + const auto errors = report.get_errors(); + + if(listener->IsInterested() && !errors.empty()) + { + *listener << errors.size() << " tensors failed to validate"; + } + + return errors.empty(); + } + + void DescribeTo(std::ostream* os) const override { *os << ""; } + + void DescribeNegationTo(std::ostream* os) const override + { + *os << "isn't equal to "; + } + + const builder::test::Args* args_; + builder::test::Outputs expected_; +}; + +template +::testing::Matcher> +MatchesReference(const builder::test::Args& args, + builder::test::Outputs expected) +{ + return ::testing::MakeMatcher(new ReferenceOutputMatcher(args, expected)); +} + } // namespace ck_tile::test diff --git a/experimental/builder/test/unit_conv_fwd_testing.cpp b/experimental/builder/test/unit_conv_fwd_testing.cpp index 3243935ca5..be95a29a2d 100644 --- a/experimental/builder/test/unit_conv_fwd_testing.cpp +++ b/experimental/builder/test/unit_conv_fwd_testing.cpp @@ -4,6 +4,7 @@ #include "impl/conv_signature_types.hpp" #include "testing_utils.hpp" #include "ck_tile/builder/testing/conv_fwd.hpp" +#include "ck_tile/builder/testing/tensor_foreach.hpp" #include #include #include @@ -12,6 +13,7 @@ namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; using ::testing::ElementsAreArray; +using ::testing::Eq; using ::testing::NotNull; constexpr auto SIGNATURE = @@ -57,6 +59,8 @@ using UniqueOutputs = ckt::UniqueOutputs; static_assert(ckt::ValidUniqueInputs); static_assert(ckt::ValidUniqueOutputs); +static_assert(ckt::TensorReflectable); +static_assert(ckt::TensorReflectable); TEST(ConvFwdTesting, MakeDescriptors) { @@ -81,3 +85,41 @@ TEST(ConvFwdTesting, Alloc) EXPECT_THAT(inputs.get().weight, NotNull()); EXPECT_THAT(outputs.get().output, NotNull()); } + +TEST(ConvFwdTesting, Validate) +{ + auto a = alloc_outputs(ARGS); + auto b = alloc_outputs(ARGS); + + // Positive test + { + ckt::Outputs::reflect( + ARGS, + [&]([[maybe_unused]] std::string_view name, + const auto& desc, + void* ckt::Outputs::*ptr) { + ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{123}); + ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{123}); + }); + + const auto report = ckt::validate(ARGS, a.get(), b.get()); + EXPECT_THAT(report.get_errors().size(), Eq(0)); + } + + // Negative test + { + size_t field_count = 0; + ckt::Outputs::reflect( + ARGS, + [&]([[maybe_unused]] std::string_view name, + const auto& desc, + void* ckt::Outputs::*ptr) { + ++field_count; + ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{2}); + ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{1}); + }); + + const auto report = ckt::validate(ARGS, a.get(), b.get()); + EXPECT_THAT(report.get_errors().size(), Eq(field_count)); + } +} diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp index ce31f41933..0df94d977e 100644 --- a/experimental/builder/test/unit_conv_tensor_layout.cpp +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -38,11 +38,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK) .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = NWGK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -57,11 +57,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW) .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = NGKW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -76,11 +76,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK) .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = GNWK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -95,11 +95,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW) .weight = {.config = {.layout = GKCX}}, .output = {.config = {.layout = NGKW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -114,11 +114,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW) .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = NGKHW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -133,11 +133,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK) .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = NHWGK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -152,11 +152,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK) .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = GNHWK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -171,11 +171,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW) .weight = {.config = {.layout = GKCYX}}, .output = {.config = {.layout = NGKHW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -190,11 +190,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW) .weight = {.config = {.layout = GKCZYX}}, .output = {.config = {.layout = NGKDHW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -209,11 +209,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK) .weight = {.config = {.layout = GKZYXC}}, .output = {.config = {.layout = NDHWGK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -228,11 +228,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK) .weight = {.config = {.layout = GKZYXC}}, .output = {.config = {.layout = GNDHWK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -273,7 +273,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout) static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = G_K_strided}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -287,7 +287,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout) static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -301,7 +301,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout) static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = G_C_strided}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -316,7 +316,7 @@ TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors) MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}, MockAuxiliaryTensorConfig{.layout = GC}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 2); using ExpectedType = @@ -333,7 +333,7 @@ TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors) MockAuxiliaryTensorConfig{.layout = GC}, MockAuxiliaryTensorConfig{.layout = G_C_strided}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 3); using ExpectedType = ck::Tuple aux_configs = { MockAuxiliaryTensorConfig{.layout = G_K_strided}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -363,7 +363,7 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution) static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = GC}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -387,11 +387,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; EXPECT_TRUE((std::is_same_v)); @@ -414,11 +414,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; EXPECT_TRUE((std::is_same_v)); @@ -442,11 +442,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALEADD_SCALEADD_RELU}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; @@ -470,11 +470,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; EXPECT_TRUE((std::is_same_v)); @@ -497,11 +497,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::BIAS_BNORM_CLAMP}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; EXPECT_TRUE((std::is_same_v)); diff --git a/experimental/builder/test/unit_conv_tensor_type.cpp b/experimental/builder/test/unit_conv_tensor_type.cpp index 7ffd446966..b32ce339fa 100644 --- a/experimental/builder/test/unit_conv_tensor_type.cpp +++ b/experimental/builder/test/unit_conv_tensor_type.cpp @@ -11,40 +11,27 @@ namespace { namespace ckb = ck_tile::builder; using ck_tile::builder::factory::internal::DataTypeToCK; -TEST(ConvTensorType, AssignsTypesForFP16) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} +template +constexpr auto check_same = std::is_same_v::type, T>; -TEST(ConvTensorType, AssignsTypesForBF16) +TEST(ConvTensorType, Exhaustive) { - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} + using enum ckb::DataType; -TEST(ConvTensorType, AssignsTypesForFP32) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} - -TEST(ConvTensorType, AssignsTypesForINT32) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} - -TEST(ConvTensorType, AssignsTypesForI8) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} - -TEST(ConvTensorType, AssignsTypesForFP8) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); + const auto type = FP32; + // This switch ensures that we get a warning (error with -Werror) if + // a variant is missing. + switch(type) + { + case UNDEFINED_DATA_TYPE: break; + case FP32: EXPECT_TRUE((check_same)); break; + case FP16: EXPECT_TRUE((check_same)); break; + case BF16: EXPECT_TRUE((check_same)); break; + case I32: EXPECT_TRUE((check_same)); break; + case FP8: EXPECT_TRUE((check_same)); break; + case I8: EXPECT_TRUE((check_same)); break; + case U8: EXPECT_TRUE((check_same)); break; + } } } // namespace diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index b35a1ced55..9005742930 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -19,7 +19,7 @@ TEST(ConvTuningParams, AssignsBlockGemmParams) { ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3; ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE; - } block_gemm; + } block_gemm_pipeline; } kAlgorithm; constexpr auto block_gemm = SetBlockGemm(); @@ -42,10 +42,7 @@ TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion) { constexpr struct Algorithm { - struct GridwiseGemm - { - ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; - } gridwise_gemm; + ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; } kAlgorithm; constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion(); @@ -78,8 +75,8 @@ TEST(ConvTuningParams, AssignsFwdConvSpecialization) { constexpr struct Algorithm { - ckb::ConvFwdSpecialization fwd_specialization = - ckb::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0; + ckb::ConvSpecialization fwd_specialization = + ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0; } kAlgorithm; constexpr auto conv_spec = SetFwdConvSpecialization(); diff --git a/experimental/builder/test/unit_debug.cpp b/experimental/builder/test/unit_debug.cpp new file mode 100644 index 0000000000..80ff291782 --- /dev/null +++ b/experimental/builder/test/unit_debug.cpp @@ -0,0 +1,464 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/tensor_foreach.hpp" +#include "ck_tile/builder/testing/debug.hpp" +#include "testing_utils.hpp" +#include +#include +#include +#include + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; + +using ck_tile::test::StringEqWithDiff; +using ::testing::ElementsAreArray; +using ::testing::Eq; +using ::testing::Gt; + +TEST(Debug, PrintDescriptor) +{ + auto desc = + ckt::make_descriptor(ckt::Extent{10, 11, 12}, ckt::PackedRightLayout{}); + + std::stringstream ss; + ckt::print_descriptor("test", desc, ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Descriptor \"test\":\n" + " data type: I32\n" + " size: 1'320 elements\n" + " space: 1'320 elements (5'280 bytes)\n" + " lengths: [10, 11, 12]\n" + " strides: [132, 12, 1]\n" + " packed: yes\n")); + + // Make sure that the stream locale does not leak. + ss.str(""); + ss << 1000; + EXPECT_THAT(ss.str(), StringEqWithDiff("1000")); +} + +TEST(Debug, LimitedForeach) +{ + { + std::vector values; + size_t delim_count = 0; + ckt::detail::limited_foreach( + 10, + 2, + [&](auto i) { values.push_back(i); }, + [&](auto skip_count) { + ++delim_count; + EXPECT_THAT(skip_count, Eq(10 - 2)); + }); + EXPECT_THAT(values, ElementsAreArray({0, 9})); + EXPECT_THAT(delim_count, Eq(1)); + } + + { + std::vector values; + size_t delim_count = 0; + ckt::detail::limited_foreach( + 100, + 9, + [&](auto i) { values.push_back(i); }, + [&](auto skip_count) { + ++delim_count; + EXPECT_THAT(skip_count, Eq(100 - 9)); + }); + EXPECT_THAT(values, ElementsAreArray({0, 1, 2, 3, 4, 96, 97, 98, 99})); + EXPECT_THAT(delim_count, Eq(1)); + } + + { + size_t call_count = 0; + size_t delim_count = 0; + ckt::detail::limited_foreach( + 50, + 100, + [&](auto i) { + EXPECT_THAT(i, Eq(call_count)); + ++call_count; + }, + [&]([[maybe_unused]] auto skip_count) { ++delim_count; }); + EXPECT_THAT(call_count, Eq(50)); + EXPECT_THAT(delim_count, Eq(0)); + } +} + +TEST(Debug, PrintTensor0D) +{ + auto desc = ckt::make_descriptor(ckt::Extent{}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), []([[maybe_unused]] size_t i) { return 123; }); + + std::stringstream ss; + ckt::print_tensor("0D", desc, a.get(), {}, ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"0D\": shape = []\n" + " 123\n")); +} + +TEST(Debug, PrintTensor1D) +{ + auto desc = ckt::make_descriptor(ckt::Extent{44}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i % 7; }); + + std::stringstream ss; + ckt::print_tensor("1D", desc, a.get(), {}, ss); + + // Note: output does not involve the size of the matrix separator fields, + // since these are not printed. + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"1D\": shape = [44]\n" + " 0 1 2 3 4 ... 4 5 6 0 1\n")); +} + +TEST(Debug, PrintTensor4D) +{ + auto desc = ckt::make_descriptor(ckt::Extent{100, 110, 120, 130}, + ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i; }); + + std::stringstream ss; + ckt::print_tensor("4D", + desc, + a.get(), + { + // Reduce default limits to have smaller output here. + // That also tests that we can configure these (to some + // extent). + .col_limit = 4, + .row_limit = 4, + .slice_limit = 4, + }, + ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"4D\": shape = [100, 110, 120, 130]\n" + "Tensor \"4D\", slice [0, 0, :, :]\n" + " 0 1 ... 128 129\n" + " 130 131 ... 258 259\n" + " ... ... ... ... ...\n" + " 15340 15341 ... 15468 15469\n" + " 15470 15471 ... 15598 15599\n" + "\n" + "Tensor \"4D\", slice [0, 1, :, :]\n" + " 15600 15601 ... 15728 15729\n" + " 15730 15731 ... 15858 15859\n" + " ... ... ... ... ...\n" + " 30940 30941 ... 31068 31069\n" + " 31070 31071 ... 31198 31199\n" + "\n" + "(skipping 10'996 slices...)\n" + "\n" + "Tensor \"4D\", slice [99, 108, :, :]\n" + " 171568800 171568801 ... 171568928 171568929\n" + " 171568930 171568931 ... 171569058 171569059\n" + " ... ... ... ... ...\n" + " 171584140 171584141 ... 171584268 171584269\n" + " 171584270 171584271 ... 171584398 171584399\n" + "\n" + "Tensor \"4D\", slice [99, 109, :, :]\n" + " 171584400 171584401 ... 171584528 171584529\n" + " 171584530 171584531 ... 171584658 171584659\n" + " ... ... ... ... ...\n" + " 171599740 171599741 ... 171599868 171599869\n" + " 171599870 171599871 ... 171599998 171599999\n")); +} + +TEST(Debug, PrintTensorCustomConfig) +{ + auto desc = + ckt::make_descriptor(ckt::Extent{10, 10, 10}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i * 101 % 77; }); + + std::stringstream ss; + ckt::print_tensor("CustomConfig", + desc, + a.get(), + { + // Reduce default limits to have smaller output here. + // That also tests that we can configure these. + .col_limit = 4, + .row_limit = 2, + .slice_limit = 6, + // Try with different sizes to make sure that the alignment + // is still correct after changing these. + .row_prefix = ">>>>", + .row_field_sep = "|||||", + .row_skip_val = "-------", + .matrix_row_skip_val = "&&&&&&&&", + }, + ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"CustomConfig\": shape = [10, 10, 10]\n" + "Tensor \"CustomConfig\", slice [0, :, :]\n" + ">>>>||||| 0||||| 24|||||-------||||| 38||||| 62\n" + ">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n" + ">>>>||||| 4||||| 28|||||-------||||| 42||||| 66\n" + "\n" + "Tensor \"CustomConfig\", slice [1, :, :]\n" + ">>>>||||| 13||||| 37|||||-------||||| 51||||| 75\n" + ">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n" + ">>>>||||| 17||||| 41|||||-------||||| 55||||| 2\n" + "\n" + "Tensor \"CustomConfig\", slice [2, :, :]\n" + ">>>>||||| 26||||| 50|||||-------||||| 64||||| 11\n" + ">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n" + ">>>>||||| 30||||| 54|||||-------||||| 68||||| 15\n" + "\n" + "(skipping 4 slices...)\n" + "\n" + "Tensor \"CustomConfig\", slice [7, :, :]\n" + ">>>>||||| 14||||| 38|||||-------||||| 52||||| 76\n" + ">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n" + ">>>>||||| 18||||| 42|||||-------||||| 56||||| 3\n" + "\n" + "Tensor \"CustomConfig\", slice [8, :, :]\n" + ">>>>||||| 27||||| 51|||||-------||||| 65||||| 12\n" + ">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n" + ">>>>||||| 31||||| 55|||||-------||||| 69||||| 16\n" + "\n" + "Tensor \"CustomConfig\", slice [9, :, :]\n" + ">>>>||||| 40||||| 64|||||-------||||| 1||||| 25\n" + ">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n" + ">>>>||||| 44||||| 68|||||-------||||| 5||||| 29\n")); +} + +TEST(Debug, PrintTensorUnlimitedMatrix) +{ + // To limit the output of the test, split the "unlimited" test up into one for the + // matrices and one for the slices. + + const ckt::Extent shape = ckt::Extent{12, 12}; + const ckt::TensorPrintConfig default_config; + + // The shape should be larger than the default, otherwise this test doesn't make + // any sense. + ASSERT_THAT(shape[1], Gt(default_config.col_limit)); + ASSERT_THAT(shape[2], Gt(default_config.row_limit)); + + auto desc = ckt::make_descriptor(shape, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i ^ 0xF; }); + + std::stringstream ss; + ckt::print_tensor("UnlimitedConfig", desc, a.get(), ckt::TensorPrintConfig::unlimited(), ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"UnlimitedConfig\": shape = [12, 12]\n" + " 15 14 13 12 11 10 9 8 7 6 5 4\n" + " 3 2 1 0 31 30 29 28 27 26 25 24\n" + " 23 22 21 20 19 18 17 16 47 46 45 44\n" + " 43 42 41 40 39 38 37 36 35 34 33 32\n" + " 63 62 61 60 59 58 57 56 55 54 53 52\n" + " 51 50 49 48 79 78 77 76 75 74 73 72\n" + " 71 70 69 68 67 66 65 64 95 94 93 92\n" + " 91 90 89 88 87 86 85 84 83 82 81 80\n" + " 111 110 109 108 107 106 105 104 103 102 101 100\n" + " 99 98 97 96 127 126 125 124 123 122 121 120\n" + " 119 118 117 116 115 114 113 112 143 142 141 140\n" + " 139 138 137 136 135 134 133 132 131 130 129 128\n")); +} + +TEST(Debug, PrintTensorUnlimitedSlices) +{ + // To limit the output of the test, split the "unlimited" test up into one for the + // matrices and one for the slices. + + const ckt::Extent shape = ckt::Extent{13, 1, 1}; + const ckt::TensorPrintConfig default_config; + + // The shape should be larger than the default, otherwise this test doesn't make + // any sense. + ASSERT_THAT(shape[0], Gt(default_config.slice_limit)); + + auto desc = ckt::make_descriptor(shape, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i * 3; }); + + std::stringstream ss; + ckt::print_tensor("UnlimitedConfig", desc, a.get(), ckt::TensorPrintConfig::unlimited(), ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"UnlimitedConfig\": shape = [13, 1, 1]\n" + "Tensor \"UnlimitedConfig\", slice [0, :, :]\n" + " 0\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [1, :, :]\n" + " 3\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [2, :, :]\n" + " 6\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [3, :, :]\n" + " 9\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [4, :, :]\n" + " 12\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [5, :, :]\n" + " 15\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [6, :, :]\n" + " 18\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [7, :, :]\n" + " 21\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [8, :, :]\n" + " 24\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [9, :, :]\n" + " 27\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [10, :, :]\n" + " 30\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [11, :, :]\n" + " 33\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [12, :, :]\n" + " 36\n")); +} + +TEST(Debug, PrintTensorFP32) +{ + auto desc = + ckt::make_descriptor(ckt::Extent{5, 5}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return std::pow(1.9999, i); }); + + std::stringstream ss; + ckt::print_tensor("FP32", desc, a.get(), {}, ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"FP32\": shape = [5, 5]\n" + " 1.000 2.000 4.000 7.999 15.997\n" + " 31.992 63.981 127.955 255.898 511.770\n" + " 1023.488 2046.874 4093.543 8186.677 16372.535\n" + " 32743.432 65483.590 130960.633 261908.172 523790.156\n" + " 1047527.938 2094951.125 4189692.750 8378966.500 16757095.000\n")); +} + +TEST(Debug, PrintTensorBF16) +{ + auto desc = + ckt::make_descriptor(ckt::Extent{5, 5}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer( + desc, a.get(), [](size_t i) { return ck::type_convert(1.2345678f * i); }); + + std::stringstream ss; + ckt::print_tensor("BF16", desc, a.get(), {}, ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"BF16\": shape = [5, 5]\n" + " 0.000 1.234 2.469 3.703 4.938\n" + " 6.188 7.406 8.625 9.875 11.125\n" + " 12.375 13.562 14.812 16.000 17.250\n" + " 18.500 19.750 21.000 22.250 23.500\n" + " 24.750 25.875 27.125 28.375 29.625\n")); +} + +TEST(Debug, PrintTensorFP8) +{ + auto desc = + ckt::make_descriptor(ckt::Extent{5, 5}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer( + desc, a.get(), [](size_t i) { return ck::type_convert(i * 0.1f); }); + + std::stringstream ss; + ckt::print_tensor("FP8", desc, a.get(), {}, ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"FP8\": shape = [5, 5]\n" + " 0.000 0.102 0.203 0.312 0.406\n" + " 0.500 0.625 0.688 0.812 0.875\n" + " 1.000 1.125 1.250 1.250 1.375\n" + " 1.500 1.625 1.750 1.750 1.875\n" + " 2.000 2.000 2.250 2.250 2.500\n")); +} + +TEST(Debug, PrintTensorSpecialFloats) +{ + auto desc = + ckt::make_descriptor(ckt::Extent{5, 5}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { + if(i % 8 == 1) + return 0.f / 0.f; + else if(i % 7 == 1) + return std::sqrt(-1.f); + else if(i % 6 == 1) + return 1.f / 0.f; + else if(i % 5 == 1) + return -1.f / 0.f; + else + return static_cast(i); + }); + + std::stringstream ss; + ckt::print_tensor("specials", desc, a.get(), {}, ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"specials\": shape = [5, 5]\n" + " 0.000 nan 2.000 3.000 4.000\n" + " 5.000 -inf inf -nan nan\n" + " 10.000 -inf 12.000 inf 14.000\n" + " -nan -inf nan 18.000 inf\n" + " 20.000 -inf -nan 23.000 24.000\n")); +} + +TEST(Debug, PrintTensorFloatPrecision) +{ + auto desc = ckt::make_descriptor(ckt::Extent{5}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return std::pow(0.9, i); }); + + std::stringstream ss; + ckt::print_tensor("FloatPrecision", + desc, + a.get(), + { + .float_precision = 10, + }, + ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"FloatPrecision\": shape = [5]\n" + " 1.0000000000 0.8999999762 0.8100000024 0.7289999723 0.6560999751\n")); +} diff --git a/experimental/builder/test/unit_device_buffer.cpp b/experimental/builder/test/unit_device_buffer.cpp index 75408acc16..548b055238 100644 --- a/experimental/builder/test/unit_device_buffer.cpp +++ b/experimental/builder/test/unit_device_buffer.cpp @@ -2,10 +2,11 @@ // SPDX-License-Identifier: MIT #include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" #include "testing_utils.hpp" #include #include -#include +#include namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; @@ -54,6 +55,11 @@ TEST(DeviceBuffer, AutoFree) // Trying to use a pointer after freeing should return en error in HIP. EXPECT_THAT(hipMemset(ptr, 0xFF, size), HipError(hipErrorInvalidValue)); + + // Reset internal HIP error state. + // Otherwise, the error may leak into other tests, triggering anything that + // checks the output of hipGetLastError(); + (void)hipGetLastError(); } TEST(DeviceBuffer, ThrowsOnOom) @@ -62,13 +68,16 @@ TEST(DeviceBuffer, ThrowsOnOom) auto check = [] { auto buffer = ckt::alloc_buffer(size); }; EXPECT_THAT(check, Throws()); + + // Reset internal HIP error state. + // Otherwise, the error may leak into other tests, triggering anything that + // checks the output of hipGetLastError(); + (void)hipGetLastError(); } TEST(DeviceBuffer, AllocTensorBuffer) { - std::vector lengths = {128, 128, 128}; - std::vector strides = {128 * 128, 128, 1}; - ckt::TensorDescriptor descriptor(lengths, strides); + ckt::TensorDescriptor descriptor({128, 128, 128}, {128 * 128, 128, 1}); auto buffer = ckt::alloc_tensor_buffer(descriptor); @@ -79,3 +88,11 @@ TEST(DeviceBuffer, AllocTensorBuffer) EXPECT_THAT(hipMemset(buffer.get(), 0xFF, descriptor.get_element_space_size_in_bytes()), HipSuccess()); } + +TEST(DeviceBuffer, AlignForward) +{ + EXPECT_THAT(ckt::align_fwd(24, 8), Eq(24)); + EXPECT_THAT(ckt::align_fwd(25, 8), Eq(32)); + EXPECT_THAT(ckt::align_fwd(0xd7c563, 0x1000), Eq(0xd7d000)); + EXPECT_THAT(ckt::align_fwd(19561, 23), Eq(19573)); +} diff --git a/experimental/builder/test/unit_error.cpp b/experimental/builder/test/unit_error.cpp new file mode 100644 index 0000000000..201780cc6a --- /dev/null +++ b/experimental/builder/test/unit_error.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/builder/testing/error.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "testing_utils.hpp" +#include +#include + +namespace ckt = ck_tile::builder::test; + +using ::testing::AllOf; +using ::testing::HasSubstr; +using ::testing::Throws; +using ::testing::ThrowsMessage; + +[[noreturn]] void throw_error() { throw ckt::HipError("test error", hipErrorInvalidValue); } + +TEST(HipError, SourceInfo) +{ + EXPECT_THAT(throw_error, + ThrowsMessage(AllOf( + // The error message should include... + // ...the user message + HasSubstr("test error"), + // ...the HIP message + HasSubstr("invalid argument"), + // ...the HIP status code, + HasSubstr("(1)"), + // ...the filename + HasSubstr("experimental/builder/test/unit_error.cpp"), + // ...the function name + HasSubstr("throw_error") + // Note: Don't include the row/column so that we can move + // stuff around in this file. + ))); +} + +TEST(CheckHip, BasicUsage) +{ + EXPECT_THAT([] { ckt::check_hip(hipSuccess); }, Not(Throws())); + EXPECT_THAT([] { ckt::check_hip(hipErrorNotMapped); }, Throws()); + EXPECT_THAT([] { ckt::check_hip(hipErrorOutOfMemory); }, Throws()); + EXPECT_THAT([] { ckt::check_hip("test message", hipErrorAlreadyMapped); }, + ThrowsMessage(HasSubstr("test message"))); +} diff --git a/experimental/builder/test/unit_tensor_descriptor.cpp b/experimental/builder/test/unit_tensor_descriptor.cpp index 07abfe44bd..ce6209795a 100644 --- a/experimental/builder/test/unit_tensor_descriptor.cpp +++ b/experimental/builder/test/unit_tensor_descriptor.cpp @@ -1,25 +1,30 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" #include "testing_utils.hpp" #include #include +#include +#include #include namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; +using ck_tile::test::StringEqWithDiff; using ::testing::ElementsAreArray; -using ::testing::Ge; +using ::testing::Eq; +using ::testing::Throws; TEST(TensorDescriptor, Basic) { - constexpr auto dt = ckb::DataType::FP16; - std::vector lengths = {123, 456, 789}; - std::vector strides = {456 * 789, 789, 1}; + constexpr auto dt = ckb::DataType::FP16; + constexpr size_t rank = 3; + ckt::Extent lengths = {123, 456, 789}; + ckt::Extent strides = {456 * 789, 789, 1}; - ckt::TensorDescriptor
descriptor(lengths, strides); + ckt::TensorDescriptor descriptor(lengths, strides); EXPECT_THAT(descriptor.get_lengths(), ElementsAreArray(lengths)); EXPECT_THAT(descriptor.get_strides(), ElementsAreArray(strides)); @@ -27,21 +32,179 @@ TEST(TensorDescriptor, Basic) TEST(TensorDescriptor, ComputeSize) { - constexpr auto dt = ckb::DataType::FP32; - std::vector lengths = {305, 130, 924}; - std::vector strides = {1000 * 1000, 1, 1000}; + constexpr auto dt = ckb::DataType::FP32; + constexpr size_t rank = 3; + ckt::Extent lengths = {305, 130, 924}; + ckt::Extent strides = {1001 * 1000, 1, 1000}; - ckt::TensorDescriptor
descriptor(lengths, strides); + ckt::TensorDescriptor descriptor(lengths, strides); - // Compute the location of the last item in memory, then add one - // to get the minimum size. - size_t expected_size = 1; + // Compute the location of the last item in memory, + // then add one to get the minimum size. + size_t expected_size = 1; + size_t expected_numel = 1; for(size_t i = 0; i < lengths.size(); ++i) { expected_size += (lengths[i] - 1) * strides[i]; + expected_numel *= lengths[i]; } - EXPECT_THAT(descriptor.get_element_space_size(), Ge(expected_size)); + EXPECT_THAT(descriptor.get_element_size(), Eq(expected_numel)); + EXPECT_THAT(descriptor.get_element_space_size(), Eq(expected_size)); EXPECT_THAT(descriptor.get_element_space_size_in_bytes(), - Ge(expected_size * ckt::data_type_sizeof(dt))); + Eq(expected_size * ckt::data_type_sizeof(dt))); +} + +TEST(TensorDescriptor, PackedRightLayout) +{ + const ckt::Extent lengths = {5125, 623, 1177, 1534}; + const auto strides = ckt::PackedRightLayout{}(lengths); + + EXPECT_THAT(strides, ElementsAreArray({623 * 1177 * 1534, 1177 * 1534, 1534, 1})); +} + +TEST(TensorDescriptor, PackedLeftLayout) +{ + const ckt::Extent lengths = {4, 15, 925, 662, 1462}; + const auto strides = ckt::PackedLeftLayout{}(lengths); + + EXPECT_THAT(strides, ElementsAreArray({1, 4, 4 * 15, 4 * 15 * 925, 4 * 15 * 925 * 662})); +} + +TEST(TensorDescriptor, MakeDescriptor) +{ + { + const ckt::Extent lengths = {10, 11, 12, 13, 14}; + + // Note: automatic inference of RANK. + const auto desc = + ckt::make_descriptor(lengths, ckt::PackedRightLayout{}); + + EXPECT_THAT(desc.get_lengths(), ElementsAreArray(lengths)); + EXPECT_THAT(desc.get_strides(), + ElementsAreArray({11 * 12 * 13 * 14, 12 * 13 * 14, 13 * 14, 14, 1})); + } + + { + const ckt::Extent lengths = {4, 3, 2}; + const ckt::Extent strides = {60, 1, 7}; + + // Note: automatic inference of RANK. + const auto desc = ckt::make_descriptor(lengths, strides); + + EXPECT_THAT(desc.get_lengths(), ElementsAreArray(lengths)); + EXPECT_THAT(desc.get_strides(), ElementsAreArray(strides)); + } +} + +TEST(TensorDescriptor, GetSpaceDescriptor) +{ + { + const auto desc = ckt::make_descriptor(ckt::Extent{4, 4, 4}, + ckt::PackedLeftLayout{}); + const auto space = desc.get_space_descriptor(); + + const auto expected = 4 * 4 * 4; + + EXPECT_THAT(decltype(space)::data_type, Eq(ckb::DataType::FP32)); + EXPECT_THAT(decltype(space)::rank, Eq(1)); + + EXPECT_THAT(decltype(space)::data_type, Eq(ckb::DataType::FP32)); + EXPECT_THAT(decltype(space)::rank, Eq(1)); + EXPECT_THAT(space.get_lengths(), ElementsAreArray({expected})); + EXPECT_THAT(space.get_strides(), ElementsAreArray({1})); + EXPECT_THAT(space.get_element_size(), Eq(expected)); + EXPECT_THAT(space.get_element_space_size(), Eq(expected)); + } + + { + const ckt::Extent lengths = {6, 3, 4}; + const ckt::Extent strides = {102, 1, 2002}; + const auto desc = ckt::make_descriptor(lengths, strides); + const auto space = desc.get_space_descriptor(); + + // Compute the location of the last item in memory, + // then add one to get the minimum size. + size_t expected_size = 1; + for(size_t i = 0; i < lengths.size(); ++i) + { + expected_size += (lengths[i] - 1) * strides[i]; + } + + EXPECT_THAT(decltype(space)::data_type, Eq(ckb::DataType::FP32)); + EXPECT_THAT(decltype(space)::rank, Eq(1)); + EXPECT_THAT(space.get_lengths(), ElementsAreArray({expected_size})); + EXPECT_THAT(space.get_strides(), ElementsAreArray({1})); + EXPECT_THAT(space.get_element_size(), Eq(expected_size)); + EXPECT_THAT(space.get_element_space_size(), Eq(expected_size)); + } +} + +TEST(TensorDescriptor, EmptyExtent) +{ + // A rank-0 tensor points to a single element + const auto desc = ckt::make_descriptor(ckt::Extent{}, ckt::Extent{}); + EXPECT_THAT(decltype(desc)::rank, Eq(0)); + EXPECT_THAT(desc.get_lengths().size(), Eq(0)); + EXPECT_THAT(desc.get_strides().size(), Eq(0)); + EXPECT_THAT(desc.get_element_size(), Eq(1)); + EXPECT_THAT(desc.get_element_space_size(), Eq(1)); + EXPECT_THAT(desc.get_element_space_size_in_bytes(), Eq(2)); + + // We expect a rank-1 tensor with the one dimension being 1. + const auto space = desc.get_space_descriptor(); + + const auto expected = 1; + + EXPECT_THAT(decltype(space)::rank, Eq(1)); + EXPECT_THAT(space.get_lengths(), ElementsAreArray({expected})); + EXPECT_THAT(space.get_strides(), ElementsAreArray({1})); + EXPECT_THAT(space.get_element_size(), Eq(expected)); + EXPECT_THAT(space.get_element_space_size(), Eq(expected)); + EXPECT_THAT(space.get_element_space_size_in_bytes(), Eq(2)); +} + +TEST(TensorDescriptor, ExtentFromVector) +{ + EXPECT_THAT(ckt::Extent<4>::from_vector(std::vector{1, 2, 3, 4}), + ElementsAreArray({1, 2, 3, 4})); + + EXPECT_THAT([] { return ckt::Extent<5>::from_vector(std::vector{1, 2}); }, + Throws()); +} + +TEST(TensorDescriptor, IsPacked) +{ + constexpr auto dt = ckb::DataType::I32; // Irrelevant for this test + EXPECT_TRUE( + ckt::make_descriptor
(ckt::Extent{101, 43, 25, 662, 654}, ckt::PackedLeftLayout{}) + .is_packed()); + EXPECT_TRUE( + ckt::make_descriptor
(ckt::Extent{5334, 235, 1563, 256, 23}, ckt::PackedRightLayout{}) + .is_packed()); + EXPECT_TRUE(ckt::make_descriptor
(ckt::Extent{}, ckt::Extent{}).is_packed()); + EXPECT_TRUE( + ckt::make_descriptor
(ckt::Extent{461, 345, 5, 93}, ckt::Extent{160425, 5, 1, 1725}) + .is_packed()); + EXPECT_FALSE( + ckt::make_descriptor
(ckt::Extent{10, 11, 12}, ckt::Extent{1, 100, 1100}).is_packed()); + EXPECT_FALSE( + ckt::make_descriptor
(ckt::Extent{30, 20, 10}, ckt::Extent{1, 1, 1}).is_packed()); +} + +TEST(TensorDescriptor, PrintExtent) +{ + { + const ckt::Extent extent{6233, 55, 1235, 52, 203}; + std::stringstream ss; + ss << extent; + EXPECT_THAT(ss.str(), StringEqWithDiff("[6233, 55, 1235, 52, 203]")); + } + + { + const ckt::Extent extent{}; + std::stringstream ss; + ss << extent; + EXPECT_THAT(ss.str(), StringEqWithDiff("[]")); + } } diff --git a/experimental/builder/test/unit_tensor_foreach.cpp b/experimental/builder/test/unit_tensor_foreach.cpp new file mode 100644 index 0000000000..f689d3c82f --- /dev/null +++ b/experimental/builder/test/unit_tensor_foreach.cpp @@ -0,0 +1,227 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_foreach.hpp" +#include "testing_utils.hpp" +#include +#include +#include +#include + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; + +using ::testing::Each; +using ::testing::Eq; + +TEST(TensorForeach, NdIter) +{ + { + ckt::NdIter iter(ckt::Extent{523, 345, 123, 601}); + + EXPECT_THAT(iter.numel(), Eq(13'338'296'505ULL)); + EXPECT_THAT(iter(0), Eq(ckt::Extent{0, 0, 0, 0})); + EXPECT_THAT(iter(1), Eq(ckt::Extent{0, 0, 0, 1})); + EXPECT_THAT(iter(601), Eq(ckt::Extent{0, 0, 1, 0})); + EXPECT_THAT(iter(601 * 123), Eq(ckt::Extent{0, 1, 0, 0})); + EXPECT_THAT(iter(601 * 123 * 10), Eq(ckt::Extent{0, 10, 0, 0})); + EXPECT_THAT(iter(((34 * 345 + 63) * 123 + 70) * 601 + 5), Eq(ckt::Extent{34, 63, 70, 5})); + } + + { + ckt::NdIter iter(ckt::Extent{}); + + EXPECT_THAT(iter.numel(), Eq(1)); + EXPECT_THAT(iter(0), Eq(ckt::Extent{})); + } +} + +TEST(TensorForeach, CalculateOffset) +{ + EXPECT_THAT(ckt::calculate_offset(ckt::Extent{1, 2, 3}, ckt::Extent{100, 10, 1}), Eq(123)); + EXPECT_THAT(ckt::calculate_offset(ckt::Extent{523, 266, 263}, ckt::Extent{1, 545, 10532}), + Eq(2915409)); + EXPECT_THAT(ckt::calculate_offset(ckt::Extent{}, ckt::Extent{}), Eq(0)); + // Note: >4 GB overflow test + EXPECT_THAT(ckt::calculate_offset(ckt::Extent{8, 2, 5, 7, 0, 4, 1, 3, 6, 9}, + ckt::Extent{1'000, + 1'000'000, + 10'000'000, + 1'000'000'000, + 1, + 10'000, + 100, + 10, + 100'000'000, + 100'000}), + Eq(size_t{7'652'948'130})); +} + +TEST(TensorForeach, VisitsCorrectCount) +{ + // tensor_foreach should visit every index exactly once. + // This test checks that the count is at least correct. + + const ckt::Extent shape = {10, 20, 30}; + + auto d_count = ckt::alloc_buffer(sizeof(uint64_t)); + ckt::check_hip(hipMemset(d_count.get(), 0, sizeof(uint64_t))); + + ckt::tensor_foreach(shape, [count = d_count.get()]([[maybe_unused]] const auto& index) { + atomicAdd(reinterpret_cast(count), 1); + }); + + uint64_t actual; + ckt::check_hip(hipMemcpy(&actual, d_count.get(), sizeof(uint64_t), hipMemcpyDeviceToHost)); + + const auto expected = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + + EXPECT_THAT(actual, Eq(expected)); +} + +TEST(TensorForeach, VisitsEveryIndex) +{ + const ckt::Extent shape = {5, 6, 7, 8, 9, 10, 11}; + const auto total = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + + // We know this is correct due to testing in unit_tensor_descriptor.cpp + const auto stride = ckt::PackedRightLayout{}(shape); + + auto d_output = ckt::alloc_buffer(sizeof(uint32_t) * total); + ckt::check_hip(hipMemset(d_output.get(), 0, sizeof(uint32_t) * total)); + + ckt::tensor_foreach(shape, [output = d_output.get(), stride](const auto& index) { + // We know this is correct due to the CalculateOffset test. + auto offset = ckt::calculate_offset(index, stride); + + // Use atomic add so that we can check that every index is visited exactly once. + atomicAdd(&reinterpret_cast(output)[offset], 1); + }); + + std::vector actual(total); + ckt::check_hip( + hipMemcpy(actual.data(), d_output.get(), sizeof(uint32_t) * total, hipMemcpyDeviceToHost)); + + EXPECT_THAT(actual, Each(Eq(1))); +} + +TEST(TensorForeach, FillTensorBuffer) +{ + auto desc = + ckt::make_descriptor(ckt::Extent{31, 54, 13}, ckt::PackedRightLayout{}); + + auto buffer = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer(desc, buffer.get(), [](size_t i) { return static_cast(i); }); + + std::vector h_buffer(desc.get_element_space_size()); + ckt::check_hip(hipMemcpy( + h_buffer.data(), buffer.get(), h_buffer.size() * sizeof(uint32_t), hipMemcpyDeviceToHost)); + + for(size_t i = 0; i < h_buffer.size(); ++i) + { + EXPECT_THAT(h_buffer[i], Eq(static_cast(i))); + } +} + +TEST(TensorForeach, FillTensor) +{ + // FillTensor with non-packed indices should not write out-of-bounds. + const ckt::Extent shape = {4, 23, 35}; + const ckt::Extent pad = {12, 53, 100}; + auto desc = ckt::make_descriptor(shape, ckt::PackedRightLayout{}(pad)); + const auto strides = desc.get_strides(); + + auto size = desc.get_element_space_size(); + auto buffer = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer(desc, buffer.get(), []([[maybe_unused]] size_t i) { return 123; }); + + ckt::fill_tensor(desc, buffer.get(), []([[maybe_unused]] const auto& index) { return 1; }); + + auto d_error = ckt::alloc_buffer(sizeof(uint32_t) * size); + ckt::check_hip(hipMemset(d_error.get(), 0, sizeof(uint32_t))); + + ckt::tensor_foreach( + // Iterate over the entire padding so that we can check out-of-bounds elements + pad, + [shape, pad, strides, size, error = d_error.get(), tensor = buffer.get()]( + const auto& index) { + const auto offset = ckt::calculate_offset(index, strides); + const auto value = reinterpret_cast(tensor)[offset]; + + // Note: The space of the descriptor will not actually be (12, 53, 100) but + // more like (4, 53, 100), as the outer stride is irrelevant. So we have to + // perform an extra bounds check here. + if(offset < size) + { + // Check if the coordinate is within the shape bounds. + bool in_bounds = true; + for(size_t i = 0; i < shape.size(); ++i) + { + if(index[i] >= shape[i]) + { + in_bounds = false; + } + } + + // In-bounds elements are 1, out-of-bounds is 123. + if(in_bounds && value != 1) + { + atomicAdd(reinterpret_cast(error), 1); + } + else if(!in_bounds && value != 123) + { + atomicAdd(reinterpret_cast(error), 1); + } + } + }); + + uint32_t error_count = 0; + ckt::check_hip(hipMemcpy(&error_count, d_error.get(), sizeof(uint32_t), hipMemcpyDeviceToHost)); + + EXPECT_THAT(error_count, Eq(0)); +} + +TEST(TensorForeach, ClearTensorZeros) +{ + const ckt::Extent shape = {5, 4, 5, 4, 5, 4, 5, 6}; + const ckt::Extent pad = {6, 6, 6, 6, 6, 6, 6, 6}; + + const auto desc = + ckt::make_descriptor(shape, ckt::PackedRightLayout{}(pad)); + + auto buffer = ckt::alloc_tensor_buffer(desc); + ckt::clear_tensor_buffer(desc, buffer.get()); + + // Check that all values are zeroed. + auto d_count = ckt::alloc_buffer(sizeof(uint64_t)); + ckt::check_hip(hipMemset(d_count.get(), 0, sizeof(uint64_t))); + + { + const auto size = desc.get_element_space_size(); + const auto strides = desc.get_strides(); + auto* count = d_count.get(); + const auto* tensor = reinterpret_cast(buffer.get()); + // Note: iterate over the entire pad, so that we can check out-of-bounds elements. + ckt::tensor_foreach(pad, + [count, tensor, strides, size]([[maybe_unused]] const auto& index) { + const auto offset = ckt::calculate_offset(index, strides); + + // Note: The space of the descriptor will not actually be (6, 6, + // ...) but more like (5, 6, ...), as the outer stride is + // irrelevant. So we have to perform an extra bounds check here. + if(offset < size && tensor[offset] != 0) + { + atomicAdd(reinterpret_cast(count), 1); + } + }); + } + + uint64_t actual; + ckt::check_hip(hipMemcpy(&actual, d_count.get(), sizeof(uint64_t), hipMemcpyDeviceToHost)); + + EXPECT_THAT(actual, Eq(0)); +} diff --git a/experimental/builder/test/unit_validation.cpp b/experimental/builder/test/unit_validation.cpp new file mode 100644 index 0000000000..a83d034ac2 --- /dev/null +++ b/experimental/builder/test/unit_validation.cpp @@ -0,0 +1,300 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/builder/testing/error.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/validation.hpp" +#include "ck_tile/builder/testing/tensor_foreach.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/testing/testing.hpp" +#include "testing_utils.hpp" +#include +#include +#include +#include + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; + +using testing::ElementsAreArray; +using testing::Eq; +using testing::StrEq; + +using ck_tile::test::MatchesReference; +using ck_tile::test::StringEqWithDiff; + +// Googletest cannot have both type AND value parameterized tests. +// For now just act lazy and use value template parameters. +template +struct Param +{ + constexpr static auto data_type = DT; + constexpr static auto shape = SHAPE; + constexpr static auto strides = STRIDES; + + constexpr static auto rank = shape.size(); + + static ckt::TensorDescriptor get_descriptor() + { + return ckt::make_descriptor(shape, strides); + } +}; + +template +struct ValidationReportTests : public ::testing::Test +{ +}; + +using Types = ::testing::Types< + Param, + Param, + Param, + Param>; + +TYPED_TEST_SUITE(ValidationReportTests, Types); + +TYPED_TEST(ValidationReportTests, SingleCorrect) +{ + const auto desc = TypeParam::get_descriptor(); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::clear_tensor_buffer(desc, a.get()); + ckt::clear_tensor_buffer(desc, b.get()); + + // Generate a sort-of-random looking sequence + auto generator = [strides = desc.get_strides()](const auto& index) { + const auto flat_index = ckt::calculate_offset(index, strides); + return static_cast((flat_index + 1) * 10'000'019 % 768'351); + }; + + ckt::fill_tensor(desc, a.get(), generator); + ckt::fill_tensor(desc, b.get(), generator); + + ckt::ValidationReport report; + report.check("correct", desc, b.get(), a.get()); + + EXPECT_THAT(report.get_errors().size(), Eq(0)); +} + +TYPED_TEST(ValidationReportTests, SingleIncorrect) +{ + const auto desc = TypeParam::get_descriptor(); + const auto packed_strides = ckt::PackedRightLayout{}(desc.get_lengths()); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::clear_tensor_buffer(desc, a.get()); + ckt::clear_tensor_buffer(desc, b.get()); + + ckt::fill_tensor(desc, a.get(), []([[maybe_unused]] const auto& i) { return 123; }); + ckt::fill_tensor(desc, b.get(), [packed_strides](const auto& index) { + const auto flat_index = ckt::calculate_offset(index, packed_strides); + return flat_index == 0 ? 0 : flat_index == 12345 ? 456 : flat_index == 999999 ? 1 : 123; + }); + + ckt::ValidationReport report; + report.check("incorrect", desc, b.get(), a.get()); + + const auto errors = report.get_errors(); + + const auto flat_size = desc.get_element_size(); + const auto expected_errors = flat_size >= 999999 ? 3 : flat_size >= 12345 ? 2 : 1; + + ASSERT_THAT(errors.size(), Eq(1)); + EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect")); + EXPECT_THAT(errors[0].wrong_elements, Eq(expected_errors)); + EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size())); +} + +TYPED_TEST(ValidationReportTests, ZeroIsIncorrect) +{ + const auto desc = TypeParam::get_descriptor(); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::clear_tensor_buffer(desc, a.get()); + ckt::clear_tensor_buffer(desc, b.get()); + + ckt::ValidationReport report; + report.check("zero_is_incorrect", desc, b.get(), a.get()); + + const auto errors = report.get_errors(); + ASSERT_THAT(errors.size(), Eq(1)); + EXPECT_THAT(errors[0].tensor_name, StrEq("zero_is_incorrect")); + EXPECT_THAT(errors[0].wrong_elements, Eq(0)); + EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size())); + EXPECT_THAT(errors[0].zero_elements, Eq(desc.get_element_size())); +} + +TEST(ValidationReportTests, MultipleSomeIncorrect) +{ + ckt::ValidationReport report; + + { + auto desc = ckt::make_descriptor({'R', 'O', 'C', 'm'}, + ckt::PackedLeftLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer( + desc, a.get(), [](size_t i) { return ck::type_convert(i % 100); }); + ckt::fill_tensor_buffer( + desc, b.get(), [](size_t i) { return ck::type_convert(i % 101); }); + + report.check("incorrect 1", desc, b.get(), a.get()); + } + + { + auto desc = + ckt::make_descriptor({'H', 'I', 'P'}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return "ROCm"[i % 4]; }); + ckt::fill_tensor_buffer(desc, b.get(), [](size_t i) { + switch(i % 4) + { + case 0: return 'R'; + case 1: return 'O'; + case 2: return 'C'; + case 3: return 'm'; + default: return 'x'; + } + }); + + report.check("correct", desc, b.get(), a.get()); + } + + { + auto desc = + ckt::make_descriptor({'G', 'P', 'U'}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer(desc, a.get(), []([[maybe_unused]] size_t i) { return 1; }); + ckt::fill_tensor_buffer(desc, b.get(), []([[maybe_unused]] size_t i) { return 555; }); + + report.check("incorrect 2", desc, b.get(), a.get()); + } + + const auto errors = report.get_errors(); + + ASSERT_THAT(errors.size(), Eq(2)); + EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect 1")); + EXPECT_THAT(errors[0].wrong_elements, Eq(46840334)); + EXPECT_THAT(errors[1].tensor_name, StrEq("incorrect 2")); + EXPECT_THAT(errors[1].wrong_elements, Eq(482800)); +} + +// MatchesReference operates on the types defined in testing.hpp, so just +// quickly define a bunch of dummy values for that. + +struct DummySignature +{ +}; + +constexpr DummySignature DUMMY_SIGNATURE = {}; + +namespace ck_tile::builder::test { + +template <> +struct Args +{ + auto make_a_descriptor() const + { + return make_descriptor(Extent{5, 5, 5, 5}, PackedRightLayout{}); + } + + auto make_b_descriptor() const + { + return make_descriptor(Extent{100000}, PackedLeftLayout{}); + } +}; + +template <> +struct Outputs +{ + void* a; + void* b; +}; + +// Explicitly implement validate for this type to test that that works. +template <> +ValidationReport validate(const Args& args, + Outputs actual, + Outputs expected) +{ + ValidationReport report; + report.check("a", args.make_a_descriptor(), actual.a, expected.a); + report.check("b", args.make_b_descriptor(), actual.b, expected.b); + return report; +} + +} // namespace ck_tile::builder::test + +TEST(MatchesReference, Correct) +{ + const ckt::Args args; + + const auto a_desc = args.make_a_descriptor(); + const auto b_desc = args.make_b_descriptor(); + + auto a_actual = ckt::alloc_tensor_buffer(a_desc); + auto b_actual = ckt::alloc_tensor_buffer(b_desc); + ckt::clear_tensor_buffer(a_desc, a_actual.get(), 1); + ckt::clear_tensor_buffer(b_desc, b_actual.get(), 2); + const auto actual = ckt::Outputs{ + .a = a_actual.get(), + .b = b_actual.get(), + }; + + auto a_expected = ckt::alloc_tensor_buffer(a_desc); + auto b_expected = ckt::alloc_tensor_buffer(b_desc); + ckt::clear_tensor_buffer(a_desc, a_expected.get(), 1); + ckt::clear_tensor_buffer(b_desc, b_expected.get(), 2); + const auto expected = ckt::Outputs{ + .a = a_expected.get(), + .b = b_expected.get(), + }; + + EXPECT_THAT(actual, MatchesReference(args, expected)); +} + +TEST(MatchesReference, Incorrect) +{ + const ckt::Args args; + + const auto a_desc = args.make_a_descriptor(); + const auto b_desc = args.make_b_descriptor(); + + auto a_actual = ckt::alloc_tensor_buffer(a_desc); + auto b_actual = ckt::alloc_tensor_buffer(b_desc); + ckt::clear_tensor_buffer(a_desc, a_actual.get(), 1); + ckt::clear_tensor_buffer(b_desc, b_actual.get(), 2); + const auto actual = ckt::Outputs{ + .a = a_actual.get(), + .b = b_actual.get(), + }; + + auto a_expected = ckt::alloc_tensor_buffer(a_desc); + auto b_expected = ckt::alloc_tensor_buffer(b_desc); + ckt::clear_tensor_buffer(a_desc, a_expected.get(), 2); + ckt::clear_tensor_buffer(b_desc, b_expected.get(), 2); + const auto expected = ckt::Outputs{ + .a = a_expected.get(), + .b = b_expected.get(), + }; + + testing::StringMatchResultListener listener; + EXPECT_TRUE(!ExplainMatchResult(MatchesReference(args, expected), actual, &listener)); + + EXPECT_THAT(listener.str(), StringEqWithDiff("1 tensors failed to validate")); +} diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index ad5a5f4f6f..e48f1dd6ba 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -15,52 +15,63 @@ using namespace test; constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{ .k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; +constexpr DlThreadConfig DlThreadConfig_16x1x4x4x1{ + .k0_per_block = 16, .k1 = 1, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; + constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}}; -constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, 1, 1, 2}, - .thread_cluster_lengths = {2, 1, 128, 1}, - .thread_cluster_arrange_order = {1, 2, 0, 3}, - .src_access_order = {1, 2, 0, 3}, - .src_vector_tensor_lengths = {4, 1, 1, 2}, - .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, - .dst_vector_tensor_lengths = {1, 1, 1, 2}}; +constexpr DlBlockTransfer<4> DlBlockTransfer_8x1x1x2{ + .thread_slice_lengths = {8, 1, 1, 2}, + .thread_cluster_lengths = {2, 1, 128, 1}, + .thread_cluster_arrange_order = {1, 2, 0, 3}, + .src_access_order = {1, 2, 0, 3}, + .src_vector_tensor_lengths = {4, 1, 1, 2}, + .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, + .dst_vector_tensor_lengths = {1, 1, 1, 2}}; -constexpr DlTransferABC DlFwdTransfer{.a = - { - .block_transfer = DlBlockTransferAB, - }, - .b = - { - .block_transfer = DlBlockTransferAB, - }, - .c = { - .epilogue = {.src_dst_access_order = {0, 1, 2, 3, 4, 5}, - .src_dst_vector_dim = 5, - .dst_scalar_per_vector = 4}, - }}; +constexpr DlTransfer<4> DlTransfer4D{.a = DlBlockTransfer_8x1x1x2, + .b = DlBlockTransfer_8x1x1x2, + .c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 4}}; -constexpr TransferABC FwdTransfer_4x64x1{ +constexpr DlBlockTransfer<5> DlBlockTransfer_1x8x1x1x1{ + .thread_slice_lengths = {1, 8, 1, 1, 1}, + .thread_cluster_lengths = {1, 2, 1, 128, 1}, + .thread_cluster_arrange_order = {0, 2, 3, 1, 4}, + .src_access_order = {0, 2, 3, 1, 4}, + .src_vector_tensor_lengths = {1, 1, 1, 1, 1}, + .src_vector_tensor_contiguous_dim_order = {0, 2, 3, 1, 4}, + .dst_vector_tensor_lengths = {1, 1, 1, 1, 1}}; + +constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1, + .b = DlBlockTransfer_1x8x1x1x1, + .c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 1}}; + +constexpr Transfer<> Transfer_4x64x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 4, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, }, .c = { @@ -72,28 +83,28 @@ constexpr TransferABC FwdTransfer_4x64x1{ }, }; -constexpr TransferABC FwdTransfer_4x64x1_fp8{ +constexpr Transfer<4> BwdTransfer_4x64x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {0, 3, 1, 2}, + .src_access_order = {0, 2, 1, 3}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {0, 3, 1, 2}, + .src_access_order = {0, 2, 1, 3}, }, .c = { @@ -105,28 +116,94 @@ constexpr TransferABC FwdTransfer_4x64x1_fp8{ }, }; -constexpr TransferABC FwdTransfer_4x16x1{ +constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{ .a = { - .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {2, 0, 1}, + .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .thread_cluster_arrange_order = {2, 0, 1}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 2}, + }, +}; + +constexpr Transfer<> Transfer_4x64x1_fp8{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + }, +}; + +constexpr Transfer<> Transfer_4x16x1{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, }, .c = { @@ -139,28 +216,28 @@ constexpr TransferABC FwdTransfer_4x16x1{ }, }; -constexpr TransferABC FwdTransfer_4x32x1{ +constexpr Transfer<> Transfer_4x32x1{ .a = { - .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, }, .b = { - .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, - .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .block_transfer_access_order = {1, 0, 2}, - .src_access_order = {1, 0, 2}, + .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .thread_cluster_arrange_order = {1, 0, 2}, + .src_access_order = {1, 0, 2}, }, .c = { @@ -172,59 +249,80 @@ constexpr TransferABC FwdTransfer_4x32x1{ }, }; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ - .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}; +constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{ + .k1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ - .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}; +constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_1x1_per_wave{ + .k1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 1, .n_xdl_per_wave = 1}}; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x2_per_wave{ - .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}; +constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ - .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}; +constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}}; -constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8, - .m_per_wmma = 32, - .n_per_wmma = 32, - .m_wmma_per_wave = 2, - .n_wmma_per_wave = 1, - .pipeline_version = PipelineVersion::V1}; +constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x2_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}}; -constexpr ThreadBlock FwdThreadBlock_256_256x256x32{.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; +constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}}; -constexpr ThreadBlock FwdThreadBlock_256_256x128x32{.block_size = 256, - .tile_size = {.m = 256, .n = 128, .k = 32}}; +constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{ + .k1 = 8, .m_per_wmma = 32, .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; -constexpr ThreadBlock FwdThreadBlock_256_128x128x32{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 32}}; +constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{ + .k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; -constexpr ThreadBlock FwdThreadBlock_256_128x128x16{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 16}}; +constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256, + .tile_size = {.m = 256, .n = 256, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_64_64x32x32{.block_size = 64, - .tile_size = {.m = 64, .n = 32, .k = 32}}; +constexpr ThreadBlock ThreadBlock_256_256x128x32{.block_size = 256, + .tile_size = {.m = 256, .n = 128, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_128_128x128x32{.block_size = 128, - .tile_size = {.m = 128, .n = 128, .k = 32}}; +constexpr ThreadBlock ThreadBlock_256_128x128x32{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_128_64x64x64{.block_size = 128, - .tile_size = {.m = 64, .n = 64, .k = 64}}; +constexpr ThreadBlock ThreadBlock_256_128x128x16{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; -constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 8}}; -constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64, + .tile_size = {.m = 64, .n = 32, .k = 32}}; -constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr ThreadBlock ThreadBlock_64_32x32x32{.block_size = 64, + .tile_size = {.m = 32, .n = 32, .k = 32}}; -constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128, + .tile_size = {.m = 128, .n = 128, .k = 32}}; -constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr ThreadBlock ThreadBlock_128_64x64x64{.block_size = 128, + .tile_size = {.m = 64, .n = 64, .k = 64}}; + +constexpr BlockGemmPipeline BlockGemmDesc_v1_intrawave = { + .pipeline_version = PipelineVersion::V1, .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemmPipeline BlockGemmDesc_v2_intrawave = { + .pipeline_version = PipelineVersion::V2, .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemmPipeline BlockGemmDesc_v3_intrawave = { + .pipeline_version = PipelineVersion::V3, .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemmPipeline BlockGemmDesc_v4_intrawave = { + .pipeline_version = PipelineVersion::V4, .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemmPipeline BlockGemmDesc_v5_intrawave = { + .pipeline_version = PipelineVersion::V5, .scheduler = PipelineScheduler::INTRAWAVE}; } // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp index 377234dd19..41a1250854 100644 --- a/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp @@ -12,35 +12,35 @@ namespace ck_tile::builder::test_utils { using namespace ck_tile::builder; using namespace test; -constexpr TileTransfer FwdTileTransfer_1x1x1{ +constexpr TileTransfer TileTransfer_1x1x1{ .a_scalar_per_vector = 1, .b_scalar_per_vector = 1, .c_scalar_per_vector = 1, }; -constexpr TileTransfer FwdTileTransfer_4x4x4{ +constexpr TileTransfer TileTransfer_4x4x4{ .a_scalar_per_vector = 4, .b_scalar_per_vector = 4, .c_scalar_per_vector = 4, }; -constexpr TileTransfer FwdTileTransfer_8x8x8{ +constexpr TileTransfer TileTransfer_8x8x8{ .a_scalar_per_vector = 8, .b_scalar_per_vector = 8, .c_scalar_per_vector = 8, }; -constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}}; +constexpr TileThreadBlock TileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}}; -constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}}; +constexpr TileThreadBlock TileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}}; -constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}}; +constexpr TileThreadBlock TileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}}; -constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}}; +constexpr TileThreadBlock TileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}}; -constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}}; +constexpr TileThreadBlock TileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}}; -constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}}; +constexpr TileThreadBlock TileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}}; constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = { .warps = {.m = 2, .n = 2, .k = 1}, diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index e4db149a98..178029e338 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -54,7 +54,7 @@ inline std::string to_string(PipelineScheduler t) } template <> -inline std::string to_string(ConvFwdSpecialization t) +inline std::string to_string(ConvSpecialization t) { std::ostringstream oss; oss << t; @@ -86,11 +86,20 @@ inline std::string to_string(ThreadBlock t) } template <> -inline std::string to_string(GridwiseXdlGemm t) +inline std::string to_string(GridwiseBwdXdlGemm t) { std::ostringstream oss; - oss << t.ak1 << "," << t.bk1 << "," << t.m_per_xdl << "," << t.n_per_xdl << "," - << t.m_xdl_per_wave << "," << t.n_xdl_per_wave; + oss << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << "," + << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; + return oss.str(); +} + +template <> +inline std::string to_string(GridwiseFwdXdlGemm t) +{ + std::ostringstream oss; + oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl + << "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; return oss.str(); } @@ -104,17 +113,29 @@ inline std::string to_string(GridwiseWmmaGemm t) } template <> -inline std::string to_string(BlockGemm t) +inline std::string to_string(BlockGemmPipeline t) { std::ostringstream oss; oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version); return oss.str(); } -template <> -inline std::string to_string(BlockTransfer t) +template +inline std::string to_string(BlockTransfer t) { - return array_to_seq(std::array{t.k0, t.m_n, t.k1}); + if constexpr(ThreadClusterRank == 4) + { + return array_to_seq(std::array{t.k_batch_size, t.k0, t.m_n, t.k1}); + } + else if constexpr(ThreadClusterRank == 3) + { + return array_to_seq(std::array{t.k0, t.m_n, t.k1}); + } + else + { + static_assert(ThreadClusterRank == 3 || ThreadClusterRank == 4, + "Unsupported ThreadClusterRank"); + } } template <> @@ -134,17 +155,17 @@ inline std::string to_string(LdsTransfer t) return oss.str(); } -template <> -inline std::string to_string(AccessOrder t) +template +inline std::string to_string(AccessOrder t) { return array_to_seq(t.order); } -template <> -inline std::string to_string(TransferAB t) +template +inline std::string to_string(InputTransfer t) { std::ostringstream oss; - oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << "," + oss << to_string(t.block_transfer) << "," << to_string(t.thread_cluster_arrange_order) << "," << to_string(t.src_access_order) << "," << t.lds_transfer.src_vector_dim << "," << t.lds_transfer.src_scalar_per_vector << "," << t.lds_transfer.lds_dst_scalar_per_vector << "," << (t.lds_transfer.lds_padding ? "true" : "false"); @@ -152,7 +173,7 @@ inline std::string to_string(TransferAB t) } template <> -inline std::string to_string(TransferC t) +inline std::string to_string(OutputTransfer t) { std::ostringstream oss; oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << "," @@ -160,8 +181,8 @@ inline std::string to_string(TransferC t) return oss.str(); } -template <> -inline std::string to_string(TransferABC t) +template +inline std::string to_string(Transfer t) { std::ostringstream oss; oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); @@ -185,7 +206,19 @@ inline std::string to_string(DlThreadCluster t) } template <> -inline std::string to_string(DlBlockTransfer t) +inline std::string to_string>(DlBlockTransfer<4> t) +{ + std::ostringstream oss; + oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths) + << "," << array_to_seq(t.thread_cluster_arrange_order) << "," + << array_to_seq(t.src_access_order) << "," << array_to_seq(t.src_vector_tensor_lengths) + << "," << array_to_seq(t.src_vector_tensor_contiguous_dim_order) << "," + << array_to_seq(t.dst_vector_tensor_lengths); + return oss.str(); +} + +template <> +inline std::string to_string>(DlBlockTransfer<5> t) { std::ostringstream oss; oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths) @@ -206,19 +239,24 @@ inline std::string to_string(DlEpilogue t) } template <> -inline std::string to_string(DlBlockTransferAB t) +inline std::string to_string(TransposeParams_ t) { - return to_string(t.block_transfer); + std::ostringstream oss; + oss << t.max_transpose_transfer_src_scalar_per_vector << "," + << t.max_transpose_transfer_dst_scalar_per_vector; + return oss.str(); } template <> -inline std::string to_string(DlBlockTransferC t) +inline std::string to_string>(DlTransfer<4> t) { - return to_string(t.epilogue); + std::ostringstream oss; + oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); + return oss.str(); } template <> -inline std::string to_string(DlTransferABC t) +inline std::string to_string>(DlTransfer<5> t) { std::ostringstream oss; oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); @@ -234,7 +272,13 @@ inline std::string to_string(ThreadBlock_ t) } template <> -inline std::string to_string(XdlGemm_ t) +inline std::string to_string(FwdXdlGemm_ t) +{ + return to_string(t.gridwise_gemm); +} + +template <> +inline std::string to_string(BwdXdlGemm_ t) { return to_string(t.gridwise_gemm); } @@ -245,33 +289,40 @@ inline std::string to_string(WmmaGemm_ t) return to_string(t.gridwise_gemm); } -template <> -inline std::string to_string(Transfer_ t) +template +inline std::string to_string(Transfer_ t) { return to_string(t.transfer); } template <> -inline std::string to_string(ConvSpecialization_ t) +inline std::string to_string(ConvSpecializationFwd_ t) { std::ostringstream oss; oss << to_string(t.fwd_specialization) << "," << to_string(t.gemm_specialization); return oss.str(); } +template <> +inline std::string to_string(ConvSpecializationBwdWeight_ t) +{ + std::ostringstream oss; + oss << to_string(t.bwd_weight_specialization); + return oss.str(); +} + template <> inline std::string to_string(Prefetch_ t) { std::ostringstream oss; - oss << t.num_gemm_k_prefetch_stages << "," << t.num_groups_to_merge << "," - << to_string(t.loop_scheduler); + oss << t.num_gemm_k_prefetch_stages << "," << to_string(t.loop_scheduler); return oss.str(); } template <> inline std::string to_string(BlockGemm_ t) { - return to_string(t.block_gemm); + return to_string(t.block_gemm_pipeline); } template <> @@ -287,7 +338,13 @@ inline std::string to_string(DlThreadCluster_ t) } template <> -inline std::string to_string(DlTransfer_ t) +inline std::string to_string>(DlTransfer_<4> t) +{ + return to_string(t.transfer); +} + +template <> +inline std::string to_string>(DlTransfer_<5> t) { return to_string(t.transfer); } @@ -299,8 +356,8 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -309,8 +366,8 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -320,7 +377,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -332,7 +389,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) << "," << to_string(static_cast(t)) << "," - << to_string(static_cast(t)); + << to_string(static_cast>(t)); return oss.str(); } @@ -340,7 +397,102 @@ template <> inline std::string to_string( ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor t) { - return to_string(t.base_algorithm); + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); } } // namespace ck_tile::builder::test diff --git a/include/ck/library/utility/device_tensor_generator.hpp b/include/ck/library/utility/device_tensor_generator.hpp index 4da38bf399..60bc3110d4 100644 --- a/include/ck/library/utility/device_tensor_generator.hpp +++ b/include/ck/library/utility/device_tensor_generator.hpp @@ -7,7 +7,6 @@ #include "ck/utility/common_header.hpp" #include "ck/library/utility/device_tensor_generator.hpp" #include "ck/utility/data_type.hpp" -#include // use xorshift for now since it is simple. Should be suitable enough, but feel free to switch in // the future @@ -107,6 +106,7 @@ template __global__ void fill_tensor_norm_rand_fp_values(T* p, float sigma, float mean, uint64_t buffer_element_size) { + static constexpr float PI = 3.141592653f; // initial values ran_state_u32 s = ran_init(); float norm[2]; @@ -115,12 +115,11 @@ fill_tensor_norm_rand_fp_values(T* p, float sigma, float mean, uint64_t buffer_e { if(j % (2 / ck::packed_size_v) == 0) { - float u1 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); - float u2 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); - norm[0] = - sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::cos(2.0f * M_PI * u2) + mean; - norm[1] = - sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::sin(2.0f * M_PI * u2) + mean; + float u1 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); + float u2 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); + float scale = sigma * ck::math::sqrt(-2.0f * ck::math::log(u1)); + norm[0] = scale * ck::math::cos(2.0f * PI * u2) + mean; + norm[1] = scale * ck::math::sin(2.0f * PI * u2) + mean; } if constexpr(ck::is_same_v) diff --git a/include/ck/library/utility/gpu_verification.hpp b/include/ck/library/utility/gpu_verification.hpp new file mode 100644 index 0000000000..e4a444ecb9 --- /dev/null +++ b/include/ck/library/utility/gpu_verification.hpp @@ -0,0 +1,425 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" +#include "ck/utility/type.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/library/utility/check_err.hpp" + +namespace ck { +namespace profiler { + +// Result struct for GPU verification with detailed error reporting +// Provides backward compatibility via operator bool() +struct GpuVerifyResult +{ + unsigned long long error_count; // Number of elements that exceeded tolerance + float max_error; // Maximum error value observed + std::size_t total; // Total number of elements compared + bool all_zero; // True if device result is all zeros (likely kernel issue) + + // Implicit conversion to bool for backward compatibility + // Allows: if (gpu_verify(...)) { ... } + operator bool() const { return error_count == 0; } + + // Calculate error percentage + float error_percentage() const + { + if(total == 0) + return 0.0f; + return static_cast(error_count) / static_cast(total) * 100.0f; + } + + // Print error summary to stderr (matches check_err format) + void print_error_summary() const + { + if(error_count > 0) + { + if(all_zero) + { + std::cerr << "WARNING: Device result is all zeros - kernel may not have executed " + "properly!" + << std::endl; + } + std::cerr << "max err: " << max_error; + std::cerr << ", number of errors: " << error_count; + std::cerr << ", " << std::setprecision(2) << std::fixed << error_percentage() + << "% wrong values" << std::endl; + } + } +}; + +// Compute relative tolerance for GPU verification +// Matches the logic of ck::utils::get_relative_threshold but handles all types +template +inline float compute_relative_tolerance(const int number_of_accumulations = 1) +{ + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using F32 = float; + using I8 = int8_t; + using I16 = int16_t; + using I32 = int32_t; + + // For integer types, tolerance is 0 + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + return 0.0f; + } + // For types supported by get_relative_threshold, use it + else if constexpr((std::is_same_v || + std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v)) + { + return static_cast( + ck::utils::get_relative_threshold( + number_of_accumulations)); + } + // For unsupported types (FP8, BF8, etc.), use default tolerances based on output type + else + { + if constexpr(std::is_same_v) + { + return 1e-3f; + } + else if constexpr(std::is_same_v) + { + return 1e-1f; + } + else + { + // For FP8/BF8 and other types, use conservative tolerance + return 1e-1f; + } + } +} + +// Device-side result structure for kernel output +// Packed into a single struct to minimize device memory allocations +struct GpuVerifyDeviceResult +{ + unsigned long long error_count; // Number of errors found + float max_error; // Maximum error value + int all_zero; // 1 = device result is all zeros, 0 = has non-zero values +}; + +// GPU verification kernel - compares device result against reference using relative and absolute +// tolerance. Tracks all errors (no early exit) to provide detailed error reporting. +// +// Uses LDS (shared memory) for block-level reduction to minimize atomic contention. +// This reduces atomic operations from O(errors) to O(blocks), providing massive speedup +// when there are many errors. +// +// Assumption: Block size is 256 +template +__global__ void gpu_verify_kernel(const T* __restrict__ device_result, + const T* __restrict__ reference_result, + float rtol, + float atol, + long long size, + GpuVerifyDeviceResult* result) +{ + constexpr int block_size = 256; + + // Shared memory for block-level reduction + __shared__ unsigned long long shared_error_count[block_size]; + __shared__ float shared_max_error[block_size]; + __shared__ int shared_has_error[block_size]; + __shared__ int shared_has_nonzero[block_size]; + + // Thread-local accumulators (in registers) + unsigned long long local_error_count = 0; + float local_max_error = 0.0f; + int local_has_error = 0; + int local_has_nonzero = 0; + + // Grid-stride loop to handle any tensor size + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + long long stride = blockDim.x * gridDim.x; + + for(long long i = idx; i < size; i += stride) + { + // Convert to float for comparison + float dev_val = type_convert(device_result[i]); + float ref_val = type_convert(reference_result[i]); + + // Check if device value is non-zero + if(dev_val != 0.0f) + { + local_has_nonzero = 1; + } + + // Compute absolute difference + float abs_diff = fabsf(dev_val - ref_val); + + // Check tolerance (matches CPU check_err logic: err > atol + rtol * abs(ref)) + if(abs_diff > atol + rtol * fabsf(ref_val)) + { + local_has_error = 1; + local_error_count++; + local_max_error = fmaxf(local_max_error, abs_diff); + } + } + + // Store thread-local results to shared memory + shared_error_count[threadIdx.x] = local_error_count; + shared_max_error[threadIdx.x] = local_max_error; + shared_has_error[threadIdx.x] = local_has_error; + shared_has_nonzero[threadIdx.x] = local_has_nonzero; + __syncthreads(); + + // Block-level reduction: 256 -> 128 -> 64 -> 32 + for(unsigned int s = block_size / 2; s >= 32; s >>= 1) + { + if(threadIdx.x < s) + { + shared_error_count[threadIdx.x] += shared_error_count[threadIdx.x + s]; + shared_max_error[threadIdx.x] = + fmaxf(shared_max_error[threadIdx.x], shared_max_error[threadIdx.x + s]); + shared_has_error[threadIdx.x] |= shared_has_error[threadIdx.x + s]; + shared_has_nonzero[threadIdx.x] |= shared_has_nonzero[threadIdx.x + s]; + } + __syncthreads(); + } + + // Final reduction of remaining 32 elements in thread 0 + if(threadIdx.x == 0) + { + for(int i = 1; i < 32; ++i) + { + shared_error_count[0] += shared_error_count[i]; + shared_max_error[0] = fmaxf(shared_max_error[0], shared_max_error[i]); + shared_has_error[0] |= shared_has_error[i]; + shared_has_nonzero[0] |= shared_has_nonzero[i]; + } + + // Single atomic update per block (reduces contention from O(errors) to O(blocks)) + if(shared_has_error[0]) + { + atomicAdd(&result->error_count, shared_error_count[0]); + atomicMax(&result->max_error, shared_max_error[0]); + } + // Update all_zero flag: if no nonzero values found, mark as all zero + if(!shared_has_nonzero[0]) + { + atomicMin(&result->all_zero, 1); + } + else + { + atomicMin(&result->all_zero, 0); + } + } +} + +// Host-side wrapper for GPU verification with explicit tolerances +// Returns GpuVerifyResult with detailed error information +template +GpuVerifyResult gpu_verify(const void* device_result, + const void* reference_result, + float rtol, + float atol, + std::size_t size, + hipStream_t stream = nullptr) +{ + // Allocate result buffer on device + GpuVerifyDeviceResult* result_dev; + hip_check_error(hipMalloc(&result_dev, sizeof(GpuVerifyDeviceResult))); + + // Initialize result struct + GpuVerifyDeviceResult result_host; + result_host.error_count = 0; // No errors yet + result_host.max_error = 0.0f; // No error observed + result_host.all_zero = 1; // Start assuming all zeros (will be cleared if nonzero found) + hip_check_error( + hipMemcpy(result_dev, &result_host, sizeof(GpuVerifyDeviceResult), hipMemcpyHostToDevice)); + + // Launch kernel with grid-stride loop + // Use 65535 as max grid size (hardware limit for grid dimension in x) + // Grid-stride loop handles any tensor size regardless of grid dimensions + constexpr int block_size = 256; + int grid_size = std::min(65535, (size + block_size - 1) / block_size); + + gpu_verify_kernel + <<>>(static_cast(device_result), + static_cast(reference_result), + rtol, + atol, + static_cast(size), + result_dev); + + hip_check_error(hipGetLastError()); + + // Synchronize the stream to ensure kernel completion before reading results + hip_check_error(hipStreamSynchronize(stream)); + + // Get result + hip_check_error( + hipMemcpy(&result_host, result_dev, sizeof(GpuVerifyDeviceResult), hipMemcpyDeviceToHost)); + + // Free device memory + hip_check_error(hipFree(result_dev)); + + // Build and return result struct + GpuVerifyResult result; + result.error_count = result_host.error_count; + result.max_error = result_host.max_error; + result.total = size; + result.all_zero = (result_host.all_zero == 1); + + return result; +} + +// Forward declaration of gpu_reduce_max +template +float gpu_reduce_max(const void* device_buffer, std::size_t size, hipStream_t stream = nullptr); + +// Host-side wrapper for GPU verification with automatic tolerance computation +// Computes max value on GPU, then computes tolerances and verifies +// Returns GpuVerifyResult with detailed error information +template +GpuVerifyResult gpu_verify(const void* device_result, + const void* reference_result, + int number_of_accumulations, + std::size_t size, + hipStream_t stream = nullptr) +{ + // Compute max absolute value on GPU (only 4 bytes transferred!) + double max_abs_value = + static_cast(gpu_reduce_max(reference_result, size, stream)); + + // Compute tolerances based on data types and accumulation count + float rtol = compute_relative_tolerance( + number_of_accumulations); + + float atol = 0.0f; + // Only compute absolute tolerance for supported types + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using F32 = float; + + if constexpr((std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v)) + { + atol = static_cast( + ck::utils::get_absolute_threshold( + max_abs_value, number_of_accumulations)); + } + + // Call the explicit tolerance version + return gpu_verify(device_result, reference_result, rtol, atol, size, stream); +} + +// GPU reduction kernel for computing max(abs(data)) +// This is an internal kernel called only by gpu_reduce_max() wrapper. +// +// Assumption: Block size is 256 +template +__global__ void +gpu_reduce_max_kernel(const T* __restrict__ data, long long size, float* __restrict__ max_val) +{ + constexpr int block_size = 256; + __shared__ float shared_max[block_size]; + + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + long long stride = blockDim.x * gridDim.x; + + float local_max = 0.0f; + + for(long long i = idx; i < size; i += stride) + { + float val = fabsf(type_convert(data[i])); + local_max = fmaxf(local_max, val); + } + + shared_max[threadIdx.x] = local_max; + __syncthreads(); + + // Block-level reduction: 256 -> 128 -> 64 -> 32 + for(unsigned int s = block_size / 2; s >= 32; s >>= 1) + { + if(threadIdx.x < s) + { + shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]); + } + __syncthreads(); + } + + // Final reduction of remaining 32 elements in thread 0 + if(threadIdx.x == 0) + { + for(int i = 1; i < 32; ++i) + { + shared_max[0] = fmaxf(shared_max[0], shared_max[i]); + } + + // Single atomic update per block + atomicMax(max_val, shared_max[0]); + } +} + +// Host-side wrapper for GPU max reduction +// Computes max(abs(data)) and returns as float +// Only transfers 4 bytes (the final max value) instead of entire tensor +template +float gpu_reduce_max(const void* device_buffer, std::size_t size, hipStream_t stream) +{ + if(size == 0) + { + return 0.0f; + } + + // Allocate device memory for result + float* max_dev; + hip_check_error(hipMalloc(&max_dev, sizeof(float))); + + // Initialize to zero + float init_val = 0.0f; + hip_check_error(hipMemcpy(max_dev, &init_val, sizeof(float), hipMemcpyHostToDevice)); + + // Launch reduction kernel + // Use 1024 blocks max for reduction to balance occupancy vs. grid-stride iterations + // For very large tensors (>256M elements), grid-stride loop handles the remainder + constexpr int block_size = 256; + int grid_size = std::min(1024, (size + block_size - 1) / block_size); + + gpu_reduce_max_kernel<<>>( + static_cast(device_buffer), static_cast(size), max_dev); + + hip_check_error(hipGetLastError()); + + // Synchronize if using default stream + if(stream == nullptr) + { + hip_check_error(hipDeviceSynchronize()); + } + + // Copy result to host (only 4 bytes!) + float max_host; + hip_check_error(hipMemcpy(&max_host, max_dev, sizeof(float), hipMemcpyDeviceToHost)); + + // Free device memory + hip_check_error(hipFree(max_dev)); + + return max_host; +} + +} // namespace profiler +} // namespace ck diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp index 896c048781..ea1c15b1aa 100644 --- a/include/ck/stream_config.hpp +++ b/include/ck/stream_config.hpp @@ -4,7 +4,6 @@ #pragma once #include -#include struct StreamConfig { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index 3b12e7feb0..4f884b1df3 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -151,7 +151,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1 PrefetchStages; } + static bool __host__ __device__ BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } static TailNumber BlockLoopTailNum(index_t num_loop) { @@ -707,7 +710,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1 PrefetchStages; } + __host__ __device__ static bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } static TailNumber BlockLoopTailNum(index_t num_loop) { diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp index ade8035877..2154f35815 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp @@ -3,6 +3,11 @@ #pragma once +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/stream_utility.hpp" + #include "device_grouped_gemm.hpp" namespace ck { @@ -43,6 +48,59 @@ struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm +struct TileLoopKernelConfig +{ + // The oversubscription factor for the number of blocks that can simultaneously reside on + // GPU. + static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1; + // static constexpr int BLOCK_WAVES = BlockSize / get_warp_size(); + static constexpr int CU_SIMDS = 4; + // Assume we want to have at most 2 waves per SIMD + // static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + static int GetCuBlocks() + { + int BLOCK_WAVES = BlockSize / get_warp_size(); + return ck::math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + } + + template + static int CalculateMaxOccupancyGridSize(const KernelFunction& kernel, + const StreamConfig& stream_config) + { + // Calculate max number of workgroups that can simultaneously reside on the CU. + int occ_num_blocks = GetKernelOccupancy(kernel); + int cu_count = getAvailableComputeUnitCount(stream_config); + + if(stream_config.log_level_ > 0) + { + std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks + << ", available CUs count: " << cu_count << ", occup. grid size: " + << ck::math::min(occ_num_blocks, GetCuBlocks()) * cu_count << std::endl; + } + + return cu_count * ck::math::min(occ_num_blocks, GetCuBlocks()); + } + + template + static int GetKernelOccupancy(const KernelFunction& kernel) + { + int occupancy = 0; + ck::hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + return occupancy; + } + + static int GetComputeUnitCount() + { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + ck::hip_check_error(hipGetDevice(&dev)); + ck::hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + } +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..47ef2e339d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,956 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/utility/scheduler_enum.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_contraction_multiple_d_wmma_cshuffle_v3(typename DeviceOp::Argument karg) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) + static constexpr index_t NumDTensor = GridwiseOp::NumDTensor; + + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(karg.compute_ptr_offset_of_batch_.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(karg.compute_ptr_offset_of_batch_.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(karg.compute_ptr_offset_of_batch_.GetEPtrOffset(g_idx)); + + const auto ds_batch_offset = + amd_wave_read_first_lane(karg.compute_ptr_offset_of_batch_.GetDsPtrOffset(g_idx)); + + typename GridwiseOp::AsGridPointer p_as_grid_batch{karg.p_a_grid_ + a_batch_offset}; + typename GridwiseOp::BsGridPointer p_bs_grid_batch{karg.p_b_grid_ + b_batch_offset}; + typename GridwiseOp::DsGridPointer p_ds_grid_batch; + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_batch(i) = karg.p_ds_grid_[i] + ds_batch_offset[i]; }); + + using EpilogueType = typename std::conditional::type; + + constexpr index_t LDS_size = GridwiseOp::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + const auto a_grid_desc_ak0_m_ak1 = + GridwiseOp::MakeAGridDescriptor_AK0_M_AK1(karg.a_grid_desc_m_k_); + const auto b_grid_desc_bk0_n_bk1 = + GridwiseOp::MakeBGridDescriptor_BK0_N_BK1(karg.b_grid_desc_n_k_); + + auto epilogue_args = EpilogueType{}; + GridwiseOp::template Run( + p_as_grid_batch, + p_bs_grid_batch, + p_ds_grid_batch, + karg.p_e_grid_ + e_batch_offset, + p_shared, + make_tuple(a_grid_desc_ak0_m_ak1), + make_tuple(b_grid_desc_bk0_n_bk1), + karg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + karg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + karg.block_2_etile_map_, + karg.a_element_op_, + karg.b_element_op_, + karg.cde_element_op_, + epilogue_args); +#else + ignore = karg; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] + +// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner +// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and +// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted +// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into +// TensorSpecialization::Default with NumDimG/M/N/K = 1 +// +// Detail- Packed tensor satisfies +// stride_0 = 1 +// stride_i = stride_{i - 1} * extent_{i - 1} +// So tensor +// [G0, G1, G2, M, N] +// transposed into tensor +// [G0, G2, G1, M, N] +// with strides +// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1] +// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some +// strides from input tensor extents so finer dimension information is lost. Merging dimensions is +// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1. +// +// Might need to expose dimension order to the interface to fully support +// TensorSpecialization::Packed in a traditional sense of "packed" tensor +template +struct DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3 + : public DeviceBatchedContractionMultipleD +{ + using DeviceOp = DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && + a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto a_ms_ks_lengths = to_tuple( + a_gs_ms_ks_lengths_vec, Number{}, Number{}); + const auto a_ms_ks_strides = to_tuple( + a_gs_ms_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + } + + // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) + { + assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && + b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto b_ns_ks_lengths = to_tuple( + b_gs_ns_ks_lengths_vec, Number{}, Number{}); + const auto b_ns_ks_strides = to_tuple( + b_gs_ns_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_ms_ns_lengths = to_tuple( + e_gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto e_ms_ns_strides = to_tuple( + e_gs_ms_ns_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(e_ms_ns_strides[Number{}], + e_ms_ns_strides[Number{}])); + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + else + { + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_G_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_gs_ms_ns_lengths = + to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto e_gs_ms_ns_strides = + to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}])); + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + else + { + // naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + // transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + static auto MakeDsGridDescriptor_G_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + // GridwiseGemm + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using DsLayout = decltype(generate_tuple( + [](auto) { return ck::tensor_layout::gemm::RowMajor{}; }, Number{})); + using ELayout = ck::tensor_layout::gemm::RowMajor; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // PermuteA + false // PermuteB + >; + + // block-to-e-tile map + using Block2ETileMap = GridwiseGemm::Block2CTileMap; + + // problem grid descriptors + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + + using DsGridDesc_G_M_N = remove_cvref_t; + using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {})); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}, 0, 0))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{}, 0, 0))>; + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, + index_t batch_stride_B, + DsGridDesc_G_M_N ds_grid_desc_g_m_n, + EGridDesc_G_M_N e_grid_desc_g_m_n) + : batch_stride_A_(batch_stride_A), + batch_stride_B_(batch_stride_B), + ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n), + e_grid_desc_g_m_n_(e_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_A_; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_B_; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = static_cast(g_idx) * + ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0)); + }); + + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * + e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0)); + } + + private: + index_t batch_stride_A_; + index_t batch_stride_B_; + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + KBatch(1), + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ns_lengths, a_gs_ms_ks_strides)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + ds_grid_desc_g_m_n_{ + DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, + e_grid_desc_g_m_n_{ + DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + compute_ptr_offset_of_batch_{a_gs_ms_ks_strides[NumDimG - 1], + b_gs_ns_ks_strides[NumDimG - 1], + ds_grid_desc_g_m_n_, + e_grid_desc_g_m_n_} + { + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0, + "Invalid number of dimensions"); + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths[i], + ds_gs_ms_ns_strides[i]); + }); + + // Extract 2D GEMM dimensions + G = e_grid_desc_g_m_n_.GetLength(I0); + M = e_grid_desc_g_m_n_.GetLength(I1); + N = e_grid_desc_g_m_n_.GetLength(I2); + K = a_grid_desc_m_k_.GetLength(I1); + AK0 = GridwiseGemm::CalculateAK0Padded(K); + + index_t MBlock = GridwiseGemm::CalculateMBlock(M); + index_t NBlock = GridwiseGemm::CalculateMBlock(N); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_, MBlock, NBlock); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_, MBlock, NBlock); + + block_2_etile_map_ = GridwiseGemm::DefaultBlock2CTileMap(M, N); + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + index_t G, M, N, K; + index_t KBatch; // Always 1, but included for compatability with GridwiseGemm::CheckValidity + index_t AK0; // Also included for compatibility + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + + // tensor descriptors for block/thread-wise copy + // AK0_M_AK1/BK0_N_BK1 are generated in the kernel to match the transfer method used + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!DeviceOp::IsSupportedArgument(arg)) + { + throw std::runtime_error( + "wrong! DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3 has invalid " + "setting"); + } + + const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.M, arg.N); + + auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr auto tail_num = tail_number.value; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + const auto kernel = + kernel_contraction_multiple_d_wmma_cshuffle_v3; + + return launch_and_time_kernel( + stream_config, kernel, dim3(grid_size, arg.G, 1), dim3(BlockSize), 0, arg); + }; + + bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(arg.K); + TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.K); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + throw std::runtime_error( + "Invalid HasMainKBlockLoop and TailNum combination for pipeline V1!\n"); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + throw std::runtime_error( + "Invalid HasMainKBlockLoop and TailNum combination for pipeline V3!\n"); + } + } + else + { + throw std::runtime_error("Invalid pipeline version! Only V1 and V3 supported\n"); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "GPU Arch not supported" << std::endl; + } + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && + (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), + "Wrong dimension for A or B vector loads, should be 1 or 2!"); + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + ds_gs_ms_ns_lengths, + ds_gs_ms_ns_strides, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + ds_gs_ms_ns_lengths, + ds_gs_ms_ns_strides, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3" + << "<" + << NumDimG << ", " + << NumDimM << ", " + << NumDimN << ", " + << NumDimK << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << ABlockTransferSrcVectorDim << ", " + << BBlockTransferSrcVectorDim + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp index 2a1a210398..126d107725 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -314,6 +314,10 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3 { ActiveWorkgroupsPerCU() { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return; + } constexpr int dynamic_smem_size = 0; int max_occupancy = 0; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp index 11e2add132..a18f108e47 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -60,8 +60,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const long_index_t c_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -84,7 +90,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; }); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; GridwiseGemm::template Run( p_as_grid_shift, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp index ee1ddc494d..b88f071a96 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -46,8 +46,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); // The normal approach to batching would be to increase the grid size by just stretching out // the grid Z dimension (which is the outermost dimension), but this depends on lower level // functions not directly using the Z dimension for other calculations. As it turns out, k @@ -86,7 +92,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; }); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; GridwiseGemm::template Run( p_as_grid_shift, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..e8e3b69cb5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp @@ -0,0 +1,685 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_bias_add_reduce_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, + typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid, + const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops, + const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops, + const typename ReduceTrait::D0ElementwiseOperation_ d0_element_op) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle; + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + auto epilogue_args = EpilogueType( + p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M, d0_element_op); + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg, epilogue_args); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = p_reduces_grid; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = d0_element_op; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3 + : public DeviceGemmReduce<1, ReduceOperations::Size()> +{ + using CDEShuffleBlockTransferScalarPerVectors = Sequence; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + Tuple, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + Tuple, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB, + false, // IsBPreShuffled + false, // ForceThreadTileTransfer + true>; // IsFusedKernel + + using ReduceTrait = ReduceTrait_; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + EDataType* p_e_grid, + const BiasDataType* p_bias_grid, + const D0DataType* p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + D0ElementwiseOperation d0_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_e_grid_{p_e_grid}, + p_bias_grid_{p_bias_grid}, + p_d0_grid_{p_d0_grid}, + p_reduces_grid_{p_reduces_grid}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw}, + StrideA_{StrideA}, + StrideB_{StrideB}, + StrideC_{StrideC}, + StrideC1_{StrideC1}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + d0_element_op_{d0_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + const BiasDataType* p_bias_grid_; + const D0DataType* p_d0_grid_; + ReducePtrsGlobal p_reduces_grid_; + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + index_t StrideA_; + index_t StrideB_; + index_t StrideC_; + index_t StrideC1_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + D0ElementwiseOperation d0_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + typename GridwiseGemm::Argument gemm_arg{ + std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{arg.p_bias_grid_, arg.p_d0_grid_}, + static_cast(arg.p_e_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{0, arg.StrideC1_}, // StrideDs + arg.StrideC_, // StrideE + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + if(stream_config.log_level_ > 0) + { + gemm_arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(gemm_arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.MRaw_, arg.NRaw_, 1); + + float ave_time = 0; + + index_t K_split = (arg.KRaw_ + KPerBlock - 1) / KPerBlock * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.KRaw_); + + const auto Run = [&](const auto& kernel) { + // Note: cache flushing not supported + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.p_reduces_grid_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.d0_element_op_); + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(TailNum == TailNumber::Full) + { + const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline setting"); + } + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(TailNum == TailNumber::Full) + { + const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline v1 setting"); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(TailNum == TailNumber::Even) + { + const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + else if(TailNum == TailNumber::Odd) + { + const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3< + GridwiseGemm, + ReduceTrait, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + throw std::runtime_error("wrong! Invalid pipeline v3 setting"); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Device implementation supports only gfx11 and gfx12! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "FP8 and BF8 not supported on gfx11! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if((arg.KRaw_ % AK1 != 0 || arg.KRaw_ % BK1 != 0) && + !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Without padding, K must be divisible by AK1 and BK1! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + typename GridwiseGemm::Argument gemm_arg{ + std::array{arg.p_a_grid_}, + std::array{arg.p_b_grid_}, + std::array{arg.p_bias_grid_, arg.p_d0_grid_}, + static_cast(arg.p_e_grid_), + arg.MRaw_, + arg.NRaw_, + arg.KRaw_, + std::array{arg.StrideA_}, // StrideAs + std::array{arg.StrideB_}, // StrideBs + std::array{0, arg.StrideC1_}, // StrideDs + arg.StrideC_, // StrideE + 1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_}; + + return GridwiseGemm::CheckValidity(gemm_arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op) + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t /* KBatch */ = 1) override + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmBiasAddReduce_Wmma_CShuffleV3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp index d35f22ba4a..f0216c3f71 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp @@ -273,7 +273,10 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3 ComputeTypeA, ComputeTypeB, PermuteA, - PermuteB>; + PermuteB, + false, + false, + true>; // Welford 2nd part kernel template diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp index 0240fcb619..317c4073df 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -49,8 +49,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - auto epilogue_args = - EpilogueType(p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M); + auto epilogue_args = EpilogueType(p_reduces_grid, + reduce_in_element_ops, + reduce_out_element_ops, + karg.M, + tensor_operation::element_wise::PassThrough{}); GridwiseGemm::template Run( p_shared, splitk_batch_offset, karg, epilogue_args); @@ -184,10 +187,14 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera ComputeTypeA, ComputeTypeB, PermuteA, - PermuteB>; + PermuteB, + false, + false, + true>; using ReduceTrait = ReduceTrait_ +#include +#include +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Wmma_CShuffleV3_BPreshuffle + : public DeviceGemmV2BPreshuffle +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + Tuple<>, + CLayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB, + true>; + + using Argument = typename GridwiseGemm::Argument; + + int GetPreShuffleParameters() override { return NPerWmma; } + + using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common, + Tuple, + Tuple<>, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + CShuffleBlockTransferScalarPerVector, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + true>; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteA() override { return PermuteA; } + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{std::array{p_a}, + std::array{p_b}, + std::array{}, + static_cast(p_c), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + std::array{}, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(std::array{p_a}, + std::array{p_b}, + std::array{}, + static_cast(p_c), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + std::array{}, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_BPreshuffle_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x" << NPerWmma << ", " + << "WaveMap: " + << MRepeat << "x" << NRepeat << ", " + << "VmemReadVec: " + << ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " + << "BlkGemmPipelinePrefetchStages: " + << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", " + << "Kpack: " + << GridwiseGemm::KPack; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index 7bc3be1a95..bbf62d5fbe 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -63,11 +63,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using e_data_type = remove_cvref_t>; - if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) + if constexpr(EGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index d33e807828..b324845c3e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1698,6 +1698,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 valid = false; } } + else + { + valid = false; + } } else { @@ -1716,6 +1720,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 valid = false; } } + else + { + valid = false; + } } if(!valid) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index a3b8be8bf8..bc072a7019 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -50,7 +50,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3( + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d( typename GridwiseGemm::Argument karg, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, @@ -62,10 +62,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) + if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif @@ -861,30 +858,32 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; Run(kernel); } else { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; Run(kernel); } } @@ -900,30 +899,32 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; Run(kernel); } else { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; Run(kernel); } } @@ -1028,6 +1029,17 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { return false; } + + if(arg.k_batch_ > 1 && ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported splitK on gfx11." << std::endl; + } + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + if constexpr(std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 1807dc1d9f..d3bf2a364a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -52,19 +52,20 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const index_t batch_count, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_xdlops_bwd_weight_multiple_d( + const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) @@ -568,7 +569,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle int max_occupancy = 0; hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( &max_occupancy, - kernel_batched_gemm_xdlops_bwd_weight< + kernel_batched_gemm_xdlops_bwd_weight_multiple_d< GridwiseGemm, ADataType, BDataType, @@ -841,7 +842,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); }; - const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< + const auto kernel = kernel_batched_gemm_xdlops_bwd_weight_multiple_d< GridwiseGemm, ADataType, BDataType, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 7f1669cf13..f9b2ff0596 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -63,28 +63,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) + { +#endif + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); - __shared__ char p_shared[LDS_size]; + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - - GridwiseGemm::template Run(p_shared, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - compute_ptr_offset_of_batch, - num_k_per_block, - karg, - epilogue_args); + GridwiseGemm::template Run(p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + compute_ptr_offset_of_batch, + num_k_per_block, + karg, + epilogue_args); +#if defined(__gfx11__) + } +#endif #else ignore = karg; ignore = a_grid_desc_ak0_m_ak1; @@ -460,6 +466,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 { ActiveWorkgroupsPerCU() { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return; + } constexpr int dynamic_smem_size = 0; constexpr index_t minimum_occupancy = BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; @@ -1179,6 +1189,16 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 return false; } + if(arg.k_batch_ > 1 && ck::is_gfx11_supported()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported splitK on gfx11." << std::endl; + } + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + // Check this here, it allows to use other instances from factory even // if workspace is not allocated if(!arg.p_workspace_) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 3e8a0fd3fb..211496b3ff 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -24,6 +24,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -60,13 +61,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - [[maybe_unused]] const index_t num_k_per_block) + [[maybe_unused]] const index_t num_k_per_block, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_hack) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + + const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0; + const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0; const long_index_t a_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); @@ -77,23 +84,29 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + DispatchSplitKHack(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_hack); } #else ignore = karg; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = split_k_offset_hack; #endif // end of if (defined(__gfx9__)) } @@ -118,14 +131,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - [[maybe_unused]] const index_t num_k_per_block) + [[maybe_unused]] const index_t num_k_per_block, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_hack) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + + const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0; + const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0; const long_index_t a_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); @@ -139,24 +158,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + DispatchSplitKHack_2Lds(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_hack); } #else ignore = karg; + ignore = split_k_offset_hack; + ignore = split_k_stride_a; + ignore = split_k_stride_b; #endif // end of if (defined(__gfx9__)) } @@ -693,7 +718,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle k_batch_ = split_k; } - const auto descs = + // Create initial descriptors with hack=false to check compactness + const auto descs_initial = conv_to_gemm_transformer_v2 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( Conv_N_, @@ -709,11 +735,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle conv_filter_dilations, input_left_pads, input_right_pads, - k_batch_); - - a_grid_desc_k0_m_k1_ = descs[I0]; - b_grid_desc_k0_n_k1_ = descs[I1]; - ce_grid_desc_m_n_ = descs[I2]; + k_batch_, + false, // hack=false for initial check + true); // use_full_batch_kindex ce_elementwise_grid_desc_m_n_ = conv_to_gemm_transformer_v1 @@ -733,6 +757,67 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle input_right_pads, k_batch_)[I2]; + split_k_offset_hack_ = + SplitKHackEligibility::Check( + descs_initial[I0], + descs_initial[I1], + k_batch_, + Conv_N_, + output_spatial_lengths_, + KPerBlock); + + // Create final descriptors with correct hack flag + const auto descs = + conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides_transposed, + e_g_k_c_xs_strides_transposed, + a_g_n_k_wos_strides_transposed, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_, + split_k_offset_hack_, // Use determined hack flag + true); // use_full_batch_kindex + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; + + // Step 5: Calculate stride using CalculateOffset on FINAL descriptors + if(split_k_offset_hack_) + { + const index_t k0_per_batch = a_grid_desc_k0_m_k1_.GetLength(I0) / k_batch_; + const auto idx_start = make_multi_index(0, 0, 0); + const auto idx_next = make_multi_index(k0_per_batch, 0, 0); + split_k_stride_a_ = a_grid_desc_k0_m_k1_.CalculateOffset(idx_next) - + a_grid_desc_k0_m_k1_.CalculateOffset(idx_start); + } + else + { + split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); + } + + if(split_k_offset_hack_) + { + const index_t k0_per_batch = b_grid_desc_k0_n_k1_.GetLength(I0) / k_batch_; + const auto idx_start = make_multi_index(0, 0, 0); + const auto idx_next = make_multi_index(k0_per_batch, 0, 0); + split_k_stride_b_ = b_grid_desc_k0_n_k1_.CalculateOffset(idx_next) - + b_grid_desc_k0_n_k1_.CalculateOffset(idx_start); + } + else + { + split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize(); + } + const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); @@ -869,6 +954,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const std::array& input_left_pads_; const std::array& input_right_pads_; long_index_t c_space_size_bytes; + + bool split_k_offset_hack_; + long_index_t split_k_stride_a_, split_k_stride_b_; }; // Invoker @@ -971,7 +1059,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, - num_k_per_block); + num_k_per_block, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_hack_); } else { @@ -987,7 +1078,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, - num_k_per_block); + num_k_per_block, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_hack_); } }; @@ -1920,14 +2014,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } } - constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && - arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && - arg.ce_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) - { - return false; - } - return true; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 213b72050e..3f8093afe1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -64,11 +64,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using e_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) + if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd) { #endif constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< @@ -419,6 +415,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { ActiveWorkgroupsPerCU() { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return; + } constexpr int dynamic_smem_size = 0; constexpr index_t minimum_occupancy = BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; @@ -1089,18 +1089,14 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 return false; } - if constexpr(std::is_same_v || - std::is_same_v) + if(gemm_arg.KBatch > 1 && ck::is_gfx11_supported()) { - if(gemm_arg.KBatch > 1 && ck::is_gfx11_supported()) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Unsupported splitK on gfx11." << std::endl; - } - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - return false; + std::cout << "Unsupported splitK on gfx11." << std::endl; } + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; } if constexpr(std::is_same_v || std::is_same_v || diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 42ad21dafe..976b6f1ef8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -21,6 +21,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -33,6 +34,74 @@ namespace ck { namespace tensor_operation { namespace device { +// Dispatch helper function for split-K hack - handles 2-way dispatch based on runtime flag +template +__device__ void DispatchBatchedGemmSplitKHack(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + void* p_shared, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_hack, + index_t k_batch) +{ + if(split_k_offset_hack) + { + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + split_k_stride_a, + split_k_stride_b, + k_batch); + } + else + { + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + split_k_stride_a, + split_k_stride_b, + k_batch); + } +} + template (p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + DispatchBatchedGemmSplitKHack( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + split_k_stride_a, + split_k_stride_b, + split_k_offset_hack, + k_batch); } #else ignore = p_a_grid; @@ -104,6 +193,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = batch_count; ignore = block_2_ctile_map; ignore = compute_ptr_offset_of_batch; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = split_k_offset_hack; + ignore = k_batch; compute_ptr_offset_of_batch.GetAPtrOffset(0); compute_ptr_offset_of_batch.GetBPtrOffset(0); @@ -459,7 +552,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle remove_reference_t, remove_reference_t, ComputePtrOffsetOfStridedBatch<>, - false>, // Both true/false give the same occupancy. + false>, // HasMainKBlockLoop - both true/false give the same occupancy BlockSize, dynamic_smem_size)); return std::max(1, max_occupancy); @@ -576,6 +669,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle k_batch_ = split_k; } + // Create descriptors first (with hack flags temporarily set to false) + // so we can check if element space sizes are divisible by k_batch + const auto descs_initial = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides_transposed, + e_g_k_c_xs_strides_transposed, + a_g_n_k_wos_strides_transposed, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_, + false); // split_k_offset_b_hack (temporary) + + split_k_offset_hack_ = + SplitKHackEligibility::Check( + descs_initial[I0], + descs_initial[I1], + k_batch_, + Conv_N_, + output_spatial_lengths_, + K0PerBlock * K1); + + // Now create descriptors with the correct hack flag const auto descs = conv_to_gemm_transformer .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -592,12 +716,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle conv_filter_dilations, input_left_pads, input_right_pads, - k_batch_); + k_batch_, + split_k_offset_hack_); a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; + // Calculate stride using CalculateOffset method for accurate stride + // This works correctly for any descriptor transform pipeline + split_k_stride_a_ = a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize(); + if(split_k_offset_hack_) + split_k_stride_a_ /= k_batch_; + + split_k_stride_b_ = b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize(); + if(split_k_offset_hack_) + split_k_stride_b_ /= k_batch_; + block_2_ctile_map_ = GridwiseGemm64::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); @@ -732,6 +867,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const std::array& input_left_pads_; const std::array& input_right_pads_; long_index_t c_space_size_bytes; + + bool split_k_offset_hack_; + long_index_t split_k_stride_a_, split_k_stride_b_; }; // Invoker @@ -878,7 +1016,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle arg.b_grid_desc_kbatch_k0_n_k1_, c_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_ctile_map_, - arg.compute_ptr_offset_of_batch_); + arg.compute_ptr_offset_of_batch_, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_hack_, + arg.k_batch_); }; if(has_main_k0_block_loop) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 9df78f55e5..2121be00d1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -22,6 +22,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" @@ -58,13 +59,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + const index_t num_k_per_block, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_hack) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + + const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0; + const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0; const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); @@ -74,20 +81,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + + DispatchSplitKHack(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_hack); } #else ignore = karg; @@ -96,6 +107,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = compute_ptr_offset_of_batch; ignore = num_k_per_block; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = split_k_offset_hack; + #endif // end of if (defined(__gfx9__) } @@ -119,14 +134,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + const index_t num_k_per_block, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_hack) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + + const long_index_t split_k_offset_a = split_k_offset_hack ? k_idx * split_k_stride_a : 0; + const long_index_t split_k_offset_b = split_k_offset_hack ? k_idx * split_k_stride_b : 0; const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); @@ -140,21 +161,24 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + DispatchSplitKHack_2Lds(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_hack); } #else ignore = karg; @@ -163,6 +187,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = compute_ptr_offset_of_batch; ignore = num_k_per_block; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = split_k_offset_hack; #endif // end of if (defined(__gfx9__) } @@ -490,8 +517,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 : p_a_grid_{p_out_grid}, p_b_grid_{p_in_grid}, p_c_grid_{p_wei_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{}, @@ -560,6 +587,38 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 k_batch_ = split_k; } + // Create descriptors first (with hack flags temporarily set to false) + // so we can check if element space sizes match product of dimensions + const auto descs_initial = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_, + false, // split_k_offset_b_hack (temporary) + true); // use_full_batch_kindex=true for V1-compatible descriptors + + split_k_offset_hack_ = + SplitKHackEligibility::Check( + descs_initial[I0], + descs_initial[I1], + k_batch_, + Conv_N_, + output_spatial_lengths_, + K0PerBlock); + + // Now create descriptors with the correct hack flag const auto descs = conv_to_gemm_transformer .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -576,11 +635,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 conv_filter_dilations, input_left_pads, input_right_pads, - k_batch_); + k_batch_, + split_k_offset_hack_, + true); // use_full_batch_kindex=true for V1-compatible descriptors - a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; - b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; - c_grid_desc_m_n_ = descs[I2]; + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + // Calculate stride using CalculateOffset method for accurate stride + // This works correctly for any descriptor transform pipeline + split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); + if(split_k_offset_hack_) + split_k_stride_a_ /= k_batch_; + + split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize(); + if(split_k_offset_hack_) + split_k_stride_b_ /= k_batch_; // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; @@ -591,8 +662,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 end(filter_spatial_lengths_), index_t{1}, std::multiplies<>{}); - const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -604,8 +675,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -631,6 +702,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const std::array& input_left_pads_; const std::array& input_right_pads_; long_index_t c_space_size_bytes; + + bool split_k_offset_hack_; + long_index_t split_k_stride_a_, split_k_stride_b_; }; // Invoker @@ -640,17 +714,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 void ShowInfo(const Argument& arg) { - std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I3) << "}" << std::endl; - std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I3) << "}" << std::endl; std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; @@ -659,10 +731,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 template float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; @@ -680,7 +752,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const auto num_k_per_block = - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; const auto clear_workspace = [&]() { if(arg.k_batch_ > 1) @@ -716,11 +788,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, - num_k_per_block); + num_k_per_block, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_hack_); } else { @@ -732,11 +807,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, - num_k_per_block); + num_k_per_block, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_hack_); } }; @@ -749,7 +827,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< GridwiseGemm, @@ -781,7 +859,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Tail number could be One to Seven else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { @@ -1090,7 +1168,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Tail number could be Odd or Even else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { @@ -1159,7 +1237,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } else { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { @@ -1232,7 +1310,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Tail number always 1 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< GridwiseGemm, @@ -1289,10 +1367,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } #endif - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); if constexpr(is_same_v || is_same_v) { @@ -1423,9 +1501,17 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && - arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && - arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB)) + const bool a_small_enough = arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() / + (arg.split_k_offset_hack_ ? arg.k_batch_ : 1) * + sizeof(ADataType) <= + TwoGB; + const bool b_small_enough = arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() / + (arg.split_k_offset_hack_ ? arg.k_batch_ : 1) * + sizeof(BDataType) <= + TwoGB; + const bool c_small_enough = + arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB; + if(!(a_small_enough && b_small_enough && c_small_enough)) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp index df128c10b9..ee05c7c6a4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp @@ -48,8 +48,8 @@ namespace { * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). */ template ))) { #endif - __shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>()]; + using EpilogueType = + typename std::conditional::type; - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; - GridwiseGemm::template Run::value, Number<0>, @@ -289,9 +303,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 NPerBlock / ClusterLengthNPerBlock>{}; template - static auto - MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) - + static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< @@ -307,21 +319,11 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); - const auto M = in_gemmm_gemmk_desc.GetLength(I0); - const auto K = in_gemmm_gemmk_desc.GetLength(I1); - - const auto AK0 = K / AK1; - - return transform_tensor_descriptor(in_gemmm_gemmk_desc, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return in_gemmm_gemmk_desc; } template - static auto - MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< @@ -337,16 +339,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); - const auto N = wei_gemmn_gemmk_desc.GetLength(I0); - const auto K = wei_gemmn_gemmk_desc.GetLength(I1); - - const auto BK0 = K / BK1; - - return transform_tensor_descriptor(wei_gemmn_gemmk_desc, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return wei_gemmn_gemmk_desc; } template @@ -364,15 +357,21 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + // Force MN padding on the output tensor. This allows to use Gemm default or only K padding + // and remove some instructions in the hot loop (same approach used for gemm universal). if constexpr(CTranspose) { - constexpr auto matrix_padder_trans = - MatrixPadder{NPerBlock, MPerBlock, KPerBlock}; - return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + constexpr auto matrix_padder_MN_padding_trans = + MatrixPadder{ + NPerBlock, MPerBlock, KPerBlock}; + return matrix_padder_MN_padding_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); } else { - return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + constexpr auto matrix_padder_MN_padding = + MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + return matrix_padder_MN_padding.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); } } @@ -452,10 +451,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 BlkGemmPipelineVer, AComputeDataType, BComputeDataType, - false, // PermuteA - false, // PermuteB - false, // IsBPreShuffled - true>; // ForceThreadTileTransfer + false, // PermuteA + false, // PermuteB + false, // IsBPreShuffled + UseThreadTileTransfer>; // ForceThreadTileTransfer // TODO: Previously available template param DoElementwiseBeforeCShuffle! @@ -529,7 +528,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 false, // PermuteB false, // PermuteA false, // IsBPreShuffled - true>; // ForceThreadTileTransfer + true>; // ForceThreadTileTransfer (always force it because of limitations in the transfer) using GridwiseGemmCTranspose = std::conditional_t; @@ -626,10 +625,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 I1>; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t( - dummy_conv_to_gemm_transformer))>; - using BGridDesc_BK0_N_BK1 = remove_cvref_t( - dummy_conv_to_gemm_transformer))>; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; // Argument struct Argument : public BaseArgument @@ -695,10 +694,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 ds_grid_desc_m_n_{}, e_grid_desc_m_n_{ DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, - a_grid_desc_ak0_m_ak1_{ - MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, - b_grid_desc_bk0_n_bk1_{ - MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, + a_grid_desc_m_k_{MakeAGridDescriptor_M_K(conv_to_gemm_transformer_)}, + b_grid_desc_n_k_{MakeBGridDescriptor_N_K(conv_to_gemm_transformer_)}, compute_ptr_offset_of_groups_{}, compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, @@ -798,8 +795,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 } { - const index_t GemmM = a_grid_desc_ak0_m_ak1_.GetLength(I1); - const index_t GemmN = b_grid_desc_bk0_n_bk1_.GetLength(I1); + const index_t GemmM = a_grid_desc_m_k_.GetLength(I0); + const index_t GemmN = b_grid_desc_n_k_.GetLength(I0); const auto MBlock = CTranspose ? GridwiseGemmCTranspose::CalculateMBlock(GemmN) : GridwiseGemmCTranspose::CalculateMBlock(GemmM); const auto NBlock = CTranspose ? GridwiseGemmCTranspose::CalculateNBlock(GemmM) @@ -883,7 +880,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 is_same_v) { size_as_buffers[i] = - (a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() + + (a_grid_desc_m_k_.GetElementSpaceSize() + (num_group_ - NumGroupsToMerge) * (a_g_n_c_wis_strides_[0])) * sizeof(ADataType_single) / GridwiseGemm::APackedSize; } @@ -891,13 +888,13 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { if(CTranspose && a_g_n_c_wis_lengths_[I1] > 1) { - size_as_buffers[i] = (a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() + + size_as_buffers[i] = (a_grid_desc_m_k_.GetElementSpaceSize() + (eff_num_group - 1) * (a_g_n_c_wis_strides_[0])) * sizeof(ADataType_single) / GridwiseGemm::APackedSize; } else { - size_as_buffers[i] = a_grid_desc_ak0_m_ak1_.GetElementSpaceSize() * + size_as_buffers[i] = a_grid_desc_m_k_.GetElementSpaceSize() * eff_num_group * sizeof(ADataType_single) / GridwiseGemm::APackedSize; } @@ -914,7 +911,7 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 static_for<0, NumBTensor, 1>{}([&](auto i) { using BDataType_single = remove_cvref_t>; - size_bs_buffers[i] = b_grid_desc_bk0_n_bk1_.GetElementSpaceSize() * eff_num_group * + size_bs_buffers[i] = b_grid_desc_n_k_.GetElementSpaceSize() * eff_num_group * sizeof(BDataType_single) / GridwiseGemm::BPackedSize; }); @@ -961,8 +958,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 void Print() const { - std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; - std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "A[AK0, M, AK1]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[BK0, N, BK1]: " << b_grid_desc_n_k_ << std::endl; static_for<0, NumDTensor, 1>{}( [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; @@ -998,8 +995,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 DsGridDesc_M_N ds_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_; - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -1048,10 +1045,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 constexpr index_t minimum_occupancy = BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); - const index_t GemmK = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1); const index_t num_workgroups_per_Conv_N = arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; @@ -1193,8 +1189,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 dim3(BlockSize), 0, gemm_arg_, - arg.b_grid_desc_bk0_n_bk1_, - arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_n_k_, + arg.a_grid_desc_m_k_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, @@ -1210,8 +1206,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 dim3(BlockSize), 0, gemm_arg, - arg.b_grid_desc_bk0_n_bk1_, - arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_n_k_, + arg.a_grid_desc_m_k_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, @@ -1291,8 +1287,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, @@ -1308,8 +1304,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, @@ -1327,8 +1323,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< GridwiseGemmCTranspose, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_N_K, + DeviceOp::AGridDesc_M_K, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, ComputePtrOffset, @@ -1342,8 +1338,8 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_wmma_cshuffle_v3< GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::AGridDesc_M_K, + DeviceOp::BGridDesc_N_K, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, ComputePtrOffset, @@ -1985,10 +1981,9 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 } // check Gridwise GEMM - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); - const index_t GemmK = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_m_k_.GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_.GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_.GetLength(I1); if constexpr(CTranspose) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 7cb0ae20c3..cc343f6f69 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -2108,7 +2108,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle std::unique_ptr describe() const override { - static_assert(ck_tile::reflect::conv::HasConvTraits, + static_assert(ck_tile::reflect::HasConvTraits, "ConvTraits specialization not found for this device operation. " "If you modified the template parameters of this class, ensure that " "the corresponding ConvTraits specialization in " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 4f410d0cce..c9fb8ca3f6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -1282,7 +1282,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor std::unique_ptr describe() const override { static_assert( - ck_tile::reflect::conv::HasConvTraits, + ck_tile::reflect::HasConvTraits, "ConvTraits specialization not found for this device operation. " "If you modified the template parameters of this class, ensure that " "the corresponding ConvTraits specialization in " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp new file mode 100644 index 0000000000..5ae9eaf8ac --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp @@ -0,0 +1,693 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/host_utility/stream_utility.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// +/// @brief Entry point kernel for device-wide Grouped GEMM operation. +/// +/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures. +/// @param[in] group_count The number of together processed GEMMs. +/// +/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation. +/// @tparam GemmDesc The structure holding all necessary descriptors and +/// other data needed for grouped gemm calculation and work +/// distribution. +/// @tparam LocalBlock2ETileMap The structure providing mapping between workgroup ids, +/// the data tiles to process and the output tiles. +/// +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_multiple_d_wmma(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) + using EpilogueType = typename std::conditional::type; + + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ uint8_t p_shared[LDS_size]; + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + constexpr auto NumDTensor = DsDataType::Size(); + index_t tile_id = get_block_1d_id(); + index_t tile_offset = 0; + index_t group_id = -1; + index_t group_offset = 0; + index_t grid_size_grp = 0; + + index_t gemm_tile_id_start = 0; + index_t gemm_tile_id_end = 0; + + index_t M = 0, N = 0, K = 0; + + auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1); + + do + { + // Find corresponding GEMM group for our tile + while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) && + group_id < group_count) + { + group_offset += grid_size_grp; + group_id++; + + if(group_id >= group_count) + return; + + M = gemm_desc_ptr[group_id].M; + N = gemm_desc_ptr[group_id].N; + K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + { + grid_size_grp = 0; + continue; + } + + b2c_tile_map = + OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset); + grid_size_grp = b2c_tile_map.CalculateGridSize(M, N); + + gemm_tile_id_start = group_offset; + gemm_tile_id_end = group_offset + grid_size_grp; + } + + // Create A&B grid pointer containing their single tensors + typename GridwiseGemm::AsGridPointer p_as_grid = Tuple( + static_cast(gemm_desc_ptr[group_id].p_a_grid)); + typename GridwiseGemm::BsGridPointer p_bs_grid = Tuple( + static_cast(gemm_desc_ptr[group_id].p_b_grid)); + + // Make a DsGridPointer instance containing all D tensors + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + DsGridPointer p_ds_grid; + std::array stride_Ds; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + stride_Ds[i] = gemm_desc_ptr[group_id].StrideDs[i]; + }); + + index_t K_split = ck::math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + // Update tile offset if we have moved within group + b2c_tile_map.UpdateTileOffset(tile_offset); + + using Problem = typename GridwiseGemm::Problem; + auto problem = Problem(gemm_desc_ptr[group_id].M, + gemm_desc_ptr[group_id].N, + gemm_desc_ptr[group_id].K, + std::array{gemm_desc_ptr[group_id].StrideA}, + std::array{gemm_desc_ptr[group_id].StrideB}, + stride_Ds, + gemm_desc_ptr[group_id].StrideE, + 1); + + auto epilogue_args = EpilogueType{}; + constexpr TailNumber TailNum = TailNumber::Full; + + if(has_main_k_block_loop) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + GridwiseGemm::template Run( + p_as_grid, + p_bs_grid, + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + b2c_tile_map, + a_element_op, + b_element_op, + cde_element_op, + epilogue_args); + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + GridwiseGemm::template Run( + p_as_grid, + p_bs_grid, + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + b2c_tile_map, + a_element_op, + b_element_op, + cde_element_op, + epilogue_args); + } + } + + tile_id += get_grid_size(); + tile_offset += get_grid_size(); + + } while(group_id < group_count); +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif // end of if (defined(__gfx11__) || defined(__gfx12__)) +} + +template + +struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3 + : public DeviceGroupedGemmTileLoop +{ + using DeviceOp = DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // PermuteA not supported by GridwiseOp. + false>; // PermuteB not supported by DeviceGroupedGemmTileLoop base class. + + using KernelConfig = TileLoopKernelConfig; + using KernelArguments = GroupedGemmKernelArgument; + using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector& /* p_As */, + std::vector& /* p_Bs */, + std::vector>& /* p_Ds */, + std::vector& /* p_Es */, + const std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + int occupancy_num_blocks, + int gpu_cu_count) + : group_count_{static_cast(gemm_descs.size())}, + occupancy_num_blocks_{occupancy_num_blocks}, + gpu_cu_count_{gpu_cu_count}, + gemm_descs_{gemm_descs}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + tile_count_{0} + { + for(const auto& desc : gemm_descs) + { + const auto M = desc.M_; + const auto N = desc.N_; + const auto b2c_tile_map = Block2ETileMap(M, N); + tile_count_ += b2c_tile_map.CalculateGridSize(M, N); + } + } + + index_t group_count_; + const void* p_dev_gemm_args_; + int occupancy_num_blocks_; + int gpu_cu_count_; + const std::vector& gemm_descs_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + index_t tile_count_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using user provided device buffer for kernel + /// arguments. + /// + /// @param[in] arg The structure containing kernel arguments (in host + /// memory). + /// @param[in] dev_gemm_args The pointer to device memory with kernel arguments. + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, + const void* dev_gemm_args, + const StreamConfig& stream_config = StreamConfig{}) + { + if(dev_gemm_args == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + const auto kernel = GetKernelFunction(); + + int grid_size = KernelConfig::CalculateMaxOccupancyGridSize(kernel, stream_config); + + if(stream_config.log_level_ > 0) + { + std::cout << "grid_size: " << grid_size << " tile_count: " << arg.tile_count_ + << std::endl; + } + + // run multiple kernels + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(dev_gemm_args), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + } + + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using device buffers (for kernel arguments and + /// for kernel auxiliary workspace) provided with an argument. The user should + /// call @see GetDeviceKernelArgSize, and @see SetDeviceKernelArgs, on arg + /// parameter to properly allocate those buffers. + /// + /// @param[in] arg The structure containing kernel arguments (in host memory). + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_dev_gemm_args_ == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return Run(arg, arg.p_dev_gemm_args_, stream_config); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static auto GetKernelFunction() + { + const auto kernel = kernel_grouped_gemm_multiple_d_wmma; + return kernel; + } + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + bool supported = true; + for(index_t i = 0; i < arg.group_count_; ++i) + { + std::array placeholder_p_ds_grid{}; + std::array stride_Ds; + std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin()); + + typename GridwiseGemm::Argument gridwise_arg( + std::array{nullptr}, // p_a_grid, + std::array{nullptr}, // p_b_grid, + placeholder_p_ds_grid, // p_ds_grid, + nullptr, // p_e_grid , + arg.gemm_descs_[i].M_, + arg.gemm_descs_[i].N_, + arg.gemm_descs_[i].K_, + std::array{arg.gemm_descs_[i].stride_A_}, + std::array{arg.gemm_descs_[i].stride_B_}, + stride_Ds, + arg.gemm_descs_[i].stride_C_, + 1, // KBatch + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + false); + + bool group_arg_valid = GridwiseGemm::CheckValidity(gridwise_arg); + supported = supported && group_arg_valid; + + if(!group_arg_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + gridwise_arg.Print(); + } + } + } + + return supported; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static int GetKernelOccupancy() + { + const auto kernel = GetKernelFunction(); + return KernelConfig::GetKernelOccupancy(kernel); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) + { + int occupancy = GetKernelOccupancy(); + int num_cu = KernelConfig::GetComputeUnitCount(); + + return Argument{p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op, + occupancy, + num_cu}; + } + + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) override + { + int occupancy = GetKernelOccupancy(); + int num_cu = KernelConfig::GetComputeUnitCount(); + + return std::make_unique(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op, + occupancy, + num_cu); + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::ostringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] + << ">"; + // clang-format on + + return str.str(); + } + + void SetDeviceKernelArgs(Argument& arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + hip_check_error(hipMemcpyAsync(p_dev_kernel_args, + p_host_kernel_args, + GetDeviceKernelArgSize(&arg), + hipMemcpyHostToDevice)); + } + + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const override + { + return SetDeviceKernelArgs( + *dynamic_cast(p_arg), p_dev_kernel_args, p_host_kernel_args); + } + + void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + } + + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(KernelArguments); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 4492e6474f..a9e81f5563 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -4,6 +4,7 @@ #pragma once #include +#include #include #include @@ -26,6 +27,18 @@ namespace ck { namespace tensor_operation { namespace device { +// Dummy kernel to use as a fallback in the kernel selection logic +// Is not used in practice, but only used in case of misconfigured parameters +template +__global__ void kernel_dummy(const void CK_CONSTANT_ADDRESS_SPACE*, + const index_t, + const AElementwiseOperation, + const BElementwiseOperation, + const CDEElementwiseOperation) +{ +} /// /// @brief Entry point kernel for device-wide Grouped GEMM operation. /// @@ -528,6 +541,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; + using KernelConfig = TileLoopKernelConfig; using KernelArguments = GroupedGemmKernelArgument; using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2; @@ -574,22 +588,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop index_t tile_count_; }; - struct KernelConfig - { - // The oversubscription factor for the number of blocks that can simultaneously reside on - // GPU. - static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1; - // static constexpr int BLOCK_WAVES = BlockSize / get_warp_size(); - static constexpr int CU_SIMDS = 4; - // Assume we want to have at most 2 waves per SIMD - // static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); - static int GetCuBlocks() - { - int BLOCK_WAVES = BlockSize / get_warp_size(); - return math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); - } - }; - // Invoker struct Invoker : public BaseInvoker { @@ -666,58 +664,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop const void* dev_gemm_args, const StreamConfig& stream_config) const { - const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + const auto kernel = GetKernelFunction(); return LaunchKernel(kernel, arg, dev_gemm_args, stream_config); } - template - int CalculateMaxOccupancyGridSize(const KernelFunction& kernel, - const StreamConfig& stream_config) const - { - // Calculate max number of workgroups that can simultaneously reside on the CU. - int occ_num_blocks = 0; - size_t dyn_shared_mem_per_blk = 0; - hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk)); - - int cu_count = getAvailableComputeUnitCount(stream_config); - - if(stream_config.log_level_ > 0) - { - std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks - << ", available CUs count: " << cu_count << ", occup. grid size: " - << ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()) * cu_count - << std::endl; - } - - return cu_count * ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()); - } - template float LaunchKernel(const KernelFunction& kernel, const Argument& arg, const void* dev_gemm_args, const StreamConfig& stream_config) const { - int grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config); + int grid_size = KernelConfig::CalculateMaxOccupancyGridSize(kernel, stream_config); if(stream_config.log_level_ > 0) { @@ -835,65 +792,60 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop return IsSupportedArgument(*dynamic_cast(p_arg)); } - static int GetKernelOccupancy() + template + static auto GetKernelFunction() + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + return kernel; + } + + static auto GetKernelFunction() { - int occupancy = 0; if(get_warp_size() == 64) { if constexpr(NXdlPerWave64 > 0) { - const auto kernel = kernel_grouped_gemm_multiple_d_xdl; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + const auto kernel = GetKernelFunction(); + return kernel; } } else { - if constexpr(NXdlPerWave32 > 0) { - const auto kernel = kernel_grouped_gemm_multiple_d_xdl; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + const auto kernel = GetKernelFunction(); + return kernel; } } - return occupancy; + + // This is here to handle the case where MXdlPerWave/NxdPerWave is too small + // This is caught by IsSupportedArgument(), but as GetKernelFunction is sometimes called + // before we need a fallback kernel to return here. + return kernel_dummy; + } + + static int GetKernelOccupancy() + { + const auto kernel = GetKernelFunction(); + return KernelConfig::GetKernelOccupancy(kernel); } static auto MakeArgument(std::vector& p_As, @@ -906,13 +858,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop CDEElementwiseOperation cde_elementwise_op) { int occupancy = GetKernelOccupancy(); - int num_cu; - - hipDeviceProp_t dev_prop; - hipDevice_t dev; - hip_check_error(hipGetDevice(&dev)); - hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); - num_cu = dev_prop.multiProcessorCount; + int num_cu = KernelConfig::GetComputeUnitCount(); return Argument{p_As, p_Bs, @@ -937,13 +883,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop CDEElementwiseOperation cde_elementwise_op) override { int occupancy = GetKernelOccupancy(); - int num_cu; - - hipDeviceProp_t dev_prop; - hipDevice_t dev; - hip_check_error(hipGetDevice(&dev)); - hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); - num_cu = dev_prop.multiProcessorCount; + int num_cu = KernelConfig::GetComputeUnitCount(); return std::make_unique(p_As, p_Bs, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp index 2f0c047167..39024d39e4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -7,6 +7,7 @@ #include #include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/utility/env.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -40,8 +41,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const index_t group_count) { #if(defined(__gfx11__) || defined(__gfx12__)) - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = typename std::conditional::type; + + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const index_t block_id = get_block_1d_id(); @@ -88,13 +93,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; GridwiseGemm::template Run(static_cast(p_shared), @@ -125,7 +130,6 @@ template + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK; // PermuteB not supported by DeviceBatchedGemm base class. + false, // PermuteA not supported by GridwiseOp + false>; // PermuteB not supported by DeviceGroupedGemm base class using CGridDesc_M_N = remove_cvref_t( @@ -242,7 +244,6 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK; using KernelArgument = typename GridwiseGemm::Argument; - using PassThrough = ck::tensor_operation::element_wise::PassThrough; template struct GemmTransKernelArgBase { @@ -274,23 +275,38 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK& p_As, std::vector& p_Bs, + std::vector>& p_Ds, std::vector& p_Es, - std::vector& gemm_descs) - : Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch) + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : Argument(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_element_op, + b_element_op, + c_element_op, + DefaultKBatch) { // TODO: use occupancy api to calculate appropriate batch size. } Argument(std::vector& p_As, std::vector& p_Bs, + std::vector>& p_Ds, std::vector& p_Es, std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op, index_t kbatch) : K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr} { @@ -299,9 +315,11 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK(p_As.size()) && group_count_ == ck::type_convert(p_Bs.size()) && + ((NumDTensor == 0 && p_Ds.size() == 0) || + group_count_ == ck::type_convert(p_Ds.size())) && group_count_ == ck::type_convert(p_Es.size()))) { - throw std::runtime_error("wrong! group_count_ != p_As/b/c.size"); + throw std::runtime_error("wrong! group_count_ != p_As/b/d/e.size"); } gemm_kernel_args_.reserve(group_count_); @@ -320,9 +338,22 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK(stride_d_vec.size()))) + { + throw std::runtime_error("wrong! stride D mismatch"); + } + + // Copy D stride vector to fixed-size array + std::array stride_ds; + if constexpr(NumDTensor > 0) + { + std::copy(stride_d_vec.begin(), stride_d_vec.end(), stride_ds); + } const index_t m_padded = GridwiseGemm::CalculateMPadded(M); const index_t n_padded = GridwiseGemm::CalculateNPadded(N); @@ -346,19 +377,19 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK{p_As[i]}, std::array{p_Bs[i]}, - std::array{}, // p_ds_grid_ + p_Ds[i], type_convert(p_Es[i]), M, N, K, std::array{stride_a}, std::array{stride_b}, - std::array{}, // StrideDs_ + stride_ds, stride_c, K_BATCH, - PassThrough{}, - PassThrough{}, - PassThrough{}, + a_element_op, + b_element_op, + c_element_op, false); gemm_kernel_args_.emplace_back( @@ -632,6 +663,23 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK) + { + if(arg.K_BATCH > 1) + { + // Using SplitK and a C element op would require a two stage kernel where the second + // stage applies the op on the accumulated results + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "C element operators are not supported when using SplitK. Set " + "K_BATCH to 1 or remove the operator." + << std::endl; + } + return false; + } + } + if constexpr(std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { @@ -681,14 +729,15 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK& p_As, std::vector& p_Bs, - std::vector>&, + std::vector>& p_Ds, std::vector& p_Es, std::vector gemm_descs, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation) + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) { - return Argument{p_As, p_Bs, p_Es, gemm_descs}; + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; } static auto MakeInvoker() { return Invoker{}; } @@ -697,14 +746,15 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK MakeArgumentPointer(std::vector& p_As, std::vector& p_Bs, - std::vector>&, + std::vector>& p_Ds, std::vector& p_Es, std::vector& gemm_descs, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation) override + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) override { - return std::make_unique(p_As, p_Bs, p_Es, gemm_descs); + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); } // polymorphic @@ -730,7 +780,7 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK + typename LDSTypeB = ComputeTypeB, + bool NonTemporalLoadB = false> struct DeviceMoeGemmBlockScale : public DeviceGemmMultipleD_BlockScale_BPreshuffle; + LDSTypeB, + NonTemporalLoadB>; using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp new file mode 100644 index 0000000000..6fe4257dbb --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_offset_utils.hpp @@ -0,0 +1,222 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Check if a tensor descriptor has compact layout +// Compact means: GetElementSpaceSize() == product of all dimension lengths +// Non-compact descriptors have complex transform pipelines that may not support split-k hack +template +bool IsDescriptorCompact(const Descriptor& desc) +{ + // Calculate product of all dimensions + long_index_t dims_product = 1; + constexpr index_t num_dims = Descriptor::GetNumOfDimension(); + + // Use template recursion to multiply all dimension lengths + static_for<0, num_dims, 1>{}( + [&](auto i) { dims_product *= static_cast(desc.GetLength(i)); }); + + return desc.GetElementSpaceSize() == dims_product; +} + +// Determine split-k hack eligibility for descriptor pair +// This checks all the conditions required for safely using the split-k offset hack +template +struct SplitKHackEligibility +{ + template + static bool + Check(const ADescriptor& a_desc, + const BDescriptor& b_desc, + index_t k_batch, + index_t Conv_N, + const std::array& output_spatial_lengths, + index_t k_block_size) // K0PerBlock*K1 for v1, K0PerBlock for v3, KPerBlock for two-stage + { + // Only enable hack if k_batch > 1 + if(k_batch <= 1) + { + return false; + } + + // Calculate output spatial product + const index_t output_spatial_acum = std::accumulate(output_spatial_lengths.begin(), + output_spatial_lengths.end(), + index_t{1}, + std::multiplies()); + + // Check various divisibility and layout requirements + const bool is_k_not_paded = (Conv_N * output_spatial_acum) % (k_block_size * k_batch) == 0; + + const bool can_divide_n_spatial_by_k_batch = (Conv_N * output_spatial_acum) % k_batch == 0; + + const bool can_divide_n_by_k_batch = Conv_N % k_batch == 0; + + const bool is_correct_layout = + is_NSpatialGC_GKSpatial_NSpatialGK(); + + const bool is_a_stride_divisible = a_desc.GetElementSpaceSize() % k_batch == 0; + + const bool is_b_stride_divisible = b_desc.GetElementSpaceSize() % k_batch == 0; + + // Check descriptor compactness + const bool is_a_compact = IsDescriptorCompact(a_desc); + const bool is_b_compact = IsDescriptorCompact(b_desc); + + // Require BOTH A and B to be eligible for the hack to avoid KBatch dimension mismatch + // The gridwise kernel's CheckValidity requires A.KBatch == B.KBatch, so we must + // apply the hack uniformly to both tensors to maintain kernel applicability + const bool eligible = can_divide_n_spatial_by_k_batch && can_divide_n_by_k_batch && + is_k_not_paded && is_correct_layout && is_a_stride_divisible && + is_b_stride_divisible && is_a_compact && is_b_compact; + + return eligible; + } +}; + +// Helper function to dispatch split-K hack for standard kernel (single LDS) +// Reduces code duplication in device layer implementations +template +__device__ void DispatchSplitKHack(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const typename GridwiseGemm::Argument& karg, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + index_t k_id, + index_t k_batch, + bool split_k_offset_hack) +{ + if(split_k_offset_hack) + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } +} + +// Helper function to dispatch split-K hack for 2lds kernel +// Reduces code duplication in device layer implementations +template +__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const typename GridwiseGemm::Argument& karg, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + index_t k_id, + index_t k_batch, + bool split_k_offset_hack) +{ + if(split_k_offset_hack) + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 2c17b82608..dc102ef805 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -8,6 +8,7 @@ #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/quantization_operation.hpp" +#include "ck/utility/type_convert.hpp" namespace ck { namespace tensor_operation { @@ -236,8 +237,9 @@ struct MultiplyAdd const half_t& d0, const half_t& d1) const { - const half_t y = type_convert(c) * d0 + d1; - e = y; + const half_t y = + type_convert(c * type_convert(d0) + type_convert(d1)); + e = y; } template <> __host__ __device__ void operator()(bhalf_t& e, @@ -245,8 +247,9 @@ struct MultiplyAdd const bhalf_t& d0, const bhalf_t& d1) const { - const bhalf_t y = type_convert(c) * d0 + d1; - e = y; + const bhalf_t y = + type_convert(c * type_convert(d0) + type_convert(d1)); + e = y; } template <> __host__ __device__ void operator()(float& e, diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp index 942d4351b3..d1e7f35607 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp @@ -10,6 +10,7 @@ namespace ck { template const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // Thread transfer LDS to Vmem - auto cde_shuffle_block_copy_lds_and_global = - Base::template GetLDSToVmemEpilogueDescriptor( - c_ds_desc_refs, - e_grid_desc_mblock_mperblock_nblock_nperblock, - cde_element_op, - block_m_id, - block_n_id); - - // tuple of reference to C/Ds tensor buffers - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - // LDS c_reduce_block_desc_mperblock_nperblock constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, @@ -346,6 +334,68 @@ struct EpilogueReduceCShuffle }, Number{}); + // multiple Ds + constexpr auto d_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + constexpr auto ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock = generate_tuple( + [&](auto) { return d_reduce_thread_desc_mblock_mperblock_nblock_nperblock; }, + Number{}); + + constexpr auto ds_thread_buf_size = + d_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + auto c01_thread_buf = + make_static_buffer( + Number{}); + + auto ds_thread_copy_global_to_vgpr = generate_tuple( + [&](auto I) { + return ThreadwiseTensorSliceTransfer_v2< + remove_cvref_t>, + typename ReduceTrait::ReduceAccDataType_, + decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]), + remove_cvref_t< + decltype(ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I])>, + Sequence, + Sequence<0, 1, 2, 3>, + 3, + ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, + 1, + true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I], + make_multi_index( + I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + }, + Number{}); + + constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + // Write E from Vgpr to Vmem + auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + typename ReduceTrait::ReduceAccDataType_, + EDataType, + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + tensor_operation::element_wise::PassThrough, + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + 3, // DstVectorDim + ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_, + EGlobalMemoryDataOperation, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]), + NumDTensor > 0 ? tensor_operation::element_wise::PassThrough{} : cde_element_op}; + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); @@ -365,15 +415,6 @@ struct EpilogueReduceCShuffle // make sure it's safe to read from LDS block_sync_lds(); - - // each block loads its C data from LDS, D from global, applies elementwise - // operation and stores result E to global - cde_shuffle_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(e_grid_buf)); - { c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, c_shuffle_block_buf, @@ -381,6 +422,53 @@ struct EpilogueReduceCShuffle make_tuple(I0, I0), c_reduce_thread_buf); + // Note: currently multiple Ds supports only Bias + Add. + // It needs to be generalized for other operations (currently not needed) + if constexpr(NumDTensor > 0) + { + auto& d0_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I0); + // d0 / d1 operations + d0_thread_copy_global_to_vgpr.Run( + ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], + ds_grid_buf[I0], + ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I0], + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = activation(c + bias) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + typename ReduceTrait::ReduceAccDataType_ out; + cde_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i)); + c_reduce_thread_buf(i) = out; + }); + + auto& d1_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I1); + + d1_thread_copy_global_to_vgpr.Run( + ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], + ds_grid_buf[I1], + ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I1], + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = c + c1_function(c1) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + d0_element_op(c01_thread_buf(i), c01_thread_buf(i)); + c_reduce_thread_buf(i) += c01_thread_buf(i); + }); + } + + // Write E + c_reduce_thread_copy_vgpr_to_global.Run( + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c_reduce_thread_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + // Reduction static_for<0, NumReduce, 1>{}([&](auto In) { auto& p_reduce_grid = p_reduces_grid[In]; @@ -448,14 +536,15 @@ struct EpilogueReduceCShuffle { constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_global_step); + static_for<0, NumDTensor, 1>{}([&](auto I) { + auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I); + d_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + ds_grid_desc_mblock_mperblock_nblock_nperblock[I], cde_global_step); }); // move on E - cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step); + c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + e_grid_desc_mblock_mperblock_nblock_nperblock, cde_global_step); } }); } @@ -464,6 +553,7 @@ struct EpilogueReduceCShuffle typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops; typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops; index_t MRaw; + typename ReduceTrait::D0ElementwiseOperation_ d0_element_op; ReduceGridDesc_M reduce_grid_desc_m; }; diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp index b8dd5905aa..dd12cdca8c 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp @@ -59,6 +59,8 @@ struct EpilogueCShuffleBase 1, CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma>>; + __device__ static constexpr bool IsLDSNeeded() { return true; } + // *Caution Here repeat is shuffle repeat __device__ static constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp new file mode 100644 index 0000000000..859225a831 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp @@ -0,0 +1,145 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +struct EpilogueDirectStore +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + __device__ static constexpr bool IsLDSNeeded() { return false; } + + template + __device__ static void Run(CThreadBuf& c_thread_buf, + DsGridPointer, + EDataType* p_e_grid, + void*, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + CDEElementwiseOperation& cde_element_op, + const index_t& block_m_id, + const index_t& block_n_id) + { + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + BlockwiseGemmPipe:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + BlockwiseGemmPipe:: + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I1); + constexpr auto MSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I2); + constexpr auto NWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I4); + constexpr auto NThreadPerSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I5); + constexpr auto MAccVgprs = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I6); + + // origin + const auto c_thread_mtx_on_block = + BlockwiseGemmPipe::CalculateCThreadOriginDataIndex(I0, I0); + + const auto m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(c_thread_mtx_on_block[I0])); + + const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(c_thread_mtx_on_block[I1])); + + // E grid descriptor + const auto c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + transform_tensor_descriptor( + e_grid_desc_mblock_mperblock_nblock_nperblock, + make_tuple(make_freeze_transform(block_m_id), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_freeze_transform(block_n_id), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<4, 5, 3>{})); + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + EDataType, + decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + decltype(c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + CDEElementwiseOperation, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 3, + NRepeat, // VectorSize + EGlobalMemoryDataOperation, + 1, + false>{c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(m_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + n_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3]), + cde_element_op}; + + c_thread_copy.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_tuple(I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + e_grid_buf); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index cf471578ca..e47bb37a89 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -77,26 +77,79 @@ struct ABTransferWaveTiles static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack); static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma); + template + __host__ __device__ static auto PadGridDescriptor(GridDescriptorBase& base_desc, + index_t sizeMN, + index_t MNPad, + index_t sizeK, + index_t KPad, + index_t, + index_t) + { + if constexpr(PadMN && PadK) + { + // pad both MN and K + return transform_tensor_descriptor( + base_desc, + make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN), + make_right_pad_transform(sizeK, KPad - sizeK)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(PadMN && !PadK) + { + // pad MN, but not K + return transform_tensor_descriptor( + base_desc, + make_tuple(make_right_pad_transform(sizeMN, MNPad - sizeMN), + make_pass_through_transform(sizeK)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(!PadMN && PadK) + { + // pad K, but not MN + return transform_tensor_descriptor( + base_desc, + make_tuple(make_pass_through_transform(sizeMN), + make_right_pad_transform(sizeK, KPad - sizeK)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad MN or K + return base_desc; + } + } + template __host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc, index_t sizeMN, - index_t, + index_t MNPad, index_t sizeK, - index_t, + index_t KPad, index_t, index_t) { - // Notes: padding is currently not supported - static_assert(!PadMN && !PadK, "padding is currently not supported"); + // Notes: padding is currently not supported with transpose + static_assert(!((PadMN || PadK) && ABDoTranspose), + "padding is currently not supported with transpose"); + + const index_t MN_grid = !PadMN ? sizeMN : MNPad; + const index_t K_grid = !PadK ? sizeK : KPad; + + const auto base_desc_padded = + PadGridDescriptor(base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0); // Divide the base descriptor MN_K into tiles const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor( - base_desc, + base_desc_padded, make_tuple( make_unmerge_transform(make_tuple( - math::integer_divide_ceil(sizeMN, Number{}), Number{})), - make_unmerge_transform(make_tuple(math::integer_divide_ceil(sizeK, Number{}), - Number{}))), + math::integer_divide_ceil(MN_grid, Number{}), Number{})), + make_unmerge_transform(make_tuple( + math::integer_divide_ceil(K_grid, Number{}), Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); @@ -112,9 +165,9 @@ struct ABTransferWaveTiles transform_tensor_descriptor( ab_grid_desc_mntiles_ktiles, make_tuple(make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), make_pass_through_transform( - math::integer_divide_ceil(sizeK, Number{})), + math::integer_divide_ceil(K_grid, Number{})), make_pass_through_transform(Number{}), make_unmerge_transform( make_tuple(Number{}, Number{}))), @@ -127,8 +180,8 @@ struct ABTransferWaveTiles ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, make_tuple( make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), - make_pass_through_transform(math::integer_divide_ceil(sizeK, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), + make_pass_through_transform(math::integer_divide_ceil(K_grid, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), make_freeze_transform(I0)), @@ -143,9 +196,9 @@ struct ABTransferWaveTiles transform_tensor_descriptor( ab_grid_desc_mntiles_ktiles, make_tuple(make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), make_pass_through_transform( - math::integer_divide_ceil(sizeK, Number{})), + math::integer_divide_ceil(K_grid, Number{})), make_unmerge_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{})), @@ -157,8 +210,8 @@ struct ABTransferWaveTiles ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, make_tuple( make_pass_through_transform( - math::integer_divide_ceil(sizeMN, Number{})), - make_pass_through_transform(math::integer_divide_ceil(sizeK, Number{})), + math::integer_divide_ceil(MN_grid, Number{})), + make_pass_through_transform(math::integer_divide_ceil(K_grid, Number{})), make_pass_through_transform(Number{}), make_freeze_transform(I0), make_pass_through_transform(Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp new file mode 100644 index 0000000000..bfe5b7bd08 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp @@ -0,0 +1,275 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/utility/amd_address_space.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +template +struct ABTransferWaveTilesInterleave : ABTransferWaveTiles +{ + using Base = ABTransferWaveTiles; + + using Base::ABDoTranspose; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::MNKRow; + + using Base::GetBlockLaneIdx; + using Base::GetBlockStep; + using Base::GetGridLaneIdx; + using Base::GetWaveIdx; + using Base::PadGridDescriptor; + using typename Base::ThisThreadBlock; + + static constexpr auto I4 = Number<4>{}; + + static_assert(!ABDoTranspose, "wave tile interleaved transfer does not support transpose yet"); + + using Base::KRepeat_; + using Base::KWaves_; + using Base::MNRepeat_; + + static constexpr index_t MNWaves_Grid = MNWaves_Gemm; + static constexpr index_t KWaves_Grid = (BlockSize / WaveSize) / MNWaves_Gemm; + static constexpr index_t KRepeat_Grid = KPerBlock / (KWaves_Grid * KPack); + static constexpr index_t MNRepeat_Grid = MNPerBlock / (MNWaves_Grid * MNPerWmma); + + template + __host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc, + index_t sizeMN, + index_t MNPad, + index_t sizeK, + index_t KPad, + index_t, + index_t) + { + const auto base_desc_padded = Base::template PadGridDescriptor( + base_desc, sizeMN, MNPad, sizeK, KPad, 0, 0); + + const index_t MN_grid = !PadMN ? sizeMN : MNPad; + const index_t K_grid = !PadK ? sizeK : KPad; + + // Divide the base descriptor MN_K into tiles + const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor( + base_desc_padded, + make_tuple(make_unmerge_transform(make_tuple( + math::integer_divide_ceil(MN_grid, Number{}), + Number{})), + make_unmerge_transform(make_tuple( + math::integer_divide_ceil(K_grid, Number{}), Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + // The distinction is needed to get the same global indices for both layouts + // Divide each tile in 2 16x8 subtile + // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize + // MNKRow = 0-1 + // LaneLocal = 0-15 + // VectorSize must be 8 + if constexpr(!ABDoTranspose) + { + const auto ab_grid_desc_mntiles_ktiles_mnrepeat = transform_tensor_descriptor( + ab_grid_desc_mntiles_ktiles, + make_tuple( + make_pass_through_transform( + math::integer_divide_ceil(MN_grid, Number{})), + make_pass_through_transform(math::integer_divide_ceil(K_grid, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<3, 2>{}, Sequence<4>{})); + + const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 = + transform_tensor_descriptor( + ab_grid_desc_mntiles_ktiles_mnrepeat, + make_tuple(make_pass_through_transform(math::integer_divide_ceil( + MN_grid, Number{})), + make_pass_through_transform( + math::integer_divide_ceil(K_grid, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}))), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5>{})); + + // Freeze VectorSize to first element of the loading chunk (for convenience) + // Swap MNPerWmma and MNKRow for consistency with transpose descriptor + return transform_tensor_descriptor( + ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1, + make_tuple( + make_pass_through_transform( + math::integer_divide_ceil(MN_grid, Number{})), + make_pass_through_transform(math::integer_divide_ceil(K_grid, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_freeze_transform(I0)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<3>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<>{})); + } + } + + __device__ static constexpr auto GetBlockDescriptor() + { + // LDS memory layouts: + // lanes within tiles stored contiguously in chunks of 8 elements + // tiles are then stored first in K dimension + // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize + const auto a_grid_desc_mraw_kraw = [&]() { + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + I1)); + }(); + + // Freeze VectorSize to first element of the chunk (for convenience) + return transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_freeze_transform(I0)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<>{})); + } + + template + __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor, + BlockDescriptor& block_descriptor, + ABElementwiseOperation& ab_element_op, + const index_t block_mn_id, + const index_t) + { + // Note: GlobalBufferNum is currently not used but it will be needed + // once we add other pipelines. It is currently needed only for + // consistency with the thread tiles approach + static_assert(GlobalBufferNum == 1, "single global buffer is only supported"); + constexpr index_t NumABTensor = ABsDataType::Size(); + static_assert(NumABTensor == 1, "multiAB currently not supported"); + + using ABDataType = remove_cvref_t>; + + const auto wave_idx = GetWaveIdx(); + index_t wave_idK = wave_idx[I1]; + index_t wave_idMN = wave_idx[I0]; + + const auto grid_lane_id = Base::template GetGridLaneIdx(); + index_t lane_group_grid = grid_lane_id[I0]; + index_t lane_local_id_grid = grid_lane_id[I1]; + + const auto block_lane_id = GetBlockLaneIdx(); + index_t lane_group_block = block_lane_id[I0]; + index_t lane_local_id_block = block_lane_id[I1]; + + constexpr index_t MNRepeatRatio = MNRepeat_Grid / MNRepeat_; + return ThreadGroupTransferGlobal, + Sequence, + Sequence, + ABK1Value, + ABDoTranspose>( + grid_descriptor[I0], + block_descriptor, + make_multi_index(block_mn_id * MNWaves_Grid + wave_idMN / MNRepeatRatio, + wave_idK * KRepeat_Grid, + (wave_idMN % MNRepeatRatio) * MNRepeat_, + lane_group_grid, + lane_local_id_grid), + make_multi_index(wave_idMN / MNRepeatRatio, + wave_idK * KRepeat_, + (wave_idMN % MNRepeatRatio) * MNRepeat_, + lane_group_block, + lane_local_id_block), + ab_element_op); + } + + __device__ static constexpr auto GetBlockStep() + { + // Grid descriptor step (MoveSrcSliceWindow) + return make_multi_index(I0, KWaves_ * KRepeat_, I0, I0, I0); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index 64f50d13df..c168ca9d18 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -897,6 +897,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); static_for<0, num_access, 1>{}([&](auto access_id) { + block_sync_lds(); + // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index c3c14edfb8..a1cba118b2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -177,7 +177,8 @@ template + bool ForceThreadTileTransfer = false, + bool IsFusedKernel = false> struct GridwiseGemm_wmma_cshuffle_v3 : GridwiseGemm_wmma_cshuffle_v3_base< ALayout, @@ -231,7 +232,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 PermuteA, PermuteB, IsBPreShuffled, - ForceThreadTileTransfer> + ForceThreadTileTransfer, + IsFusedKernel> { using Base = GridwiseGemm_wmma_cshuffle_v3_base< ALayout, @@ -285,7 +287,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 PermuteA, PermuteB, IsBPreShuffled, - ForceThreadTileTransfer>; + ForceThreadTileTransfer, + IsFusedKernel>; using Base::I0; using Base::I1; @@ -334,14 +337,14 @@ struct GridwiseGemm_wmma_cshuffle_v3 struct Problem { __host__ Problem() = default; - __host__ Problem(index_t M_, - index_t N_, - index_t K_, - std::array StrideAs_, - std::array StrideBs_, - std::array StrideDs_, - index_t StrideE_, - index_t KBatch_) + __host__ __device__ Problem(index_t M_, + index_t N_, + index_t K_, + std::array StrideAs_, + std::array StrideBs_, + std::array StrideDs_, + index_t StrideE_, + index_t KBatch_) : M{M_}, N{N_}, K{K_}, @@ -411,22 +414,22 @@ struct GridwiseGemm_wmma_cshuffle_v3 struct Argument : public tensor_operation::device::BaseArgument, public Problem { __host__ Argument() = default; - __host__ Argument(std::array p_as_grid_, - std::array p_bs_grid_, - std::array p_ds_grid_, - EDataType* p_e_grid_, - index_t M_, - index_t N_, - index_t K_, - std::array StrideAs_, - std::array StrideBs_, - std::array StrideDs_, - index_t StrideE_, - index_t k_batch_, - AElementwiseOperation a_element_op_, - BElementwiseOperation b_element_op_, - CDEElementwiseOperation cde_element_op_, - bool is_reduce_ = false) + __host__ __device__ Argument(std::array p_as_grid_, + std::array p_bs_grid_, + std::array p_ds_grid_, + EDataType* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + std::array StrideAs_, + std::array StrideBs_, + std::array StrideDs_, + index_t StrideE_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CDEElementwiseOperation cde_element_op_, + bool is_reduce_ = false) : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_}, p_as_grid{}, p_bs_grid{}, @@ -604,6 +607,67 @@ struct GridwiseGemm_wmma_cshuffle_v3 MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n, problem.MBlock, problem.NBlock); + Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map, + a_element_op, + b_element_op, + cde_element_op, + epilogue_args, + A_k_id, + B_k_id); + } + + // Overload to pass in custom As/Bs/Ds/E grid descriptors + // Used for contraction operations, where tensor transforms are non-trivial + template + __device__ static void Run(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, + DsGridPointer& p_ds_grid, + EDataType* p_e_grid, + void* p_shared, + const AsGridDescriptor_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDescriptor_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + EpilogueArgument& epilogue_args, + const index_t A_k_id = 0, + const index_t B_k_id = 0) + { + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -770,9 +834,13 @@ struct GridwiseGemm_wmma_cshuffle_v3 B_k_id); } - __device__ static auto DefaultBlock2CTileMap(const Problem& problem) + __device__ __host__ static auto DefaultBlock2CTileMap(const Problem& problem) { - return Block2CTileMap{problem.M, problem.N, 4}; + return DefaultBlock2CTileMap(problem.M, problem.N); + } + __device__ __host__ static auto DefaultBlock2CTileMap(const index_t M, const index_t N) + { + return Block2CTileMap{M, N, 4}; } // Run method for convolution for bwd_data (grid descriptors are passed as arguments, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 11e9a6dbf7..b7b88d4920 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -15,6 +15,7 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" @@ -24,6 +25,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" @@ -50,13 +52,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) std::is_same_v))) { #endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = + typename std::conditional::type; + + constexpr index_t LDS_size = + GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + auto epilogue_args = EpilogueType{}; GridwiseGemm::template Run( p_shared, splitk_batch_offset, karg, epilogue_args); @@ -167,7 +175,8 @@ template // only needed for convolution (limitation) + bool ForceThreadTileTransfer = false, // only needed for convolution (limitation) + bool IsFusedKernel = false> struct GridwiseGemm_wmma_cshuffle_v3_base { @@ -182,6 +191,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base static constexpr index_t NumATensor = AsDataType::Size(); static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); using LDSTypeA = typename std::conditional<(NumATensor > 1), @@ -232,30 +242,44 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return 1; }(); + static constexpr index_t WaveSize = + WmmaSelector::selected_wmma + .wave_size; + // Limitations of the current implementation: // - no multiAB - // - GemmSpecialization Default - // - pipeline v1 because v3 is buggy (fixed in batched gemm gemm implementation) - // AK1Value == 8 is not really a limitation but a requirement for the method so - // it will stay + // - GemmSpecialization Default with transpose #ifdef __gfx12__ static constexpr bool IsAWaveTransferApplicable = !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && - GemmSpec == tensor_operation::device::GemmSpecialization::Default && + ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && + !is_same_v) || + is_same_v) && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled; static constexpr bool IsBWaveTransferApplicable = !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && - GemmSpec == tensor_operation::device::GemmSpecialization::Default && + ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && + !is_same_v) || + is_same_v) && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; + + static constexpr bool IsWaveTileInterleavedFitting = + (NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize); + + // We need to investigate if it makes sense to remove cshuffle for smaller types + // Currently we use direct store for NRepeat equal to 4 or 8. For 16 bit type we use at + // least buffer store 64 bit for 16 contiguous threads -> 128 bytes in total (full cache line) + static constexpr bool UseDirectStore = is_same_v && + sizeof(ComputeTypeB) == 2 && sizeof(EDataType) == 2 && + NumDTensor == 0 && (NRepeat == 4 || NRepeat == 8) && + !IsFusedKernel && IsWaveTileInterleavedFitting; #else static constexpr bool IsAWaveTransferApplicable = false; static constexpr bool IsBWaveTransferApplicable = false; + static constexpr bool UseDirectStore = false; #endif - static constexpr index_t WaveSize = - WmmaSelector::selected_wmma - .wave_size; static constexpr bool UseBlockPaddingA = ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4; using ATransfer = typename std::conditional< @@ -293,7 +317,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base static constexpr bool UseBlockPaddingB = BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4; - using BTransfer = typename std::conditional< IsBPreShuffled, ABTransferThreadTilesPreShuffle, typename std::conditional< IsBWaveTransferApplicable, - ABTransferWaveTiles, + typename std::conditional< + UseDirectStore, + ABTransferWaveTilesInterleave, + ABTransferWaveTiles>::type, ABTransferThreadTiles __host__ __device__ static auto - MakeAsGridDescriptor_AK0_M_AK1(const index_t M, + MakeAsGridDescriptor_AK0_M_AK1(const BaseDescriptors_M_K& base_descs, + const index_t M, const index_t MPad, const index_t K, const index_t KPad, @@ -481,16 +520,58 @@ struct GridwiseGemm_wmma_cshuffle_v3_base GemmSpec == GemmSpecialization::NKPadding; return generate_tuple( [&](auto i) { - const auto base_desc = MakeAGridDescriptor_M_K(M, K, StrideAs[i]); - return ATransfer::template MakeGridDescriptor( - base_desc, M, MPad, K, KPad, StrideAs[i], AK0); + base_descs[i], M, MPad, K, KPad, StrideAs[i], AK0); }, Number{}); } + template + __device__ static auto MakeAGridDescriptor_AK0_M_AK1(const GridDescBase& base_desc) + { + const auto M = base_desc.GetLength(I0); + const auto K = base_desc.GetLength(I1); + + const auto AK0 = K / AK1Value; + + constexpr bool padM = false; + constexpr bool padK = false; + return ATransfer::template MakeGridDescriptor(base_desc, M, M, K, K, 0, AK0); + } + + template __host__ __device__ static auto - MakeBsGridDescriptor_BK0_N_BK1(const index_t K, + MakeAsGridDescriptor_AK0_M_AK1(const BaseDescriptors_M_K& base_descs, const index_t KBatch = 1) + { + const index_t M = base_descs.At(I0).GetLength(I0); + const index_t K = base_descs.At(I0).GetLength(I1); + + const index_t MPad = CalculateMPadded(M); + const index_t KPad = CalculateKPadded(K, KBatch); + + const index_t AK0 = CalculateAK0Padded(K, KBatch); + + return MakeAsGridDescriptor_AK0_M_AK1(base_descs, M, MPad, K, KPad, {}, AK0); + } + + __host__ __device__ static auto + MakeAsGridDescriptor_AK0_M_AK1(const index_t M, + const index_t MPad, + const index_t K, + const index_t KPad, + const std::array& StrideAs, + const index_t AK0) + { + const auto base_descs = + generate_tuple([&](auto i) { return MakeAGridDescriptor_M_K(M, K, StrideAs[i]); }, + Number{}); + return MakeAsGridDescriptor_AK0_M_AK1(base_descs, M, MPad, K, KPad, StrideAs, AK0); + } + + template + __host__ __device__ static auto + MakeBsGridDescriptor_BK0_N_BK1(const BaseDescriptors_N_K& base_descs, + const index_t K, const index_t KPad, const index_t N, const index_t NPad, @@ -508,13 +589,55 @@ struct GridwiseGemm_wmma_cshuffle_v3_base GemmSpec == GemmSpecialization::MKPadding; return generate_tuple( [&](auto i) { - const auto base_desc = MakeBGridDescriptor_N_K(N, K, StrideBs[i]); return BTransfer::template MakeGridDescriptor( - base_desc, N, NPad, K, KPad, StrideBs[i], BK0); + base_descs[i], N, NPad, K, KPad, StrideBs[i], BK0); }, Number{}); } + template + __device__ static auto MakeBGridDescriptor_BK0_N_BK1(const GridDescBase& base_desc) + { + const auto N = base_desc.GetLength(I0); + const auto K = base_desc.GetLength(I1); + + const auto BK0 = K / BK1Value; + + constexpr bool padN = false; + constexpr bool padK = false; + return BTransfer::template MakeGridDescriptor(base_desc, N, N, K, K, 0, BK0); + } + + template + __host__ __device__ static auto + MakeBsGridDescriptor_BK0_N_BK1(const BaseDescriptors_N_K& base_descs, const index_t KBatch = 1) + { + const index_t N = base_descs.At(I0).GetLength(I0); + const index_t K = base_descs.At(I0).GetLength(I1); + + const index_t NPad = CalculateNPadded(N); + const index_t KPad = CalculateKPadded(K, KBatch); + + const index_t BK0 = CalculateBK0Padded(K, KBatch); + + return MakeBsGridDescriptor_BK0_N_BK1(base_descs, K, KPad, N, NPad, {}, BK0); + } + + __host__ __device__ static auto + MakeBsGridDescriptor_BK0_N_BK1(const index_t K, + const index_t KPad, + const index_t N, + const index_t NPad, + const std::array& StrideBs, + const index_t BK0) + { + + const auto base_descs = + generate_tuple([&](auto i) { return MakeBGridDescriptor_N_K(N, K, StrideBs[i]); }, + Number{}); + return MakeBsGridDescriptor_BK0_N_BK1(base_descs, K, KPad, N, NPad, StrideBs, BK0); + } + __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor() { constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); @@ -593,8 +716,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base #endif } - static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr auto MakeDsGridPointer() { return generate_tuple( @@ -620,7 +741,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } template - __device__ __host__ static constexpr auto + __host__ __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) @@ -678,6 +799,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base ThisThreadBlock, BlockwiseGemmPipe>; + using EpilogueDirectStore = EpilogueDirectStore; + using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle< DsDataType, EDataType, @@ -963,14 +1092,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return true; } - __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / KPerBlock; return BlockwiseGemmPipe::BlockHasHotloop(num_loop); } - __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) { const index_t num_loop = K / KPerBlock; @@ -999,18 +1128,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base max_lds_align) : 0; - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - EpilogueType:: - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + if constexpr(EpilogueType::IsLDSNeeded()) + { + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + EpilogueType:: + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - constexpr auto c_block_size = - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat - .GetElementSpaceSize(); + constexpr auto c_block_size = + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize + - b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize), - c_block_size * sizeof(CShuffleDataType)); + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize + + b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + else + { + return a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize + + b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize; + } } template @@ -1147,7 +1284,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base num_k_block_main_loop, num_k_block_per_scale); - // shuffle C and write out + // Epilogue: + // - CShuffle / direct store + // - Multiple Ds + // - Fused operations epilogue_args.template Run( c_thread_buf, p_ds_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 9339916d6f..8188c42ca5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -663,7 +663,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, - TailNumber TailNum = TailNumber::Odd> + TailNumber TailNum = TailNumber::Odd, + bool SplitKOffsetHack = false> __device__ static void Run(const ADataType* p_a_grid, const BDataType* p_b_grid, CDataType* p_c_grid, @@ -673,12 +674,16 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const index_t k_id = 0) + const index_t k_id = 0, + const index_t k_batch = 1) { + const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1; + const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1; + const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor); auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -744,7 +749,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( a_grid_desc_ak0_m_ak1, - make_multi_index(k_id, m_block_data_idx_on_grid, 0), + make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -775,7 +780,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( b_grid_desc_bk0_n_bk1, - make_multi_index(k_id, n_block_data_idx_on_grid, 0), + make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), @@ -1024,7 +1029,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, - TailNumber TailNum = TailNumber::Odd> + TailNumber TailNum = TailNumber::Odd, + bool SplitKOffsetHack = false> __device__ static void Run_2Lds(const ADataType* p_a_grid, const BDataType* p_b_grid, CDataType* p_c_grid, @@ -1035,12 +1041,16 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const index_t k_id = 0) + const index_t k_id = 0, + const index_t k_batch = 1) { + const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1; + const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1; + const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor); auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -1106,7 +1116,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( a_grid_desc_ak0_m_ak1, - make_multi_index(k_id, m_block_data_idx_on_grid, 0), + make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1137,7 +1147,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( b_grid_desc_bk0_n_bk1, - make_multi_index(k_id, n_block_data_idx_on_grid, 0), + make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 6fd6529fbb..e6f055d183 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" namespace ck { @@ -149,7 +150,8 @@ template + bool HasMainKBlockLoop, + bool SplitKOffsetHack> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -164,7 +166,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, - const CBlockClusterAdaptor c_block_cluster_adaptor) + const CBlockClusterAdaptor c_block_cluster_adaptor, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + index_t k_batch) { #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ defined(__gfx12__) @@ -172,17 +177,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - c_block_cluster_adaptor); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor, + split_k_stride_a, + split_k_stride_b, + k_batch); } #else ignore = p_a_grid; @@ -195,6 +204,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = b_element_op; ignore = c_element_op; ignore = c_block_cluster_adaptor; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = k_batch; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -536,7 +548,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight MRepeat, NRepeat, FloatC, - CGlobalMemoryDataOperation>(); + CGlobalMemoryDataOperation_>(); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template @@ -646,6 +658,416 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + template + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const CBlockClusterAdaptor& c_block_cluster_adaptor, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + index_t k_batch) + { + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t k_batch_id = block_work_idx[I0]; + + // Use compile-time branching based on template parameters + const long_index_t split_k_offset_a = SplitKOffsetHack ? k_batch_id * split_k_stride_a : 0; + const long_index_t split_k_offset_b = SplitKOffsetHack ? k_batch_id * split_k_stride_b : 0; + + // When hack is enabled, buffer size equals the stride (calculated from descriptor's + // CalculateOffset method in the device layer). This properly accounts for the + // descriptor's transform pipeline and non-compact strides. + // When hack is disabled, use the full element space size. + const long_index_t a_buffer_size = + SplitKOffsetHack ? split_k_stride_a : a_b_k0_m_k1_grid_desc.GetElementSpaceSize(); + + const long_index_t b_buffer_size = + SplitKOffsetHack ? split_k_stride_b : b_b_k0_n_k1_grid_desc.GetElementSpaceSize(); + + ignore = k_batch; // k_batch value itself not used in this function + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid + split_k_offset_a, a_buffer_size); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + split_k_offset_b, b_buffer_size); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatA, + FloatAAdjusted, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(SplitKOffsetHack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatB, + FloatBAdjusted, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(SplitKOffsetHack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + K1 <= 4) || + (is_same::value && K1 <= 8) || + ((is_same::value || is_same::value) && + K1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(K1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size, + b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + static_assert(M1 == MWave, ""); + static_assert(N1 == NWave, ""); + static_assert(M2 * M3 * M4 == MPerXdl, ""); + static_assert(N2 == NPerXdl, ""); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXdl, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } + template __device__ static void Run(const FloatA* __restrict__ p_a_grid, const FloatB* __restrict__ p_b_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index c556dbec10..3b98798833 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -173,7 +173,8 @@ template + typename LDSTypeB = BDataType, + bool NonTemporalLoadB = false> struct GridwiseMoeGemmBlockScale { using AScaleType = float; @@ -1202,6 +1203,13 @@ struct GridwiseMoeGemmBlockScale BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { +#if defined(__gfx942__) || defined(__gfx950__) + constexpr auto b_coherence_flag = NonTemporalLoadB + ? AmdBufferCoherenceEnum::WAVE_NT1 + : AmdBufferCoherenceEnum::DefaultCoherence; +#else + constexpr auto b_coherence_flag = AmdBufferCoherenceEnum::DefaultCoherence; +#endif ignore = b_element_op; index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1)); index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); @@ -1300,15 +1308,16 @@ struct GridwiseMoeGemmBlockScale const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( + const auto b_grid_buf = make_dynamic_buffer( p_b_grid + expert_id * static_cast(expert_stride) / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); - const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + expert_id * expert_scale_stride, - b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = + make_dynamic_buffer( + p_b_scale_grid + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -1465,9 +1474,11 @@ struct GridwiseMoeGemmBlockScale if constexpr(IsInputGemm && !IsSplitK) { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; - const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * static_cast(expert_stride) / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); + const auto b_grid_buf_up = + make_dynamic_buffer( + p_b_grid_up + + expert_id * static_cast(expert_stride) / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< BDataType, BDataType, @@ -1485,9 +1496,10 @@ struct GridwiseMoeGemmBlockScale KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; - const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride, - b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf_up = + make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( + const auto b_grid_buf = make_dynamic_buffer( p_b_grid + expert_id * static_cast(expert_stride) / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); - const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + expert_id * expert_scale_stride, - b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = + make_dynamic_buffer( + p_b_scale_grid + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -2227,9 +2247,11 @@ struct GridwiseMoeGemmBlockScale if constexpr(IsInputGemm && !IsSplitK) { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; - const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * static_cast(expert_stride) / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); + const auto b_grid_buf_up = + make_dynamic_buffer( + p_b_grid_up + + expert_id * static_cast(expert_stride) / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< BDataType, BDataType, @@ -2247,9 +2269,10 @@ struct GridwiseMoeGemmBlockScale KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; - const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize, - b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf_up = + make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_hack = false) // Deprecated parameter for backward compatibility { using namespace ck; @@ -172,7 +173,8 @@ struct TransformConvBwdWeightToGemm const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchDim = split_k_offset_hack ? 1 : GemmKBatch; + const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number; if constexpr(ConvBackwardWeightSpecialization == device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) @@ -190,7 +192,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -208,7 +210,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -246,7 +248,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -285,7 +287,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -323,7 +325,8 @@ struct TransformConvBwdWeightToGemm const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_hack = false) { using namespace ck; @@ -359,7 +362,8 @@ struct TransformConvBwdWeightToGemm const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchDim = split_k_offset_hack ? 1 : GemmKBatch; + const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number; const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); @@ -378,7 +382,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -393,7 +397,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -422,7 +426,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -463,7 +467,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -497,7 +501,8 @@ struct TransformConvBwdWeightToGemm const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_hack = false) { using namespace ck; @@ -540,7 +545,8 @@ struct TransformConvBwdWeightToGemm const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchDim = split_k_offset_hack ? 1 : GemmKBatch; + const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number; const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); @@ -559,7 +565,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -574,7 +580,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -603,7 +609,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -653,7 +659,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index 96482b1412..94eae555e9 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -324,7 +324,9 @@ struct TransformConvBwdWeightToGemmV2 const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_hack = false, + const bool use_full_batch_kindex = false) { using namespace ck; @@ -353,7 +355,10 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + // When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise + // kernel compatibility + const index_t KBatchDim = (split_k_offset_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; + const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number; const auto out_grid_desc = make_out_grid_desc(N, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Wi, C, input_strides); @@ -373,7 +378,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -389,7 +394,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -419,7 +424,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -460,7 +465,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -495,7 +500,9 @@ struct TransformConvBwdWeightToGemmV2 const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_hack = false, + const bool use_full_batch_kindex = false) { using namespace ck; @@ -531,7 +538,10 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + // When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise + // kernel compatibility + const index_t KBatchDim = (split_k_offset_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; + const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number; const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); @@ -551,7 +561,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -567,7 +577,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -597,7 +607,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -647,7 +657,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -681,7 +691,9 @@ struct TransformConvBwdWeightToGemmV2 const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_hack = false, + const bool use_full_batch_kindex = false) { using namespace ck; @@ -724,7 +736,10 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + // When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise + // kernel compatibility + const index_t KBatchDim = (split_k_offset_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; + const index_t GemmKPad = KBatchDim * GemmK0 * GemmK1Number; const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); @@ -744,7 +759,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -760,7 +775,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -790,7 +805,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -855,7 +870,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchDim * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index 35389bda37..057687985d 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -10,7 +10,8 @@ namespace ck { #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ - defined(__gfx1103__) || defined(__gfx11_generic__) + defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \ + defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) #define __gfx11__ #endif diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp index b76d957044..07388c4847 100644 --- a/include/ck/utility/generic_memory_space_atomic.hpp +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -111,6 +111,101 @@ __device__ double2_t atomic_add(double2_t* p_dst, const double2_t& x) return vy.template AsType()[I0]; } +#if defined(__gfx11__) +template <> +__device__ float8_t atomic_add(float8_t* p_dst, const float8_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + vy.template AsType()(I2) = + atomicAdd(c_style_pointer_cast(p_dst) + 2, vx.template AsType()[I2]); + vy.template AsType()(I3) = + atomicAdd(c_style_pointer_cast(p_dst) + 3, vx.template AsType()[I3]); + vy.template AsType()(I4) = + atomicAdd(c_style_pointer_cast(p_dst) + 4, vx.template AsType()[I4]); + vy.template AsType()(I5) = + atomicAdd(c_style_pointer_cast(p_dst) + 5, vx.template AsType()[I5]); + vy.template AsType()(I6) = + atomicAdd(c_style_pointer_cast(p_dst) + 6, vx.template AsType()[I6]); + vy.template AsType()(I7) = + atomicAdd(c_style_pointer_cast(p_dst) + 7, vx.template AsType()[I7]); + + return vy.template AsType()[I0]; +} + +template <> +__device__ half4_t atomic_add(half4_t* p_dst, const half4_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomic_add(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = atomic_add(c_style_pointer_cast(p_dst) + 1, + vx.template AsType()[I1]); + vy.template AsType()(I2) = atomic_add(c_style_pointer_cast(p_dst) + 2, + vx.template AsType()[I2]); + vy.template AsType()(I3) = atomic_add(c_style_pointer_cast(p_dst) + 3, + vx.template AsType()[I3]); + + return vy.template AsType()[I0]; +} + +template <> +__device__ half8_t atomic_add(half8_t* p_dst, const half8_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomic_add(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = atomic_add(c_style_pointer_cast(p_dst) + 1, + vx.template AsType()[I1]); + vy.template AsType()(I2) = atomic_add(c_style_pointer_cast(p_dst) + 2, + vx.template AsType()[I2]); + vy.template AsType()(I3) = atomic_add(c_style_pointer_cast(p_dst) + 3, + vx.template AsType()[I3]); + vy.template AsType()(I4) = atomic_add(c_style_pointer_cast(p_dst) + 4, + vx.template AsType()[I4]); + vy.template AsType()(I5) = atomic_add(c_style_pointer_cast(p_dst) + 5, + vx.template AsType()[I5]); + vy.template AsType()(I6) = atomic_add(c_style_pointer_cast(p_dst) + 6, + vx.template AsType()[I6]); + vy.template AsType()(I7) = atomic_add(c_style_pointer_cast(p_dst) + 7, + vx.template AsType()[I7]); + + return vy.template AsType()[I0]; +} +#endif // defined(__gfx11__) + // Caution: DO NOT REMOVE // intentionally have only declaration but no definition to cause compilation failure when trying to // instantiate this template. The purpose is to make the implementation of atomic_max explicit for diff --git a/include/ck/utility/tuple.hpp b/include/ck/utility/tuple.hpp index 78931407d8..1657595030 100644 --- a/include/ck/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -7,6 +7,7 @@ #include "ck/utility/sequence.hpp" #include "ck/utility/type.hpp" #include "ck/utility/enable_if.hpp" +#include namespace ck { @@ -220,4 +221,49 @@ constexpr Tuple tie(Args&... args) noexcept return {args...}; } +// +// tuple_map: Map tuple with a different type +// e.g. tuple_map> becomes Tuple, Wrapper, Wrapper> +// +template