diff --git a/CMakeLists.txt b/CMakeLists.txt index 3a4019188e..e2f3c46619 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,7 @@ option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF) option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF) option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) +option(CK_EXPERIMENTAL_GEMM_BENCHMARK "Enable experimental gemm benchmark for gfx1250" OFF) option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF) option(FORCE_DISABLE_WMMA "Skip compiling WMMA specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF) option(BUILD_CK_TILE_ENGINE "Build the tile_engine subdirectory" OFF) @@ -303,6 +304,9 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx10") add_definitions(-DCK_GFX1030_SUPPORT) endif() +# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA +set(CK_TILE_USE_WMMA 0) + if ((SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") AND NOT FORCE_DISABLE_WMMA) message(STATUS "Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) @@ -321,6 +325,8 @@ endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_OCP_FP8) set(CK_USE_OCP_FP8 "ON") + add_definitions(-DCK_TILE_USE_OCP_FP8) + set(CK_TILE_USE_OCP_FP8 "ON") endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx94") add_definitions(-DCK_USE_FNUZ_FP8) @@ -332,6 +338,16 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_GFX950_SUPPORT) set(CK_GFX950_SUPPORT "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx1250") + add_definitions(-DCK_USE_GFX1250) + add_definitions(-DCK_USE_NATIVE_MX_SUPPORT) + set(CK_USE_NATIVE_MX_SUPPORT "ON") + add_definitions(-DCK_GFX1250_SUPPORT) + set(CK_GFX1250_SUPPORT "ON") +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") + add_definitions(-DCK_GFX12_SUPPORT) +endif() if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32) add_definitions(-DCK_ENABLE_TF32) @@ -413,6 +429,7 @@ endif() option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF) option(ENABLE_JSON_DUMP "Whether to enable json dump for examples." OFF) +option(CK_TEST_DISABLE_GPU_VALIDATION "Whether to disable GPU validation in CK tests." OFF ) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -431,6 +448,10 @@ if (ENABLE_JSON_DUMP) message("CK compiled with ENABLE_JSON_DUMP set to ${ENABLE_JSON_DUMP}") endif() +if (CK_TEST_DISABLE_GPU_VALIDATION) + add_compile_definitions(CK_TEST_DISABLE_GPU_VALIDATION) +endif() + ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) @@ -797,6 +818,10 @@ if(BUILD_CK_PROFILER) endif() endif() +if (CK_EXPERIMENTAL_GEMM_BENCHMARK) + add_subdirectory(experimental/gemm_benchmark) +endif() + if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS)) add_subdirectory(codegen) endif() diff --git a/Jenkinsfile b/Jenkinsfile index f78bc8e329..d2b027e7af 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -772,6 +772,26 @@ def cmake_build(Map conf=[:]){ try { //build CK sh cmd + if (runAllUnitTests){ + // Archive artifacts if they were generated + if (fileExists("ck_build_trace_${arch_name}.json")) { + archiveArtifacts "ck_build_trace_${arch_name}.json" + } + if (fileExists("clang_build_analysis_${arch_name}.log")) { + archiveArtifacts "clang_build_analysis_${arch_name}.log" + } + // Process ninja build trace after full build + sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${arch_name}.json" + archiveArtifacts "ck_build_trace_${arch_name}.json" + sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${arch_name}.json" + + if (params.NINJA_FTIME_TRACE) { + echo "running ClangBuildAnalyzer" + sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log" + sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${arch_name}.log" + archiveArtifacts "clang_build_analysis_${arch_name}.log" + } + } } catch (Exception buildError) { echo "Build failed: ${buildError.getMessage()}" throw buildError @@ -796,49 +816,32 @@ def cmake_build(Map conf=[:]){ } } - //run tests except when NO_CK_BUILD is set + //run tests except when NO_CK_BUILD is set and except on gfx1250 if(!setup_args.contains("NO_CK_BUILD")){ - if (params.NINJA_BUILD_TRACE || params.BUILD_INSTANCES_ONLY){ - // do not run unit tests when building instances only - if(!params.BUILD_INSTANCES_ONLY){ - if (!runAllUnitTests){ - // Smart Build: Run smart_build_and_test.sh - sh """ - export WORKSPACE_ROOT=${env.WORKSPACE} - export PARALLEL=32 - export NINJA_JOBS=${nt} - export ARCH_NAME=${arch_name} - export PROCESS_NINJA_TRACE=true - export NINJA_FTIME_TRACE=${params.NINJA_FTIME_TRACE ? 'true' : 'false'} - bash ../script/dependency-parser/smart_build_and_test.sh - """ - - // Archive artifacts if they were generated - if (fileExists("ck_build_trace_${arch_name}.json")) { - archiveArtifacts "ck_build_trace_${arch_name}.json" - } - if (fileExists("clang_build_analysis_${arch_name}.log")) { - archiveArtifacts "clang_build_analysis_${arch_name}.log" - } - } - else{ + // run unit tests unless building library for all targets + // Note: This else block is when NINJA_BUILD_TRACE=false and BUILD_INSTANCES_ONLY=false + // So no ninja trace processing needed here + if (!params.BUILD_INSTANCES_ONLY){ + if (!runAllUnitTests && !setup_args.contains("gfx1250") ){ + // Smart Build: Run smart_build_and_test.sh + sh """ + export WORKSPACE_ROOT=${env.WORKSPACE} + export PARALLEL=32 + export NINJA_JOBS=${nt} + export ARCH_NAME=${arch_name} + export PROCESS_NINJA_TRACE=false + export NINJA_FTIME_TRACE=false + bash ../script/dependency-parser/smart_build_and_test.sh + """ + } + else{ //run all tests + if(!setup_args.contains("gfx1250")){ echo "Full test suite requested (RUN_ALL_UNIT_TESTS=true or develop branch)" sh "ninja -j${nt} check" - - // Process ninja build trace after full build - sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${arch_name}.json" - archiveArtifacts "ck_build_trace_${arch_name}.json" - sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${arch_name}.json" - - if (params.NINJA_FTIME_TRACE) { - echo "running ClangBuildAnalyzer" - sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log" - sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${arch_name}.log" - archiveArtifacts "clang_build_analysis_${arch_name}.log" - } } - if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) { - sh 'ninja check-builder' + else{ //do not run tests on gfx1250, just build everything + echo "Building for gfx1250" + sh "ninja -j${nt}" } if (params.RUN_ROCM_CK_TESTS) { sh 'ninja check-rocm-ck' @@ -850,47 +853,8 @@ def cmake_build(Map conf=[:]){ stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}" } } - if(params.BUILD_INSTANCES_ONLY){ - // build deb packages - echo "Build library package" - sh 'ninja -j64 package' - sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.2.0_amd64.deb' - stash includes: "composablekernel-dev**.deb", name: "lib_package" - } - } - else{ - // run unit tests unless building library for all targets - // Note: This else block is when NINJA_BUILD_TRACE=false and BUILD_INSTANCES_ONLY=false - // So no ninja trace processing needed here - if (!params.BUILD_INSTANCES_ONLY){ - if (!runAllUnitTests){ - // Smart Build: Run smart_build_and_test.sh - sh """ - export WORKSPACE_ROOT=${env.WORKSPACE} - export PARALLEL=32 - export NINJA_JOBS=${nt} - export ARCH_NAME=${arch_name} - export PROCESS_NINJA_TRACE=false - export NINJA_FTIME_TRACE=false - bash ../script/dependency-parser/smart_build_and_test.sh - """ - } - else{ - echo "Full test suite requested (RUN_ALL_UNIT_TESTS=true or develop branch)" - sh "ninja -j${nt} check" - } - if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) { - sh 'ninja check-builder' - } - if (params.RUN_ROCM_CK_TESTS) { - sh 'ninja check-rocm-ck' - } - if(params.BUILD_PACKAGES){ - echo "Build ckProfiler packages" - sh 'ninja -j64 package' - sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb" - stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}" - } + if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) { + sh 'ninja check-builder' } } } @@ -1414,6 +1378,10 @@ pipeline { name: "BUILD_GFX12", defaultValue: true, description: "Build CK and run tests on gfx12 (default: ON)") + booleanParam( + name: "BUILD_GFX1250", + defaultValue: true, + description: "Build CK for gfx1250 (default: ON)") booleanParam( name: "NINJA_BUILD_TRACE", defaultValue: true, @@ -2052,6 +2020,7 @@ pipeline { cleanWs() } } + /* stage("Build CK and run Tests on gfx1010") { when { @@ -2068,6 +2037,7 @@ pipeline { cleanWs() } } + */ stage("Build CK and run Tests on gfx1030") { when { @@ -2116,6 +2086,21 @@ pipeline { cleanWs() } } + stage("Build CK for gfx1250") + { + when { + beforeAgent true + expression { params.BUILD_GFX1250.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1250" -DDISABLE_DL_KERNELS="ON" """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:npi-mi450-latest", config_targets: "install", no_reboot:true, build_type: 'Release', prefixpath: '/usr/local') + cleanWs() + } + } } post { always { @@ -2194,3 +2179,4 @@ pipeline { } } } + diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index aba462638e..bc2e6a78e7 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -47,7 +47,7 @@ list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllv example_compile_options(example_gemm_xdl_fp8_v3 PRIVATE ${GEMM_OPTIONS}) example_compile_options(example_gemm_xdl_bf16_v3 PRIVATE ${GEMM_OPTIONS}) -list(APPEND gpu_list gfx942 gfx950 gfx1200 gfx1201 gfx12-generic) +list(APPEND gpu_list gfx942 gfx950 gfx1200 gfx1201 gfx12-generic gfx1250) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) @@ -82,7 +82,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -list(APPEND gpu_list gfx90a gfx942 gfx950) +list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1250) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) @@ -95,7 +95,7 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() -list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1200 gfx1201 gfx12-generic) +list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1200 gfx1201 gfx12-generic gfx1250) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) @@ -141,6 +141,9 @@ add_example_executable(example_gemm_wmma_bf16_pk_i4_v3 gemm_wmma_bf16_pk_i4_v3.c add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16_pk_i4_v3) add_example_executable(example_gemm_wmma_fp8_v3 gemm_wmma_fp8_v3.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_v3) +add_example_executable(example_gemm_wmma_fp8_v3_reg_spill gemm_wmma_fp8_v3_reg_spill.cpp) +example_compile_options(example_gemm_wmma_fp8_v3_reg_spill PRIVATE "SHELL: -Rpass-analysis=kernel-resource-usage ") +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_v3_reg_spill) add_example_executable(example_gemm_wmma_fp16_v3 gemm_wmma_fp16_v3.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_v3) add_example_executable(example_gemm_wmma_fp16_pk_i4_v3 gemm_wmma_fp16_pk_i4_v3.cpp) @@ -149,6 +152,11 @@ add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3) add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale) + +if(SUPPORTED_GPU_TARGETS MATCHES "gfx125") + add_example_executable(example_gemm_xdl_bf16_v3_prefetch gemm_xdl_bf16_v3_prefetch.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3_prefetch) +endif() add_example_executable(example_gemm_wmma_fp8_bpreshuffle gemm_wmma_fp8_bpreshuffle.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_bpreshuffle) add_example_executable(example_gemm_wmma_fp16_bpreshuffle gemm_wmma_fp16_bpreshuffle.cpp) diff --git a/example/01_gemm/gemm_wmma_fp8_v3.cpp b/example/01_gemm/gemm_wmma_fp8_v3.cpp index f9bfdbee13..715b454fa0 100644 --- a/example/01_gemm/gemm_wmma_fp8_v3.cpp +++ b/example/01_gemm/gemm_wmma_fp8_v3.cpp @@ -29,7 +29,7 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 128, - 128, 64, 64, + 128, 64, 128, 16, 16, // AK1, BK1 16, 16, 4, 2, @@ -42,7 +42,6 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf ComputeTypeA, ComputeTypeB>; // clang-format on -using ReferenceComputeType = ck::f8_t; using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + ComputeTypeA, + ComputeTypeB>; #include "run_gemm_example_v2.inc" diff --git a/example/01_gemm/gemm_wmma_fp8_v3_reg_spill.cpp b/example/01_gemm/gemm_wmma_fp8_v3_reg_spill.cpp new file mode 100644 index 0000000000..5a2778986d --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp8_v3_reg_spill.cpp @@ -0,0 +1,113 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * \brief Example of GEMM using WMMA that illustrates register spilling on gfx1200 architecture but + * no register spilling on gfx1250. + * + * This example demonstrates how more registers available on the gfx1250 architecture can help avoid + * register spilling that occurs on gfx1200 when using a specific GEMM configuration. + * + * This example must be compiled with the following flag to see the resource allocations: + * "-Rpass-analysis=kernel-resource-usage" + * + * On gfx1200, the kernel will show register spilling due to limited VGPRs: + * \verbatim + * TotalSGPRs: 105 + * VGPRs: 256 + * ScratchSize [bytes/lane]: 56 + * Dynamic Stack: False + * Occupancy [waves/SIMD]: 5 + * SGPRs Spill: 0 + * VGPRs Spill: 15 + * LDS Size [bytes/block]: 32768 + * + * gfx1201 - AMD Radeon RX 9070 XT + * Problem {M:3840, N:4096, K:4096, SA:4096, SB:4096, SC:4096, MP:3840, NP:4096, KRead:4096, + * KP:4096, AK0:512, BK0:512, MBlock: 30, NBlock: 32} + * + * Perf: 0.882764 ms, 145.961 TFlops, 72.4578 GB/s, DeviceGemm_Wmma_CShuffleV3 + * BlkSize: 128, BlkTile: 128x128x128, WaveTile: 16x16, WaveMap: 4x4, VmemReadVec: 8x8, + * BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1, BlkGemmPipelinePrefetchStages: + * 1, KPack: 16 + * \endverbatim + * + * On gfx1250, the same kernel will not show register spilling due to increased VGPRs: + * \verbatim + * TotalSGPRs: 32 + * VGPRs: 318 + * ScratchSize [bytes/lane]: 0 + * Dynamic Stack: False + * Occupancy [waves/SIMD]: 3 + * SGPRs Spill: 0 + * VGPRs Spill: 0 + * LDS Size [bytes/block]: 32768 + * \endverbatim + * + * \note The register allocations above can be influenced by compiler version and code + * changes/optimizations. + */ + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ComputeTypeA = ck::f8_t; +using ComputeTypeB = ck::f8_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 128, //blocksize + 128, 128, 128, // M/N/KPerBlock + 8, 8, // AK1, BK1 + 16, 16, //MPerWmma, NPerWmma + 4, 4, //MRepeat, NRepeat + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,// + 2, 8, 8, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 32, 1, 4>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, + ComputeTypeA, ComputeTypeB>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) +{ + if(!ck::is_gfx12_supported()) + { + std::cout << "This kernel support gfx12 only" << std::endl; + + return 0; + } + return !run_gemm_splitk_example(argc, argv); +} diff --git a/example/01_gemm/gemm_xdl_bf16_v3_prefetch.cpp b/example/01_gemm/gemm_xdl_bf16_v3_prefetch.cpp new file mode 100644 index 0000000000..e29ebeaadb --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_v3_prefetch.cpp @@ -0,0 +1,325 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +#if 1 +static const uint32_t AB_K1 = 8; + +// clang-format off +template +using DeviceGemmV3Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 128, + 256, 256, + 64, AB_K1, AB_K1, + 16, 16, + 8, 16, + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, AB_K1, AB_K1, 0, + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, AB_K1, AB_K1, 0, + 2, 4, S<1, 8, 1, 16>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, + CDataType, + CDataType, + false, + false, + 0, + UseDataCachePrefetch>; +// clang-format on +#else +// prefetch is faster on these params +// clang-format off +template +using DeviceGemmV3Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 128, 128, + 64, 8, 8, + 16, 16, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, + CDataType, + CDataType, + false, + false, + 0, + UseDataCachePrefetch>; +// clang-format on +#endif + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +template +std::pair run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = GemmInstanceType{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return std::make_pair(true, ave_time); + } + + bool pass = true; + if((config.do_verification == 1) || (config.do_verification == 3)) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 100, true, 4}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return std::make_pair(pass, ave_time); +} + +bool parse_cmd_args(int argc, + char* argv[], + ProblemSizeSplitK& problem_size, + ExecutionConfig& config, + bool& compareWithNonDataCachePrefetchImpl) +{ + compareWithNonDataCachePrefetchImpl = false; + + if(argc == 1) + { + // use default case + } + else if(argc == 4 || argc >= 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + if(argc >= 10) + { + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + + if(argc >= 11) + { + problem_size.KBatch = std::stoi(argv[10]); + if(argc > 12) + { + compareWithNonDataCachePrefetchImpl = std::stoi(argv[11]); + } + } + } + } + else + { + std::cerr + << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)" + << std::endl + << "arg10: KBatch" << std::endl + << "arg11: compareWithNonDataCachePrefetchImpl (0=no, 1=yes)" << std::endl; + return false; + } + + return true; +} + +int main(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + bool compareWithNonDataCachePrefetchImpl; + + if(!parse_cmd_args(argc, argv, problem_size, config, compareWithNonDataCachePrefetchImpl)) + { + return 1; + } + + auto [pass, ave_time] = run_gemm>(problem_size, config); + + if(compareWithNonDataCachePrefetchImpl) + { + auto [pass2, ave_time2] = run_gemm>(problem_size, config); + std::cout << "DataCache Prefetching enabled ave_time: " << ave_time << " ms" << std::endl; + std::cout << "DataCache Prefetching disabled ave_time: " << ave_time2 << " ms" << std::endl; + float speedup = ave_time2 / ave_time; + std::cout << "On average kernel with DataCache prefetching is " << speedup + << " times faster than without DataCache prefetching." << std::endl; + + if(speedup < 1.0f) + std::cout << "WARNING: Kernel with DataCache prefetching is slower!" << std::endl; + } + + return pass ? 0 : 1; +} diff --git a/example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp b/example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp index 3c71c59d4c..e7c0061074 100644 --- a/example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp @@ -28,10 +28,10 @@ using DeviceGemmV2_Streamk_Instance = ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 64, - 16, 16, + 32, 32, 256, 8, 16, 16, 16, - 1, 1, + 2, 2, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, diff --git a/example/01_gemm/gemm_xdl_fp32_v3.cpp b/example/01_gemm/gemm_xdl_fp32_v3.cpp new file mode 100644 index 0000000000..c4066063e8 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp32_v3.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#define CK_ENABLE_DYNAMIC_WARP_SIZE 1 +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = float; +using BDataType = float; +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 128, + 64, 64, + 64, 4, 4, + 16, 16, + 2, 4, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 2, 2, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 2, 2, 0, + 1, 2, S<1, 32, 1, 4>, 2, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp64_v3.cpp b/example/01_gemm/gemm_xdl_fp64_v3.cpp new file mode 100644 index 0000000000..061508b38c --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp64_v3.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#define CK_ENABLE_DYNAMIC_WARP_SIZE 1 +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = double; +using BDataType = double; +using AccDataType = double; +using CShuffleDataType = double; +using CDataType = double; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 128, + 64, 64, + 64, 4, 4, + 16, 16, + 2, 4, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 2, 2, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 2, 2, 0, + 1, 2, S<1, 32, 1, 4>, 2, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp index 37b17da3cf..6a5b3f11d6 100644 --- a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp +++ b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp @@ -29,8 +29,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipBLds //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| SrcScalar| buffer| SrcDstVectorDim| DstScalar| //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| size | | PerVector| //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if 0 - < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 8, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 8, 8, 7, 1>; +#if 0 + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 8, 8, 7, 1>; using ADataType = ck::half_t; using BDataType = ck::half_t; using CDataType = ck::half_t; diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index 9b48d5765d..fe81bdb355 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -181,7 +181,7 @@ int main(int argc, char* argv[]) exit(0); } - bool is_supported = ck::is_gfx11_supported(); + bool is_supported = ck::is_gfx11_supported() || ck::is_gfx125_supported(); if(!is_supported) { std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index a770bf5c77..aeeddca0d8 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -181,7 +181,7 @@ int main(int argc, char* argv[]) exit(0); } - bool is_supported = ck::is_gfx11_supported(); + bool is_supported = ck::is_gfx11_supported() || ck::is_gfx125_supported(); if(!is_supported) { std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index c8234bd3b3..19f04f0b95 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -19,7 +19,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -list(APPEND gpu_list gfx90a gfx942 gfx950) +list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1250) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp index ea0168accc..9abb91b58a 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp @@ -47,14 +47,14 @@ using DeviceGroupedConvNDFwdInstance = 1, // 256, // BlockSize 128, // MPerBlock - 256, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 + 128, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 16, // MPerXdl 16, // NPerXdl 4, // MXdlPerWave - 8, // NXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp index 11b7cfc78e..59004fe8ef 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp @@ -49,9 +49,9 @@ using DeviceGroupedConvNDFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 + 128, // KPerBlock + 32, // AK1 + 32, // BK1 16, // MPerXdl 16, // NPerXdl 4, // MXdlPerWave diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp index c4b84e47b1..151eb34837 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp @@ -50,9 +50,9 @@ using DeviceGroupedConvNDFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 + 128, // KPerBlock + 32, // AK1 + 32, // BK1 16, // MPerXdl 16, // NPerXdl 4, // MXdlPerWave diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index fab188fef1..4041a6b434 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -47,14 +47,14 @@ using DeviceGroupedConvNDFwdInstance = 1, // 256, // BlockSize 128, // MPerBlock - 256, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 + 128, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 16, // MPerXdl 16, // NPerXdl 4, // MXdlPerWave - 8, // NXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp index baf6c46987..a2fae971e7 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp @@ -48,14 +48,14 @@ using DeviceGroupedConvNDFwdInstance = 1, // 256, // BlockSize 128, // MPerBlock - 256, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 + 128, // NPerBlock + 128, // KPerBlock + 32, // AK1 + 32, // BK1 16, // MPerXdl 16, // NPerXdl 4, // MXdlPerWave - 8, // NXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp index ab3883e5b3..9d8c166668 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -51,10 +51,10 @@ using DeviceGroupedConvNDFwdInstance = 16, // KPerBlock 4, // AK1 4, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -72,13 +72,13 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 16, 1, 16>, - 4>; + 2>; #include "run_convnd_fwd_example.inc" int main(int argc, char* argv[]) { - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx120_supported()) { return 0; } diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp index bab1ddad9f..bb82c7ce77 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp @@ -51,9 +51,9 @@ using DeviceGroupedConvNDFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 + 128, // KPerBlock + 32, // AK1 + 32, // BK1 16, // MPerXdl 16, // NPerXdl 4, // MXdlPerWave diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp index 79d6531956..0d4e0cd278 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp @@ -50,9 +50,9 @@ using DeviceGroupedConvNDFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 + 128, // KPerBlock + 32, // AK1 + 32, // BK1 16, // MPerXdl 16, // NPerXdl 4, // MXdlPerWave diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp index 69a94c3b89..020917048a 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp @@ -17,7 +17,7 @@ using RsDataType = ck::Tuple; int main(int argc, char* argv[]) { - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx120_supported()) { return 0; } diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc index a425116048..8a7566e1f1 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -25,6 +25,7 @@ static constexpr auto ConvSpec = static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto KPerBlock = sizeof(ADataType) == 1 ? 64 : 32; // clang-format off template using DeviceInstance = @@ -34,9 +35,9 @@ using DeviceInstance = //######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| Specialization| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| | #ifdef BUILD_INT4_EXAMPLE - < NDimSpatial, ALayout, BLayout, DELayout, RLayout, KernelADataType, KernelBDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; + < NDimSpatial, ALayout, BLayout, DELayout, RLayout, KernelADataType, KernelBDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, KPerBlock, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; #else - < NDimSpatial, ALayout, BLayout, DELayout, RLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, 1>; + < NDimSpatial, ALayout, BLayout, DELayout, RLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, KPerBlock, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, 1>; #endif template @@ -292,15 +293,15 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, conv_output_device_buf.FromDevice(conv_output_device.mData.data()); r0_device_buf.FromDevice(r0_device.mData.data()); - + auto rtol = ck::is_same_v ? 1e-1f : 1e-3f; auto pass = ck::utils::check_err(conv_output_device, conv_output_host, "Error: incorrect results! (Matrix E)", - 1e-3f, + rtol, 1e-3f); pass = pass && ck::utils::check_err( - r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-3f, 1e-3f); + r0_device, r0_host, "Error: incorrect results! (Matrix R0)", rtol, 1e-3f); if(pass) std::cout << "Verification on CPU: PASS" << std::endl; diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp index 21802808a4..5ff3d772ad 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp @@ -328,7 +328,8 @@ int main(int argc, char* argv[]) problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; - if(argc == 5) + if(argc == 1) {} + else if(argc == 5) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 2e7748d3f8..278e0936ac 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on struct ProblemSize final @@ -302,7 +302,8 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; - if(argc == 5) + if(argc == 1) {} + else if(argc == 5) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp index b977956690..226d8483da 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp @@ -61,7 +61,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on struct ProblemSize final @@ -287,7 +287,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co c_element_op); ref_invoker.Run(ref_argument); - pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); } } @@ -302,7 +301,8 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; - if(argc == 5) + if(argc == 1) {} + else if(argc == 5) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp index 8d346171c8..d911d54a93 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp @@ -59,7 +59,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSpl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // clang-format on #include "run_grouped_gemm_example.inc" @@ -71,7 +71,8 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; - if(argc == 4) + if(argc == 1) {} + else if(argc == 4) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp index f32d5e9f6d..bf9d50a22f 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp @@ -69,10 +69,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 16, // KPerBlock 4, // AK1 4, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 4, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 8, // MXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1 S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder S<1, 0, 2>, // ABlockTransfer SrcAccessOrder @@ -89,7 +89,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 1, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle - S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + S<32, 8>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock 1>; // RThread DstScalarPerVector _MPerBlock // clang-format on @@ -121,25 +121,21 @@ int main(int argc, char* argv[]) { // do nothing } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 10) + else if(argc == 4 || argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); + if(argc == 10) + { + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideE = std::stoi(argv[9]); + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } } else { @@ -147,10 +143,10 @@ int main(int argc, char* argv[]) printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: Measure kernel execution time (1=ON, 0=Off)\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); - exit(0); + exit(1); } - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx120_supported()) { return 0; } diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp index 6b5dde3cc7..09ba0f5aaf 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp @@ -76,10 +76,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 16, // KPerBlock 4, // AK1 4, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 4, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 8, // MXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1 S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder S<1, 0, 2>, // ABlockTransfer SrcAccessOrder @@ -96,7 +96,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 1, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle - S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + S<32, 8>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock 1>; // RThread DstScalarPerVector _MPerBlock // clang-format on @@ -127,25 +127,21 @@ int main(int argc, char* argv[]) { // do nothing } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 10) + else if(argc == 4 || argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); + if(argc == 10) + { + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideE = std::stoi(argv[9]); + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } } else { @@ -154,10 +150,10 @@ int main(int argc, char* argv[]) << " arg3: Measure kernel execution time (1=ON, 0=Off)\n" << " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n" << std::endl; - exit(EXIT_SUCCESS); + exit(1); } - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx120_supported()) { exit(EXIT_SUCCESS); } diff --git a/example/19_binary_elementwise/CMakeLists.txt b/example/19_binary_elementwise/CMakeLists.txt index 792de59d15..fb04204902 100644 --- a/example/19_binary_elementwise/CMakeLists.txt +++ b/example/19_binary_elementwise/CMakeLists.txt @@ -4,4 +4,6 @@ add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp) add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp) add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp) -add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp) \ No newline at end of file +add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp) +add_example_executable(example_elementwise_tanh_1d elementwise_tanh_1d.cpp) +add_example_executable(example_elementwise_fastgelu_1d elementwise_fastgelu_1d.cpp) diff --git a/example/19_binary_elementwise/elementwise_fastgelu_1d.cpp b/example/19_binary_elementwise/elementwise_fastgelu_1d.cpp new file mode 100644 index 0000000000..66ce30cd1d --- /dev/null +++ b/example/19_binary_elementwise/elementwise_fastgelu_1d.cpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using ::ck::DeviceMem; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +using F32 = float; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; + +using ADataType = F16; +using CDataType = F16; + +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using DeviceElementwiseFastGeluInstance = + ck::tensor_operation::device::DeviceElementwiseImpl, + ck::Tuple, + FastGelu, + 1, + 64, + 16, + 16, + 2, + 2, + ck::Sequence<1, 0>, + ck::Sequence<1>, + ck::Sequence<1>>; + +template +void host_elementwise1D(HostTensorC& C, const HostTensorA& A, int M, Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(int m = 0; m < M; ++m) + { + auto Am = A(m); + ctype Cm = 0; + functor(Cm, Am); + C(m) = Cm; + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification; + bool time_kernel; + + if(argc == 1) + { + do_verification = true; + time_kernel = false; + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = static_cast(std::stoi(argv[2])); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + ck::index_t M = 1024; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor({len}, {stride}); + }; + + Tensor a_m(f_host_tensor_descriptor1d(M, 1)); + Tensor c_m(f_host_tensor_descriptor1d(M, 1)); + + a_m.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + + DeviceMem a_m_device_buf(sizeof(ADataType) * a_m.mDesc.GetElementSpaceSize()); + DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpaceSize()); + + a_m_device_buf.ToDevice(a_m.mData.data()); + + std::array input = {a_m_device_buf.GetDeviceBuffer()}; + std::array output = {c_m_device_buf.GetDeviceBuffer()}; + + std::array abc_lengths = {M}; + std::array a_strides = {1}; + std::array c_strides = {1}; + + auto broadcastFastGelu = DeviceElementwiseFastGeluInstance{}; + auto argument = broadcastFastGelu.MakeArgumentPointer( + abc_lengths, {a_strides}, {c_strides}, input, output, FastGelu{}); + + if(!broadcastFastGelu.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + auto broadcastFastGelu_invoker_ptr = broadcastFastGelu.MakeInvokerPointer(); + float ave_time = + broadcastFastGelu_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_device_buf.FromDevice(c_m.mData.data()); + Tensor host_c_m(f_host_tensor_descriptor1d(M, 1)); + + host_elementwise1D, Tensor, FastGelu>( + host_c_m, a_m, M, FastGelu{}); + + pass &= ck::utils::check_err(c_m, host_c_m, "Error: Incorrect results c", 4e-3, 4e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/19_binary_elementwise/elementwise_tanh_1d.cpp b/example/19_binary_elementwise/elementwise_tanh_1d.cpp new file mode 100644 index 0000000000..e3a2d16ae9 --- /dev/null +++ b/example/19_binary_elementwise/elementwise_tanh_1d.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using ::ck::DeviceMem; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +using F32 = float; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; + +using ADataType = F16; +using CDataType = F16; + +using Tanh = ck::tensor_operation::element_wise::TanH; + +using DeviceElementwiseTanhInstance = + ck::tensor_operation::device::DeviceElementwiseImpl, + ck::Tuple, + Tanh, + 1, + 64, + 16, + 16, + 2, + 2, + ck::Sequence<1, 0>, + ck::Sequence<1>, + ck::Sequence<1>>; + +template +void host_elementwise1D(HostTensorC& C, const HostTensorA& A, int M, Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(int m = 0; m < M; ++m) + { + auto Am = A(m); + ctype Cm = 0; + functor(Cm, Am); + C(m) = Cm; + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification; + bool time_kernel; + + if(argc == 1) + { + do_verification = true; + time_kernel = false; + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = static_cast(std::stoi(argv[2])); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + ck::index_t M = 1024; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor({len}, {stride}); + }; + + Tensor a_m(f_host_tensor_descriptor1d(M, 1)); + Tensor c_m(f_host_tensor_descriptor1d(M, 1)); + + a_m.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + + DeviceMem a_m_device_buf(sizeof(ADataType) * a_m.mDesc.GetElementSpaceSize()); + DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpaceSize()); + + a_m_device_buf.ToDevice(a_m.mData.data()); + + std::array input = {a_m_device_buf.GetDeviceBuffer()}; + std::array output = {c_m_device_buf.GetDeviceBuffer()}; + + std::array abc_lengths = {M}; + std::array a_strides = {1}; + std::array c_strides = {1}; + + auto broadcastTanh = DeviceElementwiseTanhInstance{}; + auto argument = broadcastTanh.MakeArgumentPointer( + abc_lengths, {a_strides}, {c_strides}, input, output, Tanh{}); + + if(!broadcastTanh.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + auto broadcastTanh_invoker_ptr = broadcastTanh.MakeInvokerPointer(); + float ave_time = + broadcastTanh_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_device_buf.FromDevice(c_m.mData.data()); + Tensor host_c_m(f_host_tensor_descriptor1d(M, 1)); + + host_elementwise1D, Tensor, Tanh>(host_c_m, a_m, M, Tanh{}); + + pass &= ck::utils::check_err(c_m, host_c_m, "Error: Incorrect results c", 4e-3, 4e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_bf16.cpp index 9af98aa463..997839a999 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_bf16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_bf16.cpp @@ -43,29 +43,29 @@ using DeviceConvBwdWeightInstance = 256, // BlockSize 128, // MPerBlock 128, // NPerBlock - 32, // K0PerBlock - 8, // K1 - 32, // MPerXdl - 32, // NPerXdl + 64, // K0PerBlock + 16, // K1 + 16, // MPerXdl + 16, // NPerXdl 2, // MXdlPerWave 2, // NXdlPerWave S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcScalarPerVector 2, // ABlockTransferDstScalarPerVector_K1 true, // ABlockLdsAddExtraM S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcVectorDim 1, // BBlockTransferSrcScalarPerVector 2, // BBlockTransferDstScalarPerVector_K1 true, // BBlockLdsAddExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle - S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + S<1, 64, 1, 2>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl // clang-format off diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_fp16.cpp index 466aec51bf..72ad0997c5 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_fp16.cpp @@ -41,29 +41,29 @@ using DeviceConvBwdWeightInstance = 256, // BlockSize 128, // MPerBlock 128, // NPerBlock - 32, // K0PerBlock - 8, // K1 - 32, // MPerXdl - 32, // NPerXdl + 64, // K0PerBlock + 16, // K1 + 16, // MPerXdl + 16, // NPerXdl 2, // MXdlPerWave 2, // NXdlPerWave S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcScalarPerVector 2, // ABlockTransferDstScalarPerVector_K1 false, // ABlockLdsAddExtraM S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcVectorDim 1, // BBlockTransferSrcScalarPerVector 2, // BBlockTransferDstScalarPerVector_K1 false, // BBlockLdsAddExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle - S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + S<1, 64, 1, 2>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl template diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp index ea113dc7be..3ee0364585 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -158,8 +158,8 @@ void host_gemm_layernorm(Tensor& h_m_n, int main() { - // temp disable on gfx11 - if(ck::is_gfx11_supported()) + // temp disable on gfx11 & gfx12 + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { return 0; } diff --git a/example/22_cgemm/cgemm_xdl_fp32.cpp b/example/22_cgemm/cgemm_xdl_fp32.cpp index 9880209452..70df5553cd 100644 --- a/example/22_cgemm/cgemm_xdl_fp32.cpp +++ b/example/22_cgemm/cgemm_xdl_fp32.cpp @@ -74,7 +74,7 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ int main(int argc, char* argv[]) { - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx120_supported()) { return 0; } diff --git a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp index 5510c8e001..b2b236f5a0 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp @@ -69,14 +69,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD CDEElementOp, GemmDefault, 256, // BlockSize - 256, // MPerBlock + 128, // MPerBlock 128, // NPerBlock 32, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXDL 16, // NPerXDL - 8, // MXdlPerWave + 4, // MXdlPerWave 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp index 2f39131b6d..5d46b75ae6 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp @@ -66,15 +66,15 @@ using DeviceBatchedGemmV2Instance = ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, Scale_Block_N, Scale_Block_K, - 16, 64, + 32, 64, KPerBlock, 8, 32, 16, 16, - 1, 1, + 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 1, 1, S<1, 16, 1, 8>, 8, + 1, 1, S<1, 16, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>; // clang-format on diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp index 257692aac6..68f8439aba 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp @@ -65,7 +65,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD int main(int argc, char* argv[]) { - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx120_supported()) { return 0; } diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc index 86a36d53e2..a10a29725d 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -523,7 +523,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[]) { - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx120_supported()) { return 1; } @@ -536,7 +536,7 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[]) problem_size.M = 128 * (dis(gen) + 1); problem_size.N = 128 * (dis(gen) + 1); - problem_size.K = 256 * (dis(gen) + 2); + problem_size.K = 256 * (dis(gen) + 4); problem_size.batch_count = 2; diff --git a/example/26_contraction/common_instances.hpp b/example/26_contraction/common_instances.hpp index 808c548042..ef39d844da 100644 --- a/example/26_contraction/common_instances.hpp +++ b/example/26_contraction/common_instances.hpp @@ -37,7 +37,7 @@ using DeviceOpInstanceKK_Generic = ck::tensor_operation::device:: //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; // clang-format on template , S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; // clang-format on template , S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; // clang-format on template , S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; // clang-format on // Fp64 instances. @@ -126,7 +126,7 @@ using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device:: //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; // clang-format on template , S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; // clang-format on template , S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; // clang-format on template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp index d246df2315..78048f3b7d 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp @@ -25,7 +25,7 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; int main(int argc, char* argv[]) { - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx120_supported()) { return 0; } diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc index f098eaf7e9..b9c34e2a24 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc @@ -32,6 +32,8 @@ using BiasLayout = typename LayoutSettingSelector::BiasLayout; template using ResidualLayout = typename LayoutSettingSelector::ResidualLayout; +static constexpr auto KPerBlock = sizeof(InKernelDataType) == 1 ? 64 : 32; + // instance for double rate mfma on gfx950 (vs gfx942) template using DeviceConvFwdInstance2 = @@ -105,7 +107,7 @@ using DeviceConvFwdInstance = 256, // BlockSize 128, // MPerBlock 128, // NPerBlock - 32, // KPerBlock + KPerBlock, // KPerBlock 4, // AK1 4, // BK1 16, // MPerXdl @@ -333,11 +335,17 @@ bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config, #ifdef BUILD_INT4_EXAMPLE const Tensor out_device_converted(out_device); - return ck::utils::check_err( - out_device_converted, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + return ck::utils::check_err(out_device_converted, + out_host, + "Error: incorrect results!", + get_rtol(), + get_atol()); #else - return ck::utils::check_err( - out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + return ck::utils::check_err(out_device, + out_host, + "Error: incorrect results!", + get_rtol(), + get_atol()); #endif } diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc index da8d956b91..81ad1d4af9 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc @@ -23,14 +23,14 @@ using DeviceConvFwdInstance = 1, // 256, // BlockSize 128, // MPerBlock - 256, // NPerBlock - 16, // KPerBlock + 128, // NPerBlock + 64, // KPerBlock 4, // AK1 4, // BK1 16, // MPerXdl 16, // NPerXdl 4, // MXdlPerWave - 8, // NXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -104,7 +104,6 @@ bool run_grouped_conv_fwd(const ExecutionConfig& config, in_device_buf.ToDevice(in.mData.data()); wei_device_buf.ToDevice(wei.mData.data()); #endif - std::array a_g_n_c_wis_lengths{}; std::array a_g_n_c_wis_strides{}; std::array b_g_k_c_xs_lengths{}; @@ -128,7 +127,6 @@ bool run_grouped_conv_fwd(const ExecutionConfig& config, copy(conv_param.conv_filter_dilations_, conv_filter_dilations); copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_right_pads_, input_right_pads); - // do Conv auto conv = DeviceConvFwdInstance{}; auto invoker = conv.MakeInvoker(); @@ -151,7 +149,6 @@ bool run_grouped_conv_fwd(const ExecutionConfig& config, InElementOp{}, WeiElementOp{}, OutElementOp{}); - if(!conv.IsSupportedArgument(argument)) { throw std::runtime_error( diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp index c5e1844e90..91c69c8c30 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp @@ -87,11 +87,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X 4, // AK1 4, // BK1 1, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 4, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 8, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -114,7 +114,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X 1, false, 1, // CShuffleMXdlPerWavePerShuffle - 2, // CShuffleNXdlPerWavePerShuffle + 4, // CShuffleNXdlPerWavePerShuffle S<1, 16, 1, 16>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock @@ -138,7 +138,7 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< int main(int argc, char* argv[]) { - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx120_supported()) { return 0; } diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index 7efa169a7d..d670fe8cc1 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -7,6 +7,10 @@ add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_f add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp) add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp) add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp) +if(GPU_TARGETS MATCHES "gfx125") + target_compile_definitions(example_self_attention_forward_wmma_fp16 PRIVATE USE_GFX125_CONFIG=1) + target_compile_definitions(example_cross_attention_forward_wmma_fp16 PRIVATE USE_GFX125_CONFIG=1) +endif() add_custom_target(example_gemm_scale_softmax_gemm) diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp index ae5bf950a7..6cb0fb2106 100644 --- a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp @@ -33,10 +33,10 @@ using DeviceGemmV2Instance = ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 256, - 128, 256, 64, + 128, 128, 64, 8, 8, 16, 16, - 4, 4, + 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp index 94e57d3d52..e92b4d9b53 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp @@ -55,7 +55,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>; // clang-format on #include "run_splitK_gemm_example.inc" diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp index a56f841689..6410946a0d 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp @@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_splitK_gemm_example.inc" diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp index 02f7fc883e..52f27e0602 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp @@ -54,14 +54,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 4, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 2>; // clang-format on #include "run_splitK_gemm_example.inc" int main(int argc, char* argv[]) { - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx120_supported()) { return 0; } diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp index d55b760cf0..1c2faf466e 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp @@ -52,7 +52,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 32, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>; // clang-format on #include "run_splitK_gemm_example.inc" diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp index 7ae29023cb..3987489454 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp @@ -64,7 +64,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>; // clang-format on #else @@ -85,7 +85,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlS int main(int argc, char* argv[]) { - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx120_supported()) { return 0; } diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp index a80b105531..7b4628baca 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp @@ -26,7 +26,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat // ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>; + < NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 128, 64, 16, 16, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_grouped_conv_bwd_data_bias_relu_example.inc" diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8.cpp index 3c49710416..2f4dcac8b0 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8.cpp @@ -31,7 +31,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat // ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| Scheduler| Type| Type| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>, BlkGemmPipeSched,BlkGemmPipelineVer, AComputeType, BComputeType , false , false>; + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>, BlkGemmPipeSched,BlkGemmPipelineVer, AComputeType, BComputeType , false , false>; // clang-format on #include "run_grouped_conv_bwd_data_example.inc" diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp index cc030d2e2b..1081b1e4bd 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp @@ -26,7 +26,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat // ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>; + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 128, 64, 16, 16, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_grouped_conv_bwd_data_example.inc" diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp index f8d61d2db6..b11cbfb879 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp @@ -30,7 +30,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat // ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| Scheduler| Type| Type| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, AComputeType, BComputeType>; + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 64, 16, 16, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, AComputeType, BComputeType>; // clang-format on #include "run_grouped_conv_bwd_data_example.inc" diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index e0fd5a1de0..bdd1d328b4 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -8,6 +8,5 @@ if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) endif(USE_BITINT_EXTENSION_INT4) -if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1") - add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) -endif() +add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) + diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp index c0a9dbe519..10cc163ee9 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp @@ -77,11 +77,11 @@ using DeviceBatchedGemmGemmInstance = 4, // AK1 4, // BK1 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 4, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 8, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -106,13 +106,13 @@ using DeviceBatchedGemmGemmInstance = 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 16, 1, 16>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock + 2>; // CShuffleBlockTransferScalarPerVector_NPerBlock #include "run_grouped_conv_conv_fwd_example.inc" int main(int argc, char* argv[]) { - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx120_supported()) { return 0; } diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp index 204bbc6aa4..34d62408b2 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp @@ -106,7 +106,7 @@ using DeviceBatchedGemmGemmInstance = 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock #include "run_grouped_conv_conv_fwd_example.inc" diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp index b043cd878c..5dbcd83eab 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp @@ -72,7 +72,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl ///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| ///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp index 31641955b7..d850469ca7 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp @@ -72,7 +72,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl ///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| ///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp index 2ceca3c877..cca739d498 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp @@ -71,7 +71,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl ///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| ///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/62_convnd_activ/convinvscale/convnd_fwd_wmma_convinvscale_fp8.cpp b/example/62_convnd_activ/convinvscale/convnd_fwd_wmma_convinvscale_fp8.cpp index 7f5e4cddc3..12229a340c 100644 --- a/example/62_convnd_activ/convinvscale/convnd_fwd_wmma_convinvscale_fp8.cpp +++ b/example/62_convnd_activ/convinvscale/convnd_fwd_wmma_convinvscale_fp8.cpp @@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance = 64, // BlockSize 64, // MPerBlock 64, // NPerBlock - 32, // KPerBlock + 128, // KPerBlock 8, // AK1 8, // BK1 16, // MPerWmma diff --git a/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp b/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp index 4513723664..8993218b7f 100644 --- a/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp +++ b/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp @@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 32, // KPerBlock + 64, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXdl diff --git a/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8.cpp index e3bbfeeb50..394a122e76 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8.cpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8.cpp @@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance = 64, // BlockSize 64, // MPerBlock 64, // NPerBlock - 32, // KPerBlock + 128, // KPerBlock 8, // AK1 8, // BK1 16, // MPerWmma diff --git a/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8_fp8.cpp index d169412c5e..ac548d524c 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8_fp8.cpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_bf8_fp8.cpp @@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance = 64, // BlockSize 64, // MPerBlock 64, // NPerBlock - 32, // KPerBlock + 128, // KPerBlock 8, // AK1 8, // BK1 16, // MPerWmma diff --git a/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8.cpp index fb89db20d7..bbc5216548 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8.cpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8.cpp @@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance = 64, // BlockSize 64, // MPerBlock 64, // NPerBlock - 32, // KPerBlock + 128, // KPerBlock 8, // AK1 8, // BK1 16, // MPerWmma diff --git a/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8_bf8.cpp index 656bcd0131..a430b469f3 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8_bf8.cpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_wmma_convscale_fp8_bf8.cpp @@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance = 64, // BlockSize 64, // MPerBlock 64, // NPerBlock - 32, // KPerBlock + 128, // KPerBlock 8, // AK1 8, // BK1 16, // MPerWmma diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp index 8cadb1b720..75f2e05539 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp @@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 32, // KPerBlock + 64, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXdl diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp index f5639c13c6..5df89812c2 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp @@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 32, // KPerBlock + 64, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXdl diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp index 08c9cc08f8..946f956c42 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp @@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 32, // KPerBlock + 64, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXdl diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp index ed7fabf02e..ee1694df10 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp @@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 32, // KPerBlock + 64, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXdl diff --git a/example/62_convnd_activ/convscale_add/convnd_fwd_wmma_convscale_add_fp8.cpp b/example/62_convnd_activ/convscale_add/convnd_fwd_wmma_convscale_add_fp8.cpp index 102fbd2fc3..60a4b1d57c 100644 --- a/example/62_convnd_activ/convscale_add/convnd_fwd_wmma_convscale_add_fp8.cpp +++ b/example/62_convnd_activ/convscale_add/convnd_fwd_wmma_convscale_add_fp8.cpp @@ -53,7 +53,7 @@ using DeviceGroupedConvNDFwdInstance = 64, // BlockSize 64, // MPerBlock 64, // NPerBlock - 32, // KPerBlock + 128, // KPerBlock 8, // AK1 8, // BK1 16, // MPerWmma diff --git a/example/62_convnd_activ/convscale_add/convnd_fwd_xdl_convscale_add_fp8.cpp b/example/62_convnd_activ/convscale_add/convnd_fwd_xdl_convscale_add_fp8.cpp index b2ad648801..d64c31e6c0 100644 --- a/example/62_convnd_activ/convscale_add/convnd_fwd_xdl_convscale_add_fp8.cpp +++ b/example/62_convnd_activ/convscale_add/convnd_fwd_xdl_convscale_add_fp8.cpp @@ -54,7 +54,7 @@ using DeviceGroupedConvNDFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 32, // KPerBlock + 64, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXdl diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_wmma_convscale_amax_fp8.cpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_wmma_convscale_amax_fp8.cpp index 7c50054197..e259b5b180 100644 --- a/example/62_convnd_activ/convscale_reduce/convnd_fwd_wmma_convscale_amax_fp8.cpp +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_wmma_convscale_amax_fp8.cpp @@ -48,7 +48,7 @@ using DeviceGroupedConvNDFwdInstance = 64, // BlockSize 64, // MPerBlock 64, // NPerBlock - 32, // KPerBlock + 128, // KPerBlock 8, // AK1 8, // BK1 16, // MPerWmma diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp index e0255f770f..983c2e5527 100644 --- a/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp @@ -49,7 +49,7 @@ using DeviceGroupedConvNDFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 32, // KPerBlock + 64, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXdl diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp index dcf1af80c7..e2f3e17352 100644 --- a/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp @@ -49,7 +49,7 @@ using DeviceGroupedConvNDFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 32, // KPerBlock + 64, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXdl diff --git a/example/62_convnd_activ/convscale_relu/convnd_fwd_wmma_convscale_relu_fp8.cpp b/example/62_convnd_activ/convscale_relu/convnd_fwd_wmma_convscale_relu_fp8.cpp index 14a2659573..074b96871c 100644 --- a/example/62_convnd_activ/convscale_relu/convnd_fwd_wmma_convscale_relu_fp8.cpp +++ b/example/62_convnd_activ/convscale_relu/convnd_fwd_wmma_convscale_relu_fp8.cpp @@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance = 64, // BlockSize 64, // MPerBlock 64, // NPerBlock - 32, // KPerBlock + 128, // KPerBlock 8, // AK1 8, // BK1 16, // MPerWmma diff --git a/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp b/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp index fe650b461b..98e23207f4 100644 --- a/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp +++ b/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp @@ -53,7 +53,7 @@ using DeviceGroupedConvNDFwdInstance = 256, // BlockSize 128, // MPerBlock 256, // NPerBlock - 32, // KPerBlock + 64, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXdl diff --git a/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp index d0da300eb3..e905f0656e 100644 --- a/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp +++ b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp @@ -26,9 +26,9 @@ using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance, // ABlockTransferThreadClusterLengths_AK0_M_AK1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + sizeof(DataType) == 1 ? 64 : 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerWmma + 16, // NPerWmma + 4, // MWmmaPerWave + 4, // NWmmaPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN 1, 1, S<1, 32, 1, 8>, @@ -123,33 +124,33 @@ using DeviceGroupedConvNDMultiABFwdInstance = InElementOp, WeiElementOp, OutElementOp, - ConvSpec, // ConvForwardSpecialization - GemmSpec, // GemmSpecialization - 1, // - 256, // BlockSize - 128, // MPerBlock - 256, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 - 16, // MPerXdl - 16, // NPerXdl - 4, // MXdlPerWave - 8, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + sizeof(DataType) == 1 ? 64 : 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN 1, 1, S<1, 32, 1, 8>, diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 24a4106ae7..a1a7dc2e8f 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -20,7 +20,8 @@ add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_bl add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp) add_example_executable(example_moe_gemm1_xdl_fp8_blockscale_splitk moe_gemm1_xdl_fp8_blockscale_splitk.cpp) -list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx11-generic gfx12-generic) +list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx11-generic gfx12-generic gfx1250) + set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp16_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp16_bpreshuffle.cpp index 5e0851dbb0..841b5399f5 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp16_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp16_bpreshuffle.cpp @@ -49,8 +49,6 @@ using D1Layout = Col; using DsLayout = ck::Tuple; using ELayout = Row; -static constexpr int KPack = 8; - using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough; @@ -58,6 +56,13 @@ using BElementOp = PassThrough; using CDEElementOp = MultiplyMultiply; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr int KPerBlock = 64; +#if defined(CK_USE_GFX1250) +static constexpr int KPack = 16; +#else +static constexpr int KPack = 8; +#endif +static constexpr auto K0 = KPerBlock / KPack; // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< @@ -65,12 +70,12 @@ using DeviceOpInstance = A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, - 32, 128, 128, - 8, 8, + 32, 128, KPerBlock, + KPack, KPack, 16, 16, 2, 2, - S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, - S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp index ba95724d3f..5a6740765e 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp @@ -50,8 +50,6 @@ using D1Layout = Col; using DsLayout = ck::Tuple<>; using ELayout = Row; -static constexpr int KPack = 16; - using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough; @@ -63,6 +61,13 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio static constexpr ck::index_t Scale_Block_M = 1; static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_K = 128; +static constexpr int KPerBlock = 128; +#if defined(CK_USE_GFX1250) +static constexpr int KPack = 32; +#else +static constexpr int KPack = 16; +#endif +static constexpr auto K0 = KPerBlock / KPack; using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle @@ -71,13 +76,13 @@ using DeviceOpInstance = A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, - 128, 128, 128, - 16, 16, + 128, 128, KPerBlock, + KPack, KPack, 16, 16, 4, 2, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + S, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + S, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_bpreshuffle.cpp index 15e7a5fb16..77173e3297 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_bpreshuffle.cpp @@ -34,10 +34,9 @@ using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; -using A0DataType = F8; -using B0DataType = F8; -static constexpr int KPack = 16; -using ComputeType = F8; +using A0DataType = F8; +using B0DataType = F8; +using ComputeType = F8; using AccDataType = F32; using CShuffleDataType = F32; @@ -60,7 +59,13 @@ using BElementOp = PassThrough; using CDEElementOp = MultiplyMultiply; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - +static constexpr int KPerBlock = 256; +#if defined(CK_USE_GFX1250) +static constexpr int KPack = 32; +#else +static constexpr int KPack = 16; +#endif +static constexpr auto K0 = KPerBlock / KPack; // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle< @@ -68,12 +73,12 @@ using DeviceOpInstance = A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, - 32, 128, 256, - 16, 16, + 32, 128, KPerBlock, + KPack, KPack, 16, 16, 2, 1, - S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index e65eed9ffa..02109ba347 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -74,7 +74,14 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = MultiplyMultiply; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto PipelineVer = []() { +#if defined(CK_USE_WMMA) && !defined(CK_USE_GFX1250) + return ck::BlockGemmPipelineVersion::v1; +#else + return ck::BlockGemmPipelineVersion::v3; +#endif +}(); using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // clang-format off @@ -88,7 +95,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; + ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, FP8>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp index 077a72b080..dd23192418 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp @@ -114,7 +114,14 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = MultiplyMultiply; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto PipelineVer = []() { +#if defined(CK_USE_WMMA) && !defined(CK_USE_GFX1250) + return ck::BlockGemmPipelineVersion::v1; +#else + return ck::BlockGemmPipelineVersion::v3; +#endif +}(); using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // clang-format off @@ -135,7 +142,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, I8>; + ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, I8>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 62e86d7682..d5eb87c728 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -6,9 +6,15 @@ add_custom_target(example_gemm_mx) add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp) add_example_dependencies(example_gemm_mx example_gemm_mx_fp8) +add_example_executable(example_gemm_mx_fp8_v1 gemm_mx_fp8_v1.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_v1) + add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp) add_example_dependencies(example_gemm_mx example_gemm_mx_bf8) +add_example_executable(example_gemm_mx_fp8_bpreshuffle gemm_mx_fp8_bpreshuffle.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bpreshuffle) + # TODO: Fix RRR # add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp) # add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8) @@ -63,9 +69,81 @@ example_compile_options(example_moe_gemm2_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_M set(FP8_MXGEMM_OPTIONS) list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1") example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS}) +example_compile_options(example_gemm_mx_fp8_v1 PRIVATE ${FP8_MXGEMM_OPTIONS}) +example_compile_options(example_gemm_mx_fp8_bpreshuffle PRIVATE ${FP8_MXGEMM_OPTIONS}) example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS}) set(FP6_MXGEMM_OPTIONS) list(APPEND FP6_MXGEMM_OPTIONS -mavx512f) example_compile_options(example_gemm_mx_fp6 PRIVATE ${FP6_MXGEMM_OPTIONS}) example_compile_options(example_gemm_mx_bf6 PRIVATE ${FP6_MXGEMM_OPTIONS}) + +function(add_gemm_mix_prec A_NAME B_NAME A_TYPE B_TYPE) + add_example_executable(example_gemm_mx_${A_NAME}_${B_NAME} gemm_mx_fp4.cpp) + add_example_dependencies(example_gemm_mx example_gemm_mx_${A_NAME}_${B_NAME}) + + add_example_executable(example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle gemm_mx_fp4_bpreshuffle.cpp) + add_example_dependencies(example_gemm_mx example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle) + + example_compile_options(example_gemm_mx_${A_NAME}_${B_NAME} PRIVATE ${FP4_MXGEMM_OPTIONS}) + target_compile_definitions(example_gemm_mx_${A_NAME}_${B_NAME} PRIVATE A_DATATYPE=${A_TYPE}) + target_compile_definitions(example_gemm_mx_${A_NAME}_${B_NAME} PRIVATE B_DATATYPE=${B_TYPE}) + + example_compile_options(example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) + target_compile_definitions(example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE A_DATATYPE=${A_TYPE}) + target_compile_definitions(example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE B_DATATYPE=${B_TYPE}) +endfunction(add_gemm_mix_prec) + +function(add_moe_mix_prec A_NAME B_NAME A_TYPE B_TYPE) + add_example_executable(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns moe_gemm1_xdl_mx_fp4_bns.cpp) + add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns) + + add_example_executable(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns moe_gemm2_xdl_mx_fp4_bns.cpp) + add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns) + + add_example_executable(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME} moe_gemm1_xdl_mx_fp4.cpp) + add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}) + + add_example_executable(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME} moe_gemm2_xdl_mx_fp4.cpp) + add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}) + + add_example_executable(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp) + add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle) + + add_example_executable(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp) + add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle) + + # mx moe B no-shuffling + scale shuffling + example_compile_options(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) + example_compile_options(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) + target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE A_DATATYPE=${A_TYPE}) + target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE B_DATATYPE=${B_TYPE}) + target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE A_DATATYPE=${A_TYPE}) + target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE B_DATATYPE=${B_TYPE}) + + # mx moe B no-shuffling + scale shuffling (async loads) + example_compile_options(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME} PRIVATE ${FP4_MXGEMM_OPTIONS}) + example_compile_options(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME} PRIVATE ${FP4_MXGEMM_OPTIONS}) + target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME} PRIVATE A_DATATYPE=${A_TYPE}) + target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME} PRIVATE B_DATATYPE=${B_TYPE}) + target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME} PRIVATE A_DATATYPE=${A_TYPE}) + target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME} PRIVATE B_DATATYPE=${B_TYPE}) + + # mx moe B shuffling + scale shuffling (async loads) + example_compile_options(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) + example_compile_options(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) + target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE A_DATATYPE=${A_TYPE}) + target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE B_DATATYPE=${B_TYPE}) + target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE A_DATATYPE=${A_TYPE}) + target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE B_DATATYPE=${B_TYPE}) +endfunction(add_moe_mix_prec) + +# mx mixed precsion +if(GPU_TARGETS MATCHES "gfx125") + add_gemm_mix_prec(fp4 fp8 F4 F8) + add_gemm_mix_prec(fp8 fp4 F8 F4) + + add_moe_mix_prec(fp4 fp8 F4 F8) + add_moe_mix_prec(fp8 fp4 F8 F4) + add_moe_mix_prec(fp8 fp8 F8 F8) +endif() diff --git a/example/67_gemm_microscaling/gemm_mx_bf6.cpp b/example/67_gemm_microscaling/gemm_mx_bf6.cpp index 6d5b4c17e2..bca946660c 100644 --- a/example/67_gemm_microscaling/gemm_mx_bf6.cpp +++ b/example/67_gemm_microscaling/gemm_mx_bf6.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #include "gemm_mx_common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" using ADataType = ck::bf6x16_pk_t; using BDataType = ck::bf6x16_pk_t; diff --git a/example/67_gemm_microscaling/gemm_mx_bf8.cpp b/example/67_gemm_microscaling/gemm_mx_bf8.cpp index 2f4e2a5c0b..d2794ae444 100644 --- a/example/67_gemm_microscaling/gemm_mx_bf8.cpp +++ b/example/67_gemm_microscaling/gemm_mx_bf8.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #include "gemm_mx_common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" using ADataType = ck::bf8_t; using BDataType = ck::bf8_t; @@ -44,14 +45,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle GemmSpec, // GemmSpec ScaleBlockSize, // ScaleBlockSize: Scaling block size 128, // BlockSize: Thread block size - 128, // MPerBlock - 32, // NPerBlock + 64, // MPerBlock + 64, // NPerBlock KPerBlock, // KPerBlock 16, // AK1 16, // BK1 16, // MPerXDL 16, // NPerXDL - 4, // MXdlPerWave + 2, // MXdlPerWave 2, // NXdlPerWave S<16, 8, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 74a4ce1bb8..c2a490877f 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -9,7 +9,6 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/utility/blkgemmpipe_scheduler.hpp" #include "ck/utility/data_type.hpp" @@ -27,9 +26,10 @@ using ::ck::Tensor; template using S = ck::Sequence; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using MFMA = ck::tensor_layout::gemm::MFMA; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using MFMA = ck::tensor_layout::gemm::MFMA; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -116,7 +116,7 @@ bool parse_cmd_args(int argc, } template -void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +void preShuffleScaleBuffer_gfx950(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) { int MNXdlPack = 2; int KXdlPack = 2; @@ -126,8 +126,9 @@ void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, i int K0 = K / KXdlPack / XdlKThread; // KRepeat - // The 4 16x128 building blocks will be packed into 1 32x256 for F4 - // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + // On gfx950, WarpSize=64: + // The 4 16x128 building blocks will be packed into 1 32x256 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 // unfold the MN32xK(256/32) scale buffer // 4 16 2 2 @@ -163,13 +164,74 @@ void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, i } } -void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl) +/** + * Pre-shuffle scale buffer for gfx1250 16x16x128 wmma scale instruction + * + * @tparam ScaleType Scale data type + * @tparam KStride Whether K is the leading dimension of the scale buffer + */ +template +void preShuffleScaleBuffer_gfx1250(const ScaleType* src, + ScaleType* dst, + ck::index_t MN, + ck::index_t K) { - int KPack = 16; - int NLane = NXdl; - int KLane = 64 / NLane; - int K_pk = K / 2; - int K0 = K_pk / (KLane * KPack); + + static_assert(ScaleBlockSize == 32 && sizeof(ScaleType) == 1, + "wrong! only support 8-bit scale with ScaleBlockSize=32"); + + constexpr ck::index_t MPerXdlops = 16; + // constexpr ck::index_t NPerXdlops = 16; + constexpr ck::index_t KPerXdlops = 128; + + int MNPack = 2; // 2 sets of scales in M/N direction + int KPack = 1; // 1 set of scales in K direction + + int MNStep = MPerXdlops; + int KStep = KPerXdlops / ScaleBlockSize; // scales per thread + + int K0 = K / KPack / KStep; // KRepeat - how many KStep blocks + + // On gfx1250, WarpSize=32: + // -- The 2 16x128 building blocks will be packed into 1 32x128 + // -- The 4 16x16x128 wmma will be packed into 1 32x32x128 + + // unfold the MN32xK(128/32) scale buffer + // 4 16 1 2 + // To KStep -> MNStep -> KPack -> MNPack + // or ??? + // 2 16 1 4 + // MNPack -> MNStep -> KPack -> KStep + for(int mn = 0; mn < MN; ++mn) + { + int iMNRepeat = mn / (MNStep * MNPack); // i MNRepeat (MN block id) + int tempmn = mn % (MNStep * MNPack); // position in MN block + + for(int k = 0; k < K; ++k) + { + int iKRepeat = k / (KStep * KPack); // i KRepeat + int tempk = k % (KStep * KPack); // position in KStep block + + int outputIndex = (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) + + (iKRepeat * KStep * KPack) * (MNStep * MNPack) + + tempmn * (KStep * KPack) + tempk; + + if constexpr(KStride) + dst[outputIndex] = src[mn * K + k]; + else + dst[outputIndex] = src[k * MN + mn]; + } + } +} + +template +void preShuffleBuffer(const T* src, T* dst, int N, int K, int NXdl) +{ + const int KPack = 16; + const int NLane = NXdl; + const int KLane = ck::get_warp_size() / NLane; + const int K_pk = K / ck::packed_size_v; + const int K0 = K_pk / (KLane * KPack); // K -> K0 KLane KPack // N -> N0 NLane // N, K -> N0 K0 KLane NLane KPack @@ -352,7 +414,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c a_m_k_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2} b_k_n_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2} break; - case 2: a_m_k.GenerateTensorDistr( float_distr{-2.0, 2.0}, ck::identity{}, std::minstd_rand(time(nullptr))); // R[-2,2] @@ -369,12 +430,34 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c } } - preShuffleScaleBuffer>(a_m_k_scale.mData.data(), - a_shuffled_scale.mData.data(), - Scale_Padded_M, - K / ScaleBlockSize); - preShuffleScaleBuffer>( - b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize); + if(ck::get_warp_size() == 64) + { + preShuffleScaleBuffer_gfx950>(a_m_k_scale.mData.data(), + a_shuffled_scale.mData.data(), + Scale_Padded_M, + K / ScaleBlockSize); + + preShuffleScaleBuffer_gfx950>( + b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize); + } + else if(ck::get_warp_size() == 32) + { + preShuffleScaleBuffer_gfx1250>( + a_m_k_scale.mData.data(), + a_shuffled_scale.mData.data(), + Scale_Padded_M, + K / ScaleBlockSize); + + preShuffleScaleBuffer_gfx1250>( + b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize); + } + else + { + throw std::runtime_error("wrong! Scale pre-shuffle unsupported warp size"); + } + if constexpr(BPreShuffle) { int NPerXdl = 16; // Fixed 16 @@ -459,7 +542,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c c_device_buf.FromDevice(c_m_n_device_result.mData.data()); if(config.verbosity > 0) { - std::cout << "Done." << std::endl; + std::cout << "\nDone." << std::endl; std::cout << "Computing GEMM on host..." << std::endl; } diff --git a/example/67_gemm_microscaling/gemm_mx_fp4.cpp b/example/67_gemm_microscaling/gemm_mx_fp4.cpp index a108a7848a..e33b68acbb 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp4.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp4.cpp @@ -2,9 +2,20 @@ // SPDX-License-Identifier: MIT #include "gemm_mx_common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" +using F4 = ck::f4x2_pk_t; +using F8 = ck::f8_t; -using ADataType = ck::f4x2_pk_t; -using BDataType = ck::f4x2_pk_t; +#if defined(A_DATATYPE) +using ADataType = A_DATATYPE; +#else +using ADataType = F4; +#endif +#if defined(B_DATATYPE) +using BDataType = B_DATATYPE; +#else +using BDataType = F4; +#endif using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; @@ -21,7 +32,8 @@ using AElementOp = PassThrough; // elementwise transformation for A matrix using BElementOp = PassThrough; // elementwise transformation for B matrix using CElementOp = PassThrough; // elementwise transformation for C matrix -constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t DataPackedSize = + ck::packed_size_v; // Packed representation of data constexpr ck::index_t ScaleBlockSize = 32; // scaling block size constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 diff --git a/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp index 5f7a5bfa9e..9c92aebd38 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp @@ -2,9 +2,21 @@ // SPDX-License-Identifier: MIT #include "gemm_mx_common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" -using ADataType = ck::f4x2_pk_t; -using BDataType = ck::f4x2_pk_t; +using F4 = ck::f4x2_pk_t; +using F8 = ck::f8_t; + +#if defined(A_DATATYPE) +using ADataType = A_DATATYPE; +#else +using ADataType = F4; +#endif +#if defined(B_DATATYPE) +using BDataType = B_DATATYPE; +#else +using BDataType = F4; +#endif using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; @@ -21,7 +33,8 @@ using AElementOp = PassThrough; // elementwise transformation for A matrix using BElementOp = PassThrough; // elementwise transformation for B matrix using CElementOp = PassThrough; // elementwise transformation for C matrix -constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t DataPackedSize = + ck::packed_size_v; // Packed representation of data constexpr ck::index_t ScaleBlockSize = 32; // scaling block size constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 @@ -30,7 +43,7 @@ constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; // AB DataType: f4x2_pk_t -// Mathmatically, all numbers are represented as f4x2. +// Mathematically, all numbers are represented as f4x2. using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< ALayout, // ALayout BLayout, // BLayout @@ -47,24 +60,24 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle CElementOp, // CElementwiseOperation GemmSpec, // GemmSpec ScaleBlockSize, // ScaleBlockSize: Scaling block size - 256, // BlockSize: Thread block size - 128, // MPerBlock - 512, // NPerBlock + 128, // BlockSize: Thread block size + 64, // MPerBlock + 64, // NPerBlock KPerBlock, // KPerBlock 16, // AK1 16, // BK1 16, // MPerXDL 16, // NPerXDL - 8, // MXdlPerWave - 8, // NXdlPerWave - S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + 2, // MXdlPerWave + 2, // NXdlPerWave + S<8, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 16, // ABlockTransferSrcScalarPerVector 16, // ABlockTransferDstScalarPerVector_AK1 true, // ABlockLdsExtraM - S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<8, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim @@ -72,9 +85,9 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle 16, // BBlockTransferDstScalarPerVector_BK1 true, // BBlockLdsExtraN 2, // CShuffleMXdlPerWavePerShuffle - 4, // CShuffleNXdlPerWavePerShuffle - S<1, 8, 1, 32>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlockW + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 8, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 2, // CShuffleBlockTransferScalarPerVector_NPerBlock BlkGemmPSched, // BlkGemmPipeSched BlkGemmPVer, // BlkGemmPipelineVer ADataType, // ComputeTypeA diff --git a/example/67_gemm_microscaling/gemm_mx_fp6.cpp b/example/67_gemm_microscaling/gemm_mx_fp6.cpp index 615980082d..0816162fbb 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp6.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp6.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #include "gemm_mx_common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" using ADataType = ck::f6x16_pk_t; using BDataType = ck::f6x16_pk_t; diff --git a/example/67_gemm_microscaling/gemm_mx_fp8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8.cpp index 0e28770ad4..d83d5bc83a 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp8.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp8.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #include "gemm_mx_common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" using ADataType = ck::f8_t; using BDataType = ck::f8_t; diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp index 49caf80a9e..862377a7a6 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #include "gemm_mx_common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" using ADataType = ck::f8_t; using BDataType = ck::bf8_t; diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_bpreshuffle.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_bpreshuffle.cpp new file mode 100644 index 0000000000..c229c61ff8 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8_bpreshuffle.cpp @@ -0,0 +1,101 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "gemm_mx_common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; + +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = MFMA; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256; + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XPackedDataType, // AScaleDataType + BDataType, // BDataType + XPackedDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 128, // MPerBlock + 128, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_v1.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_v1.cpp new file mode 100644 index 0000000000..0ca00bc97e --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8_v1.cpp @@ -0,0 +1,101 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "gemm_mx_common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; + +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256; + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XPackedDataType, // AScaleDataType + BDataType, // BDataType + XPackedDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 64, // MPerBlock + 128, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp index 586ecd81bf..eae5a204cb 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp @@ -1,47 +1,29 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "gemm_mx_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/utility/blkgemmpipe_scheduler.hpp" - -using ::ck::DeviceMem; -using ::ck::HostTensorDescriptor; -using ::ck::Tensor; - -template -using S = ck::Sequence; using F4 = ck::f4x2_pk_t; +using F8 = ck::f8_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using Bypass = ck::tensor_layout::BypassLayoutVerification; - -using A0DataType = F4; +#if defined(A_DATATYPE) +using A0DataType = A_DATATYPE; +#else +using A0DataType = F4; +#endif +#if defined(B_DATATYPE) +using B0DataType = B_DATATYPE; +#else +using B0DataType = F4; +#endif using A1DataType = XPackedDataType; -using B0DataType = F4; using B1DataType = XPackedDataType; using EDataType = F16; using AccDataType = F32; @@ -89,67 +71,14 @@ struct MulABScaleExpertWeight } }; -using CDEElementOp = MulABScaleExpertWeight; - -// A, B Scale preshuffle -template -void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) -{ - int MNXdlPack = 2; - int KXdlPack = 2; - - int XdlMNThread = 16; - int XdlKThread = 64 / XdlMNThread; - - int K0 = K / KXdlPack / XdlKThread; // KRepeat - - // The 4 16x128 building blocks will be packed into 1 32x256 for F4 - // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 - - // unfold the MN32xK(256/32) scale buffer - // 4 16 2 2 - // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack - // Then, MNRepeat->KRepeat - - for(int n = 0; n < MN; ++n) - { - for(int k = 0; k < K; ++k) - { - int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat - int tempn = n % (XdlMNThread * MNXdlPack); - int n1 = tempn % XdlMNThread; // i XdlMNThread - int n2 = tempn / XdlMNThread; // i MNXdlPack - - int k0 = k / (XdlKThread * KXdlPack); // i KRepeat - int tempk = k % (XdlKThread * KXdlPack); - int k1 = tempk % XdlKThread; // i XdlKThread - int k2 = tempk / XdlKThread; // i KXdlPack - - int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + - k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + - k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + - k2 * MNXdlPack + n2; - // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + - // k2 * MNXdlPack))); - if constexpr(KLast) - dst[outputIndex] = src[n * K + k]; - else - dst[outputIndex] = src[k * MN + n]; - } - } -} - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -constexpr ck::index_t DataPackedSize = 2; // Packed representation of data -constexpr ck::index_t ScaleBlockSize = 32; // scaling block size -constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 128; static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul static constexpr ck::index_t MPerBlock = 128; @@ -157,6 +86,11 @@ static constexpr ck::index_t NPerBlock = 64; static constexpr ck::index_t BlockSize = 256; static constexpr bool MulRoutedWeight = true; +static constexpr ck::index_t ClusterLengths_BK0 = + ck::is_same_v && ck::is_same_v ? 8 : 4; +static constexpr ck::index_t ClusterLengths_N = + ck::is_same_v && ck::is_same_v ? 32 : 64; + // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX< A0Layout, B0Layout, DsLayout, ELayout, @@ -166,10 +100,10 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic MPerBlock, NPerBlock, KPerBlock, 16, 16, 16, 16, - 4, 2, + 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, - 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, + S, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 2, S<1, 32, 1, 8>, S<4, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on @@ -373,6 +307,9 @@ int main(int argc, char* argv[]) DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize()); + // a0_t_k.savetxt("a.txt", "float", 128); + // b0_e_n_k.savetxt("b.txt", "float", 128); + // A scale sorted for(int i = 0; i < sorted_size; i++) { @@ -392,14 +329,30 @@ int main(int argc, char* argv[]) } // A/B scale shuffle - preShuffleScaleBuffer>(a_scale_sorted.mData.data(), - a_scale_preshuffled.mData.data(), - sorted_size, - K / ScaleBlockSize); - preShuffleScaleBuffer>(b1_e_n_k.mData.data(), - b_scale_preshuffled.mData.data(), - N * 2 * experts, - K / ScaleBlockSize); + if(ck::is_gfx125_supported()) + { + preShuffleScaleBuffer_gfx1250>( + a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer_gfx1250>( + b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + } + else + { + preShuffleScaleBuffer_gfx950>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer_gfx950>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + } sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); @@ -452,9 +405,10 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + ck::is_gfx125_supported())) { - std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + std::cout << "This kernel support gfx942, gfx950 and gfx125x only" << std::endl; } if(time_kernel) diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp index b3b2ebcbc0..0e793ee2cf 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp @@ -1,33 +1,10 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "gemm_mx_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/utility/blkgemmpipe_scheduler.hpp" - -using ::ck::DeviceMem; -using ::ck::HostTensorDescriptor; -using ::ck::Tensor; - -template -using S = ck::Sequence; - +using F8 = ck::f8_t; using F4 = ck::f4x2_pk_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -35,14 +12,19 @@ using F32 = float; using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using Bypass = ck::tensor_layout::BypassLayoutVerification; +#if defined(A_DATATYPE) +using A0DataType = A_DATATYPE; +#else +using A0DataType = F4; +#endif +#if defined(B_DATATYPE) +using B0DataType = B_DATATYPE; +#else +using B0DataType = F4; +#endif +using A1DataType = XPackedDataType; +using B1DataType = XPackedDataType; -using A0DataType = F4; -using A1DataType = XPackedDataType; -using B0DataType = F4; -using B1DataType = XPackedDataType; using EDataType = F16; using AccDataType = F32; using CShuffleDataType = F32; @@ -89,67 +71,14 @@ struct MulABScaleExpertWeight } }; -using CDEElementOp = MulABScaleExpertWeight; - -// A, B Scale preshuffle -template -void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) -{ - int MNXdlPack = 2; - int KXdlPack = 2; - - int XdlMNThread = 16; - int XdlKThread = 64 / XdlMNThread; - - int K0 = K / KXdlPack / XdlKThread; // KRepeat - - // The 4 16x128 building blocks will be packed into 1 32x256 for F4 - // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 - - // unfold the MN32xK(256/32) scale buffer - // 4 16 2 2 - // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack - // Then, MNRepeat->KRepeat - - for(int n = 0; n < MN; ++n) - { - for(int k = 0; k < K; ++k) - { - int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat - int tempn = n % (XdlMNThread * MNXdlPack); - int n1 = tempn % XdlMNThread; // i XdlMNThread - int n2 = tempn / XdlMNThread; // i MNXdlPack - - int k0 = k / (XdlKThread * KXdlPack); // i KRepeat - int tempk = k % (XdlKThread * KXdlPack); - int k1 = tempk % XdlKThread; // i XdlKThread - int k2 = tempk / XdlKThread; // i KXdlPack - - int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + - k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + - k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + - k2 * MNXdlPack + n2; - // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + - // k2 * MNXdlPack))); - if constexpr(KLast) - dst[outputIndex] = src[n * K + k]; - else - dst[outputIndex] = src[k * MN + n]; - } - } -} - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -constexpr ck::index_t DataPackedSize = 2; // Packed representation of data -constexpr ck::index_t ScaleBlockSize = 32; // scaling block size -constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 128; // 128 fp4x2 or 128 fp8 static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul static constexpr ck::index_t MPerBlock = 128; @@ -166,10 +95,10 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic MPerBlock, NPerBlock, KPerBlock, 16, 16, 16, 16, - 4, 2, + 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 2, S<1, 32, 1, 8>, S<4, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on @@ -389,14 +318,30 @@ int main(int argc, char* argv[]) } // A/B scale shuffle - preShuffleScaleBuffer>(a_scale_sorted.mData.data(), - a_scale_preshuffled.mData.data(), - sorted_size, - K / ScaleBlockSize); - preShuffleScaleBuffer>(b1_e_n_k.mData.data(), - b_scale_preshuffled.mData.data(), - N * 2 * experts, - K / ScaleBlockSize); + if(ck::is_gfx125_supported()) + { + preShuffleScaleBuffer_gfx1250>( + a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer_gfx1250>( + b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + } + else + { + preShuffleScaleBuffer_gfx950>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer_gfx950>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + } sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); @@ -449,9 +394,10 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + ck::is_gfx125_supported())) { - std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + std::cout << "This kernel support gfx942, gfx950 and gfx125x only" << std::endl; } if(time_kernel) diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp index 5c7668ab73..45b5b75820 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp @@ -1,33 +1,10 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "gemm_mx_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/utility/blkgemmpipe_scheduler.hpp" - -using ::ck::DeviceMem; -using ::ck::HostTensorDescriptor; -using ::ck::Tensor; - -template -using S = ck::Sequence; - +using F8 = ck::f8_t; using F4 = ck::f4x2_pk_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -36,13 +13,17 @@ using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t using I64 = int64_t; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using Bypass = ck::tensor_layout::BypassLayoutVerification; - -using A0DataType = F4; +#if defined(A_DATATYPE) +using A0DataType = A_DATATYPE; +#else +using A0DataType = F4; +#endif +#if defined(B_DATATYPE) +using B0DataType = B_DATATYPE; +#else +using B0DataType = F4; +#endif using A1DataType = XPackedDataType; -using B0DataType = F4; using B1DataType = XPackedDataType; using EDataType = F16; using AccDataType = F32; @@ -90,88 +71,6 @@ struct MulABScaleExpertWeight } }; -using CDEElementOp = MulABScaleExpertWeight; - -// B preshuffle -void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl) -{ - int KPack = 16; - int NLane = NXdl; - int KLane = 64 / NLane; - int K_pk = K / 2; - int K0 = K_pk / (KLane * KPack); - // K -> K0 KLane KPack - // N -> N0 NLane - // N, K -> N0 K0 KLane NLane KPack - I64 tempk; - for(I64 n = 0; n < N; ++n) - { - for(I64 k = 0; k < K_pk; ++k) - { - I64 n0 = n / NLane; - I64 n1 = n % NLane; - - I64 k0 = k / (KLane * KPack); - tempk = k % (KLane * KPack); - I64 k1 = tempk / KPack; - I64 k2 = tempk % KPack; - - I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + - k1 * KPack * NLane + n1 * KPack + k2; - - dst[outputIndex] = src[n * K_pk + k]; - } - } -} - -// A, B Scale preshuffle -template -void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) -{ - int MNXdlPack = 2; - int KXdlPack = 2; - - int XdlMNThread = 16; - int XdlKThread = 64 / XdlMNThread; - - int K0 = K / KXdlPack / XdlKThread; // KRepeat - - // The 4 16x128 building blocks will be packed into 1 32x256 for F4 - // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 - - // unfold the MN32xK(256/32) scale buffer - // 4 16 2 2 - // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack - // Then, MNRepeat->KRepeat - - for(int n = 0; n < MN; ++n) - { - for(int k = 0; k < K; ++k) - { - int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat - int tempn = n % (XdlMNThread * MNXdlPack); - int n1 = tempn % XdlMNThread; // i XdlMNThread - int n2 = tempn / XdlMNThread; // i MNXdlPack - - int k0 = k / (XdlKThread * KXdlPack); // i KRepeat - int tempk = k % (XdlKThread * KXdlPack); - int k1 = tempk % XdlKThread; // i XdlKThread - int k2 = tempk / XdlKThread; // i KXdlPack - - int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + - k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + - k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + - k2 * MNXdlPack + n2; - // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + - // k2 * MNXdlPack))); - if constexpr(KLast) - dst[outputIndex] = src[n * K + k]; - else - dst[outputIndex] = src[k * MN + n]; - } - } -} - using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough; @@ -180,9 +79,8 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -constexpr ck::index_t DataPackedSize = 2; // Packed representation of data -constexpr ck::index_t ScaleBlockSize = 32; // scaling block size -constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 128; static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul static constexpr ck::index_t MPerBlock = 32; @@ -194,12 +92,12 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffl A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ScaleBlockSize, 256, - MPerBlock, 128, KPerBlock, + MPerBlock, 256, KPerBlock, 16, 16, 16, 16, - 2, 2, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on @@ -415,15 +313,30 @@ int main(int argc, char* argv[]) } // A/B scale shuffle - preShuffleScaleBuffer>(a_scale_sorted.mData.data(), - a_scale_preshuffled.mData.data(), - sorted_size, - K / ScaleBlockSize); - preShuffleScaleBuffer>(b1_e_n_k.mData.data(), - b_scale_preshuffled.mData.data(), - N * 2 * experts, - K / ScaleBlockSize); - + if(ck::is_gfx125_supported()) + { + preShuffleScaleBuffer_gfx1250>( + a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer_gfx1250>( + b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + } + else + { + preShuffleScaleBuffer_gfx950>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer_gfx950>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + } sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data()); @@ -482,9 +395,10 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + ck::is_gfx125_supported())) { - std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + std::cout << "This kernel support gfx942, gfx950 and gfx125x only" << std::endl; } if(time_kernel) @@ -566,7 +480,6 @@ int main(int argc, char* argv[]) } e_device_buf.FromDevice(e_t_k_n_device_result.mData.data()); - auto status = ck::utils::check_err( e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp index 04c3afc62b..c2075da934 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -1,33 +1,10 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "gemm_mx_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/utility/blkgemmpipe_scheduler.hpp" - -using ::ck::DeviceMem; -using ::ck::HostTensorDescriptor; -using ::ck::Tensor; - -template -using S = ck::Sequence; - +using F8 = ck::f8_t; using F4 = ck::f4x2_pk_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -35,13 +12,17 @@ using F32 = float; using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using Bypass = ck::tensor_layout::BypassLayoutVerification; - -using A0DataType = F4; +#if defined(A_DATATYPE) +using A0DataType = A_DATATYPE; +#else +using A0DataType = F4; +#endif +#if defined(B_DATATYPE) +using B0DataType = B_DATATYPE; +#else +using B0DataType = F4; +#endif using A1DataType = XPackedDataType; -using B0DataType = F4; using B1DataType = XPackedDataType; using EDataType = F16; using AccDataType = F32; @@ -86,71 +67,22 @@ struct MulABScaleExpertWeight } }; -using CDEElementOp = MulABScaleExpertWeight; - -// A, B Scale preshuffle -template -void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) -{ - int MNXdlPack = 2; - int KXdlPack = 2; - - int XdlMNThread = 16; - int XdlKThread = 64 / XdlMNThread; - - int K0 = K / KXdlPack / XdlKThread; // KRepeat - - // The 4 16x128 building blocks will be packed into 1 32x256 for F4 - // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 - - // unfold the MN32xK(256/32) scale buffer - // 4 16 2 2 - // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack - // Then, MNRepeat->KRepeat - - for(int n = 0; n < MN; ++n) - { - for(int k = 0; k < K; ++k) - { - int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat - int tempn = n % (XdlMNThread * MNXdlPack); - int n1 = tempn % XdlMNThread; // i XdlMNThread - int n2 = tempn / XdlMNThread; // i MNXdlPack - - int k0 = k / (XdlKThread * KXdlPack); // i KRepeat - int tempk = k % (XdlKThread * KXdlPack); - int k1 = tempk % XdlKThread; // i XdlKThread - int k2 = tempk / XdlKThread; // i KXdlPack - - int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + - k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + - k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + - k2 * MNXdlPack + n2; - // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + - // k2 * MNXdlPack))); - if constexpr(KLast) - dst[outputIndex] = src[n * K + k]; - else - dst[outputIndex] = src[k * MN + n]; - } - } -} - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -constexpr ck::index_t DataPackedSize = 2; // Packed representation of data -constexpr ck::index_t ScaleBlockSize = 32; // scaling block size -constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 - +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 128; static constexpr ck::index_t MPerBlock = 128; static constexpr bool MulRoutedWeight = true; +static constexpr ck::index_t ClusterLengths_BK0 = + ck::is_same_v && ck::is_same_v ? 8 : 4; +static constexpr ck::index_t ClusterLengths_N = + ck::is_same_v && ck::is_same_v ? 32 : 64; + // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX< A0Layout, B0Layout, DsLayout, ELayout, @@ -162,7 +94,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on @@ -403,13 +335,30 @@ int main(int argc, char* argv[]) } } - preShuffleScaleBuffer>(a_scale_sorted.mData.data(), - a_scale_preshuffled.mData.data(), - sorted_size, - K / ScaleBlockSize); - preShuffleScaleBuffer>( - b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize); - + if(ck::is_gfx125_supported()) + { + preShuffleScaleBuffer_gfx1250>( + a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer_gfx1250>( + b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * experts, + K / ScaleBlockSize); + } + else + { + preShuffleScaleBuffer_gfx950>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer_gfx950>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * experts, + K / ScaleBlockSize); + } sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data()); @@ -461,9 +410,10 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + ck::is_gfx125_supported())) { - std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + std::cout << "This kernel support gfx942, gfx950 and gfx125x only" << std::endl; } if(time_kernel) @@ -539,7 +489,8 @@ int main(int argc, char* argv[]) } e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - + e_t_n_device_result.savetxt("e_device.txt"); + e_t_n_host_result.savetxt("e_host.txt"); return ck::utils::check_err( e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) ? 0 diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp index 12bb76eccd..c00c96faca 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -1,33 +1,10 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "gemm_mx_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/utility/blkgemmpipe_scheduler.hpp" - -using ::ck::DeviceMem; -using ::ck::HostTensorDescriptor; -using ::ck::Tensor; - -template -using S = ck::Sequence; - +using F8 = ck::f8_t; using F4 = ck::f4x2_pk_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -35,13 +12,17 @@ using F32 = float; using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using Bypass = ck::tensor_layout::BypassLayoutVerification; - -using A0DataType = F4; +#if defined(A_DATATYPE) +using A0DataType = A_DATATYPE; +#else +using A0DataType = F4; +#endif +#if defined(B_DATATYPE) +using B0DataType = B_DATATYPE; +#else +using B0DataType = F4; +#endif using A1DataType = XPackedDataType; -using B0DataType = F4; using B1DataType = XPackedDataType; using EDataType = F16; using AccDataType = F32; @@ -86,67 +67,13 @@ struct MulABScaleExpertWeight } }; -using CDEElementOp = MulABScaleExpertWeight; - -// A, B Scale preshuffle -template -void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) -{ - int MNXdlPack = 2; - int KXdlPack = 2; - - int XdlMNThread = 16; - int XdlKThread = 64 / XdlMNThread; - - int K0 = K / KXdlPack / XdlKThread; // KRepeat - - // The 4 16x128 building blocks will be packed into 1 32x256 for F4 - // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 - - // unfold the MN32xK(256/32) scale buffer - // 4 16 2 2 - // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack - // Then, MNRepeat->KRepeat - - for(int n = 0; n < MN; ++n) - { - for(int k = 0; k < K; ++k) - { - int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat - int tempn = n % (XdlMNThread * MNXdlPack); - int n1 = tempn % XdlMNThread; // i XdlMNThread - int n2 = tempn / XdlMNThread; // i MNXdlPack - - int k0 = k / (XdlKThread * KXdlPack); // i KRepeat - int tempk = k % (XdlKThread * KXdlPack); - int k1 = tempk % XdlKThread; // i XdlKThread - int k2 = tempk / XdlKThread; // i KXdlPack - - int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + - k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + - k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + - k2 * MNXdlPack + n2; - // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + - // k2 * MNXdlPack))); - if constexpr(KLast) - dst[outputIndex] = src[n * K + k]; - else - dst[outputIndex] = src[k * MN + n]; - } - } -} - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = MulABScaleExpertWeight; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - -constexpr ck::index_t DataPackedSize = 2; // Packed representation of data -constexpr ck::index_t ScaleBlockSize = 32; // scaling block size -constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 128; static constexpr ck::index_t MPerBlock = 128; static constexpr bool MulRoutedWeight = true; @@ -162,7 +89,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on @@ -388,12 +315,30 @@ int main(int argc, char* argv[]) } } - preShuffleScaleBuffer>(a_scale_sorted.mData.data(), - a_scale_preshuffled.mData.data(), - sorted_size, - K / ScaleBlockSize); - preShuffleScaleBuffer>( - b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize); + if(ck::is_gfx125_supported()) + { + preShuffleScaleBuffer_gfx1250>( + a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer_gfx1250>( + b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * experts, + K / ScaleBlockSize); + } + else + { + preShuffleScaleBuffer_gfx950>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer_gfx950>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * experts, + K / ScaleBlockSize); + } sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); @@ -446,9 +391,10 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + ck::is_gfx125_supported())) { - std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + std::cout << "This kernel support gfx942, gfx950 and gfx125x only" << std::endl; } if(time_kernel) diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp index 6a5f5a6b9f..89d093f254 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp @@ -1,33 +1,10 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "gemm_mx_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/utility/blkgemmpipe_scheduler.hpp" - -using ::ck::DeviceMem; -using ::ck::HostTensorDescriptor; -using ::ck::Tensor; - -template -using S = ck::Sequence; - +using F8 = ck::f8_t; using F4 = ck::f4x2_pk_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -36,14 +13,19 @@ using XDataType = ck::e8m0_bexp_t; using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t using I64 = int64_t; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using Bypass = ck::tensor_layout::BypassLayoutVerification; +#if defined(A_DATATYPE) +using A0DataType = A_DATATYPE; +#else +using A0DataType = F4; +#endif +#if defined(B_DATATYPE) +using B0DataType = B_DATATYPE; +#else +using B0DataType = F4; +#endif +using A1DataType = XPackedDataType; +using B1DataType = XPackedDataType; -using A0DataType = F4; -using A1DataType = XPackedDataType; -using B0DataType = F4; -using B1DataType = XPackedDataType; using EDataType = F16; using AccDataType = F32; using CShuffleDataType = F16; @@ -87,14 +69,12 @@ struct MulABScaleExpertWeight } }; -using CDEElementOp = MulABScaleExpertWeight; - // B preshuffle void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl) { int KPack = 16; int NLane = NXdl; - int KLane = 64 / NLane; + int KLane = ck::get_warp_size() / NLane; int K_pk = K / 2; int K0 = K_pk / (KLane * KPack); // K -> K0 KLane KPack @@ -121,65 +101,14 @@ void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl) } } -// A, B Scale preshuffle -template -void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) -{ - int MNXdlPack = 2; - int KXdlPack = 2; - - int XdlMNThread = 16; - int XdlKThread = 64 / XdlMNThread; - - int K0 = K / KXdlPack / XdlKThread; // KRepeat - - // The 4 16x128 building blocks will be packed into 1 32x256 for F4 - // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 - - // unfold the MN32xK(256/32) scale buffer - // 4 16 2 2 - // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack - // Then, MNRepeat->KRepeat - - for(int n = 0; n < MN; ++n) - { - for(int k = 0; k < K; ++k) - { - int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat - int tempn = n % (XdlMNThread * MNXdlPack); - int n1 = tempn % XdlMNThread; // i XdlMNThread - int n2 = tempn / XdlMNThread; // i MNXdlPack - - int k0 = k / (XdlKThread * KXdlPack); // i KRepeat - int tempk = k % (XdlKThread * KXdlPack); - int k1 = tempk % XdlKThread; // i XdlKThread - int k2 = tempk / XdlKThread; // i KXdlPack - - int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + - k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + - k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + - k2 * MNXdlPack + n2; - // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + - // k2 * MNXdlPack))); - if constexpr(KLast) - dst[outputIndex] = src[n * K + k]; - else - dst[outputIndex] = src[k * MN + n]; - } - } -} - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -constexpr ck::index_t DataPackedSize = 2; // Packed representation of data -constexpr ck::index_t ScaleBlockSize = 32; // scaling block size -constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 128; static constexpr ck::index_t MPerBlock = 128; static constexpr bool MulRoutedWeight = true; @@ -190,12 +119,12 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ScaleBlockSize, 256, - MPerBlock, 128, KPerBlock, + MPerBlock, 256, KPerBlock, 16, 16, 16, 16, - 8, 2, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 2, 2, S<1, 4, 1, 64>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on @@ -438,12 +367,30 @@ int main(int argc, char* argv[]) } // A, B Scale preshuffle - preShuffleScaleBuffer>(a_scale_sorted.mData.data(), - a_scale_preshuffled.mData.data(), - sorted_size, - K / ScaleBlockSize); - preShuffleScaleBuffer>( - b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize); + if(ck::is_gfx125_supported()) + { + preShuffleScaleBuffer_gfx1250>( + a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer_gfx1250>( + b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * experts, + K / ScaleBlockSize); + } + else + { + preShuffleScaleBuffer_gfx950>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer_gfx950>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * experts, + K / ScaleBlockSize); + } sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); @@ -503,9 +450,10 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + ck::is_gfx125_supported())) { - std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + std::cout << "This kernel support gfx942, gfx950 and gfx125x only" << std::endl; } if(time_kernel) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index c39f89fcaf..aedb73a70e 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -52,6 +52,9 @@ function(add_example_dependencies EXAMPLE_NAME FILE_NAME) endfunction(add_example_dependencies EXAMPLE_NAME) function(add_example_executable EXAMPLE_NAME FILE_NAME) + foreach(source IN LISTS ARGN) + set(FILE_NAME ${FILE_NAME} ${source}) + endforeach() message(DEBUG "adding example ${EXAMPLE_NAME}") set(result 1) if(DEFINED DTYPES) @@ -110,9 +113,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) message(DEBUG "removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - #Do not build any microscaling examples if gfx950 target is not on the list - if(NOT EX_TARGETS MATCHES "gfx950" AND source_name MATCHES "_mx") - message(DEBUG "removing microscaling example ${source} ") + #Do not build any microscaling examples if gfx950|gfx125 target is not on the list + if(source_name MATCHES "_mx" AND NOT EX_TARGETS MATCHES "gfx95|gfx125") + message(STATUS "removing microscaling example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() #Do not build any FP8 examples if CK_ENABLE_FP8 not set diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index b685bfe6ab..0650bd3de0 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -3,7 +3,7 @@ set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) # Currently only gfx9 and gfx12 archs are supported by FMHA -list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx1[12]") +list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12") if(NOT INST_TARGETS) message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx11, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") return() @@ -234,6 +234,7 @@ message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}") add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp) target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES}) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set_property(TARGET ${EXAMPLE_FMHA_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") # not using add_example_executable() to add this target, since we don't want this to be included in @@ -241,6 +242,7 @@ message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp) target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set_property(TARGET ${EXAMPLE_FMHA_BWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 0a71ef6770..1191b5d6de 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -517,6 +517,25 @@ class KernelComponentFactoryGfx12(KernelComponentFactoryBase): return [] +class KernelComponentFactoryGfx125(KernelComponentFactoryBase): + arch = ArchTrait("gfx125") + + @staticmethod + def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: + if tr_load == "t": + return [] + if dtype in ["fp16", "bf16"]: + return [ + # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, + FmhaBwdDQDKDVTileSize( 32, 64, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaBwdDQDKDVTileSize( 32, 64, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), + #FmhaBwdDQDKDVTileSize( 32, 64, 64, 32, 64, 32, 64, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaBwdDQDKDVTileSize( 32, 64, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaBwdDQDKDVTileSize( 32, 64, 256, 32, 256, 32, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1), + ] # fmt: skip + return [] + + def get_factory(target: str): # Place more specific architectures first @@ -524,9 +543,10 @@ def get_factory(target: str): return KernelComponentFactoryGfx950 if target.startswith("gfx9"): return KernelComponentFactoryGfx9 - if target.startswith("gfx11"): return KernelComponentFactoryGfx11 + if target.startswith("gfx125"): + return KernelComponentFactoryGfx125 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index d67fc06690..7cd15e8a57 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1334,6 +1334,69 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip return pipelines +class KernelComponentFactoryGfx125(CompatibilityRuleFactory): + arch = ArchTrait("gfx125") + + _DT_FP16_BF16 = ("fp16", "bf16") + _DT_FP8_FP8BF16 = ("fp8", "fp8bf16") + _DT_FP8FP32 = ("fp8fp32",) + + @classmethod + def supported_dtypes(cls) -> Tuple[str]: + return cls._DT_FP16_BF16 + cls._DT_FP8_FP8BF16 + cls._DT_FP8FP32 + + @classmethod + def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: + if dtype in cls._DT_FP16_BF16: + return { + # bm0, bn0, bk0, bn1, bk1, + ( 32, 32) : [FmhaFwdTileSize( 64, 64, 32, 32, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1)], + (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1)], + (192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1)], + (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1)], + } # fmt: skip + elif dtype in cls._DT_FP8_FP8BF16: + return { + # bm0, bn0, bk0, bn1, bk1, + ( 64, 64) : [FmhaFwdTileSize(128, 64, 64, 64, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 64, 16, 16, 64, -1)], + (128, 128) : [FmhaFwdTileSize( 64, 64, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 64, 16, 16, 64, -1)], + #(256, 256) : [FmhaFwdTileSize( 64, 32, 64, 256, 64, 256, 4, 1, 1, 4, 1, 1, 16, 16, 64, 16, 16, 64, -1)], + } # fmt: skip + elif dtype in cls._DT_FP8FP32: + return { + # bm0, bn0, bk0, bn1, bk1, + (128, 128) : [FmhaFwdTileSize( 64, 64, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 64, 16, 16, 64, -1)], + } # fmt: skip + else: + raise ValueError(f"unsupported dtype={dtype}") + + @classmethod + def get_pipelines( + cls, dtype, hdim, hdim_v, receipt, mask_impl + ) -> List[FmhaFwdPipeline]: + pipelines = [] + if dtype in cls._DT_FP16_BF16: + qscale = "no" + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t", "f"], + ["t", "f"], + ["t", "f"], + ["t", "f"], + ): + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: + # no need lse/dropout kernels + for logits, qscale, mask, bias in itertools.product( + ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] + ): + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip + return pipelines class CustomFactory(KernelComponentFactoryGfx9, CompatibilityRuleFactoryGfx9): @classmethod @@ -1355,11 +1418,12 @@ def get_factory(target: str): return KernelComponentFactoryGfx950 if target.startswith("gfx9"): return KernelComponentFactoryGfx9 - if target.startswith("gfx115"): return KernelComponentFactoryGfx115 if target.startswith("gfx11"): return KernelComponentFactoryGfx11 + if target.startswith("gfx125"): + return KernelComponentFactoryGfx125 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 793a743df7..03d04037e7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -408,14 +408,19 @@ class KernelComponentFactoryGfx12(KernelComponentFactoryBase): arch = ArchTrait("gfx12") +class KernelComponentFactoryGfx125(KernelComponentFactoryBase): + arch = ArchTrait("gfx125") + + def get_factory(target: str): # Place more specific architectures first if target.startswith("gfx9"): return KernelComponentFactoryGfx9 - if target.startswith("gfx11"): return KernelComponentFactoryGfx11 + if target.startswith("gfx125"): + return KernelComponentFactoryGfx125 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index b5ffb7739d..875a655976 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -877,14 +877,38 @@ class KernelComponentFactoryGfx12(KernelComponentFactoryBase): return None +class KernelComponentFactoryGfx125(KernelComponentFactoryBase): + arch = ArchTrait("gfx125") + + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype in ["fp16", "bf16"]: + return { + # bm0, bn0, bk0, bn1, bk1, + "32" : FmhaFwdTileSize( 64, 64, 32, 32, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), + "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), + "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), + "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), + } # fmt: skip + elif dtype in ["fp8", "bf8"]: + return { + # bm0, bn0, bk0, bn1, bk1, + "64" : FmhaFwdTileSize(128, 64, 64, 64, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 64, 16, 16, 64, -1), + "128": FmhaFwdTileSize( 64, 64, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 64, 16, 16, 64, -1), + } # fmt: skip + else: + return None + + def get_factory(target: str): # Place more specific architectures first if target.startswith("gfx9"): return KernelComponentFactoryGfx9 - if target.startswith("gfx11"): return KernelComponentFactoryGfx11 + if target.startswith("gfx125"): + return KernelComponentFactoryGfx125 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 7c7bddb345..935b63472f 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -639,15 +639,39 @@ class KernelComponentFactoryGfx12(KernelComponentFactoryBase): else: return None +class KernelComponentFactoryGfx125(KernelComponentFactoryBase): + arch = ArchTrait("gfx125") + + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype in ["fp16", "bf16"]: + return { + # bm0, bn0, bk0, bn1, bk1, + # "32": FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 32, -1), + # "64": FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 32, -1), + "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # "192": FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 32, -1), + # "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 32, -1), + } # fmt: skip + elif dtype in ["fp8", "bf8"]: + return { + # bm0, bn0, bk0, bn1, bk1, + "64": FmhaFwdTileSize(128, 64, 64, 64, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 64, 16, 16, 64, -1), + "128": FmhaFwdTileSize( 64, 64, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 64, 16, 16, 64, -1), + "256": FmhaFwdTileSize( 64, 64, 64, 256, 64, 256, 4, 1, 1, 4, 1, 1, 16, 16, 64, 16, 16, 64, -1), + } # fmt: skip + else: + return None def get_factory(target: str): # Place more specific architectures first if target.startswith("gfx9"): return KernelComponentFactoryGfx9 - if target.startswith("gfx11"): return KernelComponentFactoryGfx11 + if target.startswith("gfx125"): + return KernelComponentFactoryGfx125 if target.startswith("gfx12"): return KernelComponentFactoryGfx12 diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 40547d0719..85094df677 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,12 +1,18 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx125") add_executable(tile_example_gemm_basic gemm_basic.cpp) add_executable(tile_example_gemm_universal universal_gemm.cpp) add_executable(tile_example_gemm_weight_preshuffle gemm_weight_preshuffle.cpp) add_executable(tile_example_gemm_reduce gemm_splitk_two_stage_reduce.cpp) add_executable(tile_example_gemm_splitk_two_stage gemm_splitk_two_stage.cpp) + if(GPU_TARGETS MATCHES "gfx125") + add_executable(tile_example_gemm_mixed_prec gemm_mixed_prec.cpp) + add_executable(tile_example_gemm_tdm_data_cache_prefetch gemm_tdm_data_cache_prefetch.cpp) + add_executable(tile_example_gemm_weight_preshuffle_tdm_data_cache_prefetch gemm_weight_preshuffle_tdm_data_cache_prefetch.cpp) + add_executable(tile_example_gemm_universal_cluster_launch universal_gemm.cpp) + endif() set(EXAMPLE_GEMM_COMPILE_OPTIONS) set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) @@ -15,11 +21,18 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a") list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef) list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker) - list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps) + #list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps) list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm -enable-noalias-to-md-conversion=0") target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_splitk_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + if(GPU_TARGETS MATCHES "gfx125") + target_compile_options(tile_example_gemm_mixed_prec PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_gemm_tdm_data_cache_prefetch PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_gemm_weight_preshuffle_tdm_data_cache_prefetch PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) + target_compile_options(tile_example_gemm_universal_cluster_launch PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_definitions(tile_example_gemm_universal_cluster_launch PRIVATE CLUSTER_LAUNCH_ENABLED=1) + endif() endif() diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index 4681c19f9b..d056507875 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -49,6 +49,8 @@ make tile_example_gemm_basic -j`nproc` make tile_example_gemm_universal -j`nproc` # The weight preshuffle pipeline on the gemm calculation make tile_example_gemm_weight_preshuffle -j`nproc` +# gfx125 only: weight preshuffle TDM pipeline with data cache prefetch controls +make tile_example_gemm_weight_preshuffle_tdm_data_cache_prefetch -j`nproc` ``` This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal` diff --git a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp index cccc1dcc06..9dbb53ab93 100644 --- a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp @@ -52,7 +52,9 @@ struct BasicInvoker constexpr ck_tile::index_t M_Warp_Tile = 16; constexpr ck_tile::index_t N_Warp_Tile = 16; - constexpr ck_tile::index_t K_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + ck_tile::ignore = is_tf32_compute; #else // gfx950: fp32 uses 16x16x16 tile (native MFMA) // tf32 uses 32x32x16 tile (3x bf16 32x32x16 MFMA emulation) @@ -79,15 +81,20 @@ struct BasicInvoker BLayout, CLayout>; - using CodegenPipelineProblem = - ck_tile::GemmPipelineProblem; + using AComputeDataType = std:: + conditional_t, BDataType_, ADataType_>; + using BComputeDataType = + std::conditional_t || + std::is_same_v, + ADataType_, + BDataType_>; + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; diff --git a/example/ck_tile/03_gemm/gemm_mixed_prec.cpp b/example/ck_tile/03_gemm/gemm_mixed_prec.cpp new file mode 100644 index 0000000000..bdb77a3397 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_mixed_prec.cpp @@ -0,0 +1,66 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "gemm_utils.hpp" +#include "run_gemm_example.inc" +#include "run_gemm_example_common.hpp" +#include "universal_gemm_invoker.hpp" + +template