mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 08:48:45 +00:00
[rocm-libraries] ROCm/rocm-libraries#6978 (commit e58096d)
[CK] add composable kernel support on gfx1250 (#6978) ## Motivation Add composable kernel support on gfx1250. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Qun Lin <qlin@amd.com> Co-authored-by: jialuo12_amdeng <jia.luo@amd.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: hsivasun_amdeng <haresh.sivasuntharampillai@amd.com>
This commit is contained in:
@@ -50,6 +50,7 @@ option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
|
||||
option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF)
|
||||
option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF)
|
||||
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
|
||||
option(CK_EXPERIMENTAL_GEMM_BENCHMARK "Enable experimental gemm benchmark for gfx1250" OFF)
|
||||
option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
|
||||
option(FORCE_DISABLE_WMMA "Skip compiling WMMA specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
|
||||
option(BUILD_CK_TILE_ENGINE "Build the tile_engine subdirectory" OFF)
|
||||
@@ -303,6 +304,9 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx10")
|
||||
add_definitions(-DCK_GFX1030_SUPPORT)
|
||||
endif()
|
||||
|
||||
# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA
|
||||
set(CK_TILE_USE_WMMA 0)
|
||||
|
||||
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") AND NOT FORCE_DISABLE_WMMA)
|
||||
message(STATUS "Enabling WMMA instances")
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
@@ -321,6 +325,8 @@ endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950")
|
||||
add_definitions(-DCK_USE_OCP_FP8)
|
||||
set(CK_USE_OCP_FP8 "ON")
|
||||
add_definitions(-DCK_TILE_USE_OCP_FP8)
|
||||
set(CK_TILE_USE_OCP_FP8 "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
add_definitions(-DCK_USE_FNUZ_FP8)
|
||||
@@ -332,6 +338,16 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
|
||||
add_definitions(-DCK_GFX950_SUPPORT)
|
||||
set(CK_GFX950_SUPPORT "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx1250")
|
||||
add_definitions(-DCK_USE_GFX1250)
|
||||
add_definitions(-DCK_USE_NATIVE_MX_SUPPORT)
|
||||
set(CK_USE_NATIVE_MX_SUPPORT "ON")
|
||||
add_definitions(-DCK_GFX1250_SUPPORT)
|
||||
set(CK_GFX1250_SUPPORT "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
add_definitions(-DCK_GFX12_SUPPORT)
|
||||
endif()
|
||||
|
||||
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
|
||||
add_definitions(-DCK_ENABLE_TF32)
|
||||
@@ -413,6 +429,7 @@ endif()
|
||||
option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
|
||||
option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF)
|
||||
option(ENABLE_JSON_DUMP "Whether to enable json dump for examples." OFF)
|
||||
option(CK_TEST_DISABLE_GPU_VALIDATION "Whether to disable GPU validation in CK tests." OFF )
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
|
||||
@@ -431,6 +448,10 @@ if (ENABLE_JSON_DUMP)
|
||||
message("CK compiled with ENABLE_JSON_DUMP set to ${ENABLE_JSON_DUMP}")
|
||||
endif()
|
||||
|
||||
if (CK_TEST_DISABLE_GPU_VALIDATION)
|
||||
add_compile_definitions(CK_TEST_DISABLE_GPU_VALIDATION)
|
||||
endif()
|
||||
|
||||
## Threads
|
||||
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||
find_package(Threads REQUIRED)
|
||||
@@ -797,6 +818,10 @@ if(BUILD_CK_PROFILER)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (CK_EXPERIMENTAL_GEMM_BENCHMARK)
|
||||
add_subdirectory(experimental/gemm_benchmark)
|
||||
endif()
|
||||
|
||||
if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
|
||||
add_subdirectory(codegen)
|
||||
endif()
|
||||
|
||||
146
Jenkinsfile
vendored
146
Jenkinsfile
vendored
@@ -772,6 +772,26 @@ def cmake_build(Map conf=[:]){
|
||||
try {
|
||||
//build CK
|
||||
sh cmd
|
||||
if (runAllUnitTests){
|
||||
// Archive artifacts if they were generated
|
||||
if (fileExists("ck_build_trace_${arch_name}.json")) {
|
||||
archiveArtifacts "ck_build_trace_${arch_name}.json"
|
||||
}
|
||||
if (fileExists("clang_build_analysis_${arch_name}.log")) {
|
||||
archiveArtifacts "clang_build_analysis_${arch_name}.log"
|
||||
}
|
||||
// Process ninja build trace after full build
|
||||
sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${arch_name}.json"
|
||||
archiveArtifacts "ck_build_trace_${arch_name}.json"
|
||||
sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${arch_name}.json"
|
||||
|
||||
if (params.NINJA_FTIME_TRACE) {
|
||||
echo "running ClangBuildAnalyzer"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${arch_name}.log"
|
||||
archiveArtifacts "clang_build_analysis_${arch_name}.log"
|
||||
}
|
||||
}
|
||||
} catch (Exception buildError) {
|
||||
echo "Build failed: ${buildError.getMessage()}"
|
||||
throw buildError
|
||||
@@ -796,49 +816,32 @@ def cmake_build(Map conf=[:]){
|
||||
}
|
||||
}
|
||||
|
||||
//run tests except when NO_CK_BUILD is set
|
||||
//run tests except when NO_CK_BUILD is set and except on gfx1250
|
||||
if(!setup_args.contains("NO_CK_BUILD")){
|
||||
if (params.NINJA_BUILD_TRACE || params.BUILD_INSTANCES_ONLY){
|
||||
// do not run unit tests when building instances only
|
||||
if(!params.BUILD_INSTANCES_ONLY){
|
||||
if (!runAllUnitTests){
|
||||
// Smart Build: Run smart_build_and_test.sh
|
||||
sh """
|
||||
export WORKSPACE_ROOT=${env.WORKSPACE}
|
||||
export PARALLEL=32
|
||||
export NINJA_JOBS=${nt}
|
||||
export ARCH_NAME=${arch_name}
|
||||
export PROCESS_NINJA_TRACE=true
|
||||
export NINJA_FTIME_TRACE=${params.NINJA_FTIME_TRACE ? 'true' : 'false'}
|
||||
bash ../script/dependency-parser/smart_build_and_test.sh
|
||||
"""
|
||||
|
||||
// Archive artifacts if they were generated
|
||||
if (fileExists("ck_build_trace_${arch_name}.json")) {
|
||||
archiveArtifacts "ck_build_trace_${arch_name}.json"
|
||||
}
|
||||
if (fileExists("clang_build_analysis_${arch_name}.log")) {
|
||||
archiveArtifacts "clang_build_analysis_${arch_name}.log"
|
||||
}
|
||||
}
|
||||
else{
|
||||
// run unit tests unless building library for all targets
|
||||
// Note: This else block is when NINJA_BUILD_TRACE=false and BUILD_INSTANCES_ONLY=false
|
||||
// So no ninja trace processing needed here
|
||||
if (!params.BUILD_INSTANCES_ONLY){
|
||||
if (!runAllUnitTests && !setup_args.contains("gfx1250") ){
|
||||
// Smart Build: Run smart_build_and_test.sh
|
||||
sh """
|
||||
export WORKSPACE_ROOT=${env.WORKSPACE}
|
||||
export PARALLEL=32
|
||||
export NINJA_JOBS=${nt}
|
||||
export ARCH_NAME=${arch_name}
|
||||
export PROCESS_NINJA_TRACE=false
|
||||
export NINJA_FTIME_TRACE=false
|
||||
bash ../script/dependency-parser/smart_build_and_test.sh
|
||||
"""
|
||||
}
|
||||
else{ //run all tests
|
||||
if(!setup_args.contains("gfx1250")){
|
||||
echo "Full test suite requested (RUN_ALL_UNIT_TESTS=true or develop branch)"
|
||||
sh "ninja -j${nt} check"
|
||||
|
||||
// Process ninja build trace after full build
|
||||
sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${arch_name}.json"
|
||||
archiveArtifacts "ck_build_trace_${arch_name}.json"
|
||||
sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${arch_name}.json"
|
||||
|
||||
if (params.NINJA_FTIME_TRACE) {
|
||||
echo "running ClangBuildAnalyzer"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${arch_name}.log"
|
||||
archiveArtifacts "clang_build_analysis_${arch_name}.log"
|
||||
}
|
||||
}
|
||||
if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) {
|
||||
sh 'ninja check-builder'
|
||||
else{ //do not run tests on gfx1250, just build everything
|
||||
echo "Building for gfx1250"
|
||||
sh "ninja -j${nt}"
|
||||
}
|
||||
if (params.RUN_ROCM_CK_TESTS) {
|
||||
sh 'ninja check-rocm-ck'
|
||||
@@ -850,47 +853,8 @@ def cmake_build(Map conf=[:]){
|
||||
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
|
||||
}
|
||||
}
|
||||
if(params.BUILD_INSTANCES_ONLY){
|
||||
// build deb packages
|
||||
echo "Build library package"
|
||||
sh 'ninja -j64 package'
|
||||
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.2.0_amd64.deb'
|
||||
stash includes: "composablekernel-dev**.deb", name: "lib_package"
|
||||
}
|
||||
}
|
||||
else{
|
||||
// run unit tests unless building library for all targets
|
||||
// Note: This else block is when NINJA_BUILD_TRACE=false and BUILD_INSTANCES_ONLY=false
|
||||
// So no ninja trace processing needed here
|
||||
if (!params.BUILD_INSTANCES_ONLY){
|
||||
if (!runAllUnitTests){
|
||||
// Smart Build: Run smart_build_and_test.sh
|
||||
sh """
|
||||
export WORKSPACE_ROOT=${env.WORKSPACE}
|
||||
export PARALLEL=32
|
||||
export NINJA_JOBS=${nt}
|
||||
export ARCH_NAME=${arch_name}
|
||||
export PROCESS_NINJA_TRACE=false
|
||||
export NINJA_FTIME_TRACE=false
|
||||
bash ../script/dependency-parser/smart_build_and_test.sh
|
||||
"""
|
||||
}
|
||||
else{
|
||||
echo "Full test suite requested (RUN_ALL_UNIT_TESTS=true or develop branch)"
|
||||
sh "ninja -j${nt} check"
|
||||
}
|
||||
if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) {
|
||||
sh 'ninja check-builder'
|
||||
}
|
||||
if (params.RUN_ROCM_CK_TESTS) {
|
||||
sh 'ninja check-rocm-ck'
|
||||
}
|
||||
if(params.BUILD_PACKAGES){
|
||||
echo "Build ckProfiler packages"
|
||||
sh 'ninja -j64 package'
|
||||
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
|
||||
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
|
||||
}
|
||||
if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) {
|
||||
sh 'ninja check-builder'
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1414,6 +1378,10 @@ pipeline {
|
||||
name: "BUILD_GFX12",
|
||||
defaultValue: true,
|
||||
description: "Build CK and run tests on gfx12 (default: ON)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX1250",
|
||||
defaultValue: true,
|
||||
description: "Build CK for gfx1250 (default: ON)")
|
||||
booleanParam(
|
||||
name: "NINJA_BUILD_TRACE",
|
||||
defaultValue: true,
|
||||
@@ -2052,6 +2020,7 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
/*
|
||||
stage("Build CK and run Tests on gfx1010")
|
||||
{
|
||||
when {
|
||||
@@ -2068,6 +2037,7 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
*/
|
||||
stage("Build CK and run Tests on gfx1030")
|
||||
{
|
||||
when {
|
||||
@@ -2116,6 +2086,21 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK for gfx1250")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.BUILD_GFX1250.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1250" -DDISABLE_DL_KERNELS="ON" """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:npi-mi450-latest", config_targets: "install", no_reboot:true, build_type: 'Release', prefixpath: '/usr/local')
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
}
|
||||
post {
|
||||
always {
|
||||
@@ -2194,3 +2179,4 @@ pipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllv
|
||||
example_compile_options(example_gemm_xdl_fp8_v3 PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_xdl_bf16_v3 PRIVATE ${GEMM_OPTIONS})
|
||||
|
||||
list(APPEND gpu_list gfx942 gfx950 gfx1200 gfx1201 gfx12-generic)
|
||||
list(APPEND gpu_list gfx942 gfx950 gfx1200 gfx1201 gfx12-generic gfx1250)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
@@ -82,7 +82,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
|
||||
|
||||
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
|
||||
|
||||
list(APPEND gpu_list gfx90a gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1250)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
@@ -95,7 +95,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1200 gfx1201 gfx12-generic)
|
||||
list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1200 gfx1201 gfx12-generic gfx1250)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
@@ -141,6 +141,9 @@ add_example_executable(example_gemm_wmma_bf16_pk_i4_v3 gemm_wmma_bf16_pk_i4_v3.c
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16_pk_i4_v3)
|
||||
add_example_executable(example_gemm_wmma_fp8_v3 gemm_wmma_fp8_v3.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_v3)
|
||||
add_example_executable(example_gemm_wmma_fp8_v3_reg_spill gemm_wmma_fp8_v3_reg_spill.cpp)
|
||||
example_compile_options(example_gemm_wmma_fp8_v3_reg_spill PRIVATE "SHELL: -Rpass-analysis=kernel-resource-usage ")
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_v3_reg_spill)
|
||||
add_example_executable(example_gemm_wmma_fp16_v3 gemm_wmma_fp16_v3.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_v3)
|
||||
add_example_executable(example_gemm_wmma_fp16_pk_i4_v3 gemm_wmma_fp16_pk_i4_v3.cpp)
|
||||
@@ -149,6 +152,11 @@ add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3)
|
||||
add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale)
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx125")
|
||||
add_example_executable(example_gemm_xdl_bf16_v3_prefetch gemm_xdl_bf16_v3_prefetch.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3_prefetch)
|
||||
endif()
|
||||
add_example_executable(example_gemm_wmma_fp8_bpreshuffle gemm_wmma_fp8_bpreshuffle.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_bpreshuffle)
|
||||
add_example_executable(example_gemm_wmma_fp16_bpreshuffle gemm_wmma_fp16_bpreshuffle.cpp)
|
||||
|
||||
@@ -29,7 +29,7 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
128,
|
||||
128, 64, 64,
|
||||
128, 64, 128,
|
||||
16, 16, // AK1, BK1
|
||||
16, 16,
|
||||
4, 2,
|
||||
@@ -42,7 +42,6 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf
|
||||
ComputeTypeA, ComputeTypeB>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceComputeType = ck::f8_t;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
@@ -50,8 +49,8 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
ReferenceComputeType,
|
||||
ReferenceComputeType>;
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
#include "run_gemm_example_v2.inc"
|
||||
|
||||
|
||||
113
example/01_gemm/gemm_wmma_fp8_v3_reg_spill.cpp
Normal file
113
example/01_gemm/gemm_wmma_fp8_v3_reg_spill.cpp
Normal file
@@ -0,0 +1,113 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* \brief Example of GEMM using WMMA that illustrates register spilling on gfx1200 architecture but
|
||||
* no register spilling on gfx1250.
|
||||
*
|
||||
* This example demonstrates how more registers available on the gfx1250 architecture can help avoid
|
||||
* register spilling that occurs on gfx1200 when using a specific GEMM configuration.
|
||||
*
|
||||
* This example must be compiled with the following flag to see the resource allocations:
|
||||
* "-Rpass-analysis=kernel-resource-usage"
|
||||
*
|
||||
* On gfx1200, the kernel will show register spilling due to limited VGPRs:
|
||||
* \verbatim
|
||||
* TotalSGPRs: 105
|
||||
* VGPRs: 256
|
||||
* ScratchSize [bytes/lane]: 56
|
||||
* Dynamic Stack: False
|
||||
* Occupancy [waves/SIMD]: 5
|
||||
* SGPRs Spill: 0
|
||||
* VGPRs Spill: 15
|
||||
* LDS Size [bytes/block]: 32768
|
||||
*
|
||||
* gfx1201 - AMD Radeon RX 9070 XT
|
||||
* Problem {M:3840, N:4096, K:4096, SA:4096, SB:4096, SC:4096, MP:3840, NP:4096, KRead:4096,
|
||||
* KP:4096, AK0:512, BK0:512, MBlock: 30, NBlock: 32}
|
||||
*
|
||||
* Perf: 0.882764 ms, 145.961 TFlops, 72.4578 GB/s, DeviceGemm_Wmma_CShuffleV3<Default, RCR>
|
||||
* BlkSize: 128, BlkTile: 128x128x128, WaveTile: 16x16, WaveMap: 4x4, VmemReadVec: 8x8,
|
||||
* BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1, BlkGemmPipelinePrefetchStages:
|
||||
* 1, KPack: 16
|
||||
* \endverbatim
|
||||
*
|
||||
* On gfx1250, the same kernel will not show register spilling due to increased VGPRs:
|
||||
* \verbatim
|
||||
* TotalSGPRs: 32
|
||||
* VGPRs: 318
|
||||
* ScratchSize [bytes/lane]: 0
|
||||
* Dynamic Stack: False
|
||||
* Occupancy [waves/SIMD]: 3
|
||||
* SGPRs Spill: 0
|
||||
* VGPRs Spill: 0
|
||||
* LDS Size [bytes/block]: 32768
|
||||
* \endverbatim
|
||||
*
|
||||
* \note The register allocations above can be influenced by compiler version and code
|
||||
* changes/optimizations.
|
||||
*/
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using CDataType = ck::bhalf_t;
|
||||
using ComputeTypeA = ck::f8_t;
|
||||
using ComputeTypeB = ck::f8_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
128, //blocksize
|
||||
128, 128, 128, // M/N/KPerBlock
|
||||
8, 8, // AK1, BK1
|
||||
16, 16, //MPerWmma, NPerWmma
|
||||
4, 4, //MRepeat, NRepeat
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,//
|
||||
2, 8, 8, 0,
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
1, 1, S<1, 32, 1, 4>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1,
|
||||
ComputeTypeA, ComputeTypeB>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
#include "run_gemm_example_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(!ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "This kernel support gfx12 only" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
return !run_gemm_splitk_example(argc, argv);
|
||||
}
|
||||
325
example/01_gemm/gemm_xdl_bf16_v3_prefetch.cpp
Normal file
325
example/01_gemm/gemm_xdl_bf16_v3_prefetch.cpp
Normal file
@@ -0,0 +1,325 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::bhalf_t;
|
||||
using BDataType = ck::bhalf_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using CDataType = ck::bhalf_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
#if 1
|
||||
static const uint32_t AB_K1 = 8;
|
||||
|
||||
// clang-format off
|
||||
template <bool UseDataCachePrefetch>
|
||||
using DeviceGemmV3Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
128,
|
||||
256, 256,
|
||||
64, AB_K1, AB_K1,
|
||||
16, 16,
|
||||
8, 16,
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, AB_K1, AB_K1, 0,
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, AB_K1, AB_K1, 0,
|
||||
2, 4, S<1, 8, 1, 16>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3,
|
||||
CDataType,
|
||||
CDataType,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
UseDataCachePrefetch>;
|
||||
// clang-format on
|
||||
#else
|
||||
// prefetch is faster on these params
|
||||
// clang-format off
|
||||
template <bool UseDataCachePrefetch>
|
||||
using DeviceGemmV3Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
256,
|
||||
128, 128,
|
||||
64, 8, 8,
|
||||
16, 16,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
1, 2, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3,
|
||||
CDataType,
|
||||
CDataType,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
UseDataCachePrefetch>;
|
||||
// clang-format on
|
||||
#endif
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
template <typename GemmInstanceType, typename ProblemType>
|
||||
std::pair<bool, float> run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
auto KBatch = problem_size.KBatch;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1 || stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
case 3:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
DeviceMem workspace;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = GemmInstanceType{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
float ave_time = 0;
|
||||
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return std::make_pair(true, ave_time);
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if((config.do_verification == 1) || (config.do_verification == 3))
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 100, true, 4});
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
return std::make_pair(pass, ave_time);
|
||||
}
|
||||
|
||||
bool parse_cmd_args(int argc,
|
||||
char* argv[],
|
||||
ProblemSizeSplitK& problem_size,
|
||||
ExecutionConfig& config,
|
||||
bool& compareWithNonDataCachePrefetchImpl)
|
||||
{
|
||||
compareWithNonDataCachePrefetchImpl = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4 || argc >= 10)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
if(argc >= 10)
|
||||
{
|
||||
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
|
||||
problem_size.StrideA = std::stoi(argv[7]);
|
||||
problem_size.StrideB = std::stoi(argv[8]);
|
||||
problem_size.StrideC = std::stoi(argv[9]);
|
||||
|
||||
if(argc >= 11)
|
||||
{
|
||||
problem_size.KBatch = std::stoi(argv[10]);
|
||||
if(argc > 12)
|
||||
{
|
||||
compareWithNonDataCachePrefetchImpl = std::stoi(argv[11]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr
|
||||
<< "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)"
|
||||
<< std::endl
|
||||
<< "arg10: KBatch" << std::endl
|
||||
<< "arg11: compareWithNonDataCachePrefetchImpl (0=no, 1=yes)" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSizeSplitK problem_size;
|
||||
ExecutionConfig config;
|
||||
bool compareWithNonDataCachePrefetchImpl;
|
||||
|
||||
if(!parse_cmd_args(argc, argv, problem_size, config, compareWithNonDataCachePrefetchImpl))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto [pass, ave_time] = run_gemm<DeviceGemmV3Instance<true>>(problem_size, config);
|
||||
|
||||
if(compareWithNonDataCachePrefetchImpl)
|
||||
{
|
||||
auto [pass2, ave_time2] = run_gemm<DeviceGemmV3Instance<false>>(problem_size, config);
|
||||
std::cout << "DataCache Prefetching enabled ave_time: " << ave_time << " ms" << std::endl;
|
||||
std::cout << "DataCache Prefetching disabled ave_time: " << ave_time2 << " ms" << std::endl;
|
||||
float speedup = ave_time2 / ave_time;
|
||||
std::cout << "On average kernel with DataCache prefetching is " << speedup
|
||||
<< " times faster than without DataCache prefetching." << std::endl;
|
||||
|
||||
if(speedup < 1.0f)
|
||||
std::cout << "WARNING: Kernel with DataCache prefetching is slower!" << std::endl;
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -28,10 +28,10 @@ using DeviceGemmV2_Streamk_Instance =
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
64,
|
||||
16, 16,
|
||||
32, 32,
|
||||
256, 8, 16,
|
||||
16, 16,
|
||||
1, 1,
|
||||
2, 2,
|
||||
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
|
||||
49
example/01_gemm/gemm_xdl_fp32_v3.cpp
Normal file
49
example/01_gemm/gemm_xdl_fp32_v3.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#define CK_ENABLE_DYNAMIC_WARP_SIZE 1
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = float;
|
||||
using BDataType = float;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = float;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
128,
|
||||
64, 64,
|
||||
64, 4, 4,
|
||||
16, 16,
|
||||
2, 4,
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 2, 2, 0,
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 2, 2, 0,
|
||||
1, 2, S<1, 32, 1, 4>, 2,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
49
example/01_gemm/gemm_xdl_fp64_v3.cpp
Normal file
49
example/01_gemm/gemm_xdl_fp64_v3.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#define CK_ENABLE_DYNAMIC_WARP_SIZE 1
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = double;
|
||||
using BDataType = double;
|
||||
using AccDataType = double;
|
||||
using CShuffleDataType = double;
|
||||
using CDataType = double;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
128,
|
||||
64, 64,
|
||||
64, 4, 4,
|
||||
16, 16,
|
||||
2, 4,
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 2, 2, 0,
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 2, 2, 0,
|
||||
1, 2, S<1, 32, 1, 4>, 2,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
@@ -29,8 +29,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipBLds
|
||||
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| SrcScalar| buffer| SrcDstVectorDim| DstScalar|
|
||||
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| size | | PerVector|
|
||||
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if 0
|
||||
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 8, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 8, 8, 7, 1>;
|
||||
#if 0
|
||||
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 8, 8, 7, 1>;
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
@@ -181,7 +181,7 @@ int main(int argc, char* argv[])
|
||||
exit(0);
|
||||
}
|
||||
|
||||
bool is_supported = ck::is_gfx11_supported();
|
||||
bool is_supported = ck::is_gfx11_supported() || ck::is_gfx125_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
|
||||
@@ -181,7 +181,7 @@ int main(int argc, char* argv[])
|
||||
exit(0);
|
||||
}
|
||||
|
||||
bool is_supported = ck::is_gfx11_supported();
|
||||
bool is_supported = ck::is_gfx11_supported() || ck::is_gfx125_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
|
||||
@@ -19,7 +19,7 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
list(APPEND gpu_list gfx90a gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1250)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -47,14 +47,14 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1, //
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
128, // NPerBlock
|
||||
64, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
|
||||
@@ -49,9 +49,9 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
128, // KPerBlock
|
||||
32, // AK1
|
||||
32, // BK1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
|
||||
@@ -50,9 +50,9 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
128, // KPerBlock
|
||||
32, // AK1
|
||||
32, // BK1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
|
||||
@@ -47,14 +47,14 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1, //
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
128, // NPerBlock
|
||||
64, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
|
||||
@@ -48,14 +48,14 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1, //
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
128, // NPerBlock
|
||||
128, // KPerBlock
|
||||
32, // AK1
|
||||
32, // BK1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
|
||||
@@ -51,10 +51,10 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
16, // KPerBlock
|
||||
4, // AK1
|
||||
4, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -72,13 +72,13 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
1,
|
||||
S<1, 16, 1, 16>,
|
||||
4>;
|
||||
2>;
|
||||
|
||||
#include "run_convnd_fwd_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -51,9 +51,9 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
128, // KPerBlock
|
||||
32, // AK1
|
||||
32, // BK1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
|
||||
@@ -50,9 +50,9 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
128, // KPerBlock
|
||||
32, // AK1
|
||||
32, // BK1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
|
||||
@@ -17,7 +17,7 @@ using RsDataType = ck::Tuple<R0DataType>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ static constexpr auto ConvSpec =
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto KPerBlock = sizeof(ADataType) == 1 ? 64 : 32;
|
||||
// clang-format off
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceInstance =
|
||||
@@ -34,9 +35,9 @@ using DeviceInstance =
|
||||
//######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| Specialization| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| |
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
< NDimSpatial, ALayout<NDimSpatial>, BLayout<NDimSpatial>, DELayout<NDimSpatial>, RLayout<NDimSpatial>, KernelADataType, KernelBDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
|
||||
< NDimSpatial, ALayout<NDimSpatial>, BLayout<NDimSpatial>, DELayout<NDimSpatial>, RLayout<NDimSpatial>, KernelADataType, KernelBDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, KPerBlock, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
|
||||
#else
|
||||
< NDimSpatial, ALayout<NDimSpatial>, BLayout<NDimSpatial>, DELayout<NDimSpatial>, RLayout<NDimSpatial>, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, 1>;
|
||||
< NDimSpatial, ALayout<NDimSpatial>, BLayout<NDimSpatial>, DELayout<NDimSpatial>, RLayout<NDimSpatial>, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, KPerBlock, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, 1>;
|
||||
#endif
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
@@ -292,15 +293,15 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
|
||||
|
||||
conv_output_device_buf.FromDevice(conv_output_device.mData.data());
|
||||
r0_device_buf.FromDevice(r0_device.mData.data());
|
||||
|
||||
auto rtol = ck::is_same_v<EDataType, BF16> ? 1e-1f : 1e-3f;
|
||||
auto pass = ck::utils::check_err(conv_output_device,
|
||||
conv_output_host,
|
||||
"Error: incorrect results! (Matrix E)",
|
||||
1e-3f,
|
||||
rtol,
|
||||
1e-3f);
|
||||
pass =
|
||||
pass && ck::utils::check_err(
|
||||
r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-3f, 1e-3f);
|
||||
r0_device, r0_host, "Error: incorrect results! (Matrix R0)", rtol, 1e-3f);
|
||||
if(pass)
|
||||
std::cout << "Verification on CPU: PASS" << std::endl;
|
||||
|
||||
|
||||
@@ -328,7 +328,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
|
||||
|
||||
if(argc == 5)
|
||||
if(argc == 1) {}
|
||||
else if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
|
||||
@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
@@ -302,7 +302,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
if(argc == 5)
|
||||
if(argc == 1) {}
|
||||
else if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
|
||||
@@ -61,7 +61,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
@@ -287,7 +287,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
}
|
||||
@@ -302,7 +301,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
if(argc == 5)
|
||||
if(argc == 1) {}
|
||||
else if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
|
||||
@@ -59,7 +59,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSpl
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
@@ -71,7 +71,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 1) {}
|
||||
else if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
|
||||
@@ -69,10 +69,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip
|
||||
16, // KPerBlock
|
||||
4, // AK1
|
||||
4, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
8, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
|
||||
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
|
||||
@@ -89,7 +89,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
|
||||
S<32, 8>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
|
||||
4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock
|
||||
1>; // RThread DstScalarPerVector _MPerBlock
|
||||
// clang-format on
|
||||
@@ -121,25 +121,21 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
// do nothing
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
else if(argc == 4 || argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
if(argc == 10)
|
||||
{
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideE = std::stoi(argv[9]);
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideE = std::stoi(argv[9]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -147,10 +143,10 @@ int main(int argc, char* argv[])
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: Measure kernel execution time (1=ON, 0=Off)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
|
||||
exit(0);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -76,10 +76,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip
|
||||
16, // KPerBlock
|
||||
4, // AK1
|
||||
4, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
8, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
|
||||
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
|
||||
@@ -96,7 +96,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
|
||||
S<32, 8>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
|
||||
4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock
|
||||
1>; // RThread DstScalarPerVector _MPerBlock
|
||||
// clang-format on
|
||||
@@ -127,25 +127,21 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
// do nothing
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
else if(argc == 4 || argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
if(argc == 10)
|
||||
{
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideE = std::stoi(argv[9]);
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideE = std::stoi(argv[9]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -154,10 +150,10 @@ int main(int argc, char* argv[])
|
||||
<< " arg3: Measure kernel execution time (1=ON, 0=Off)\n"
|
||||
<< " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"
|
||||
<< std::endl;
|
||||
exit(EXIT_SUCCESS);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
|
||||
{
|
||||
exit(EXIT_SUCCESS);
|
||||
}
|
||||
|
||||
@@ -4,4 +4,6 @@
|
||||
add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp)
|
||||
add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp)
|
||||
add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp)
|
||||
add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp)
|
||||
add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp)
|
||||
add_example_executable(example_elementwise_tanh_1d elementwise_tanh_1d.cpp)
|
||||
add_example_executable(example_elementwise_fastgelu_1d elementwise_fastgelu_1d.cpp)
|
||||
|
||||
130
example/19_binary_elementwise/elementwise_fastgelu_1d.cpp
Normal file
130
example/19_binary_elementwise/elementwise_fastgelu_1d.cpp
Normal file
@@ -0,0 +1,130 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
using ADataType = F16;
|
||||
using CDataType = F16;
|
||||
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
|
||||
using DeviceElementwiseFastGeluInstance =
|
||||
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>,
|
||||
ck::Tuple<CDataType>,
|
||||
FastGelu,
|
||||
1,
|
||||
64,
|
||||
16,
|
||||
16,
|
||||
2,
|
||||
2,
|
||||
ck::Sequence<1, 0>,
|
||||
ck::Sequence<1>,
|
||||
ck::Sequence<1>>;
|
||||
|
||||
template <typename HostTensorA, typename HostTensorC, typename Functor>
|
||||
void host_elementwise1D(HostTensorC& C, const HostTensorA& A, int M, Functor functor)
|
||||
{
|
||||
using ctype = ck::remove_reference_t<decltype(C(0))>;
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
auto Am = A(m);
|
||||
ctype Cm = 0;
|
||||
functor(Cm, Am);
|
||||
C(m) = Cm;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification;
|
||||
bool time_kernel;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
do_verification = true;
|
||||
time_kernel = false;
|
||||
}
|
||||
else if(argc == 3)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = static_cast<bool>(std::stoi(argv[2]));
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: time kernel (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::index_t M = 1024;
|
||||
|
||||
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
|
||||
return HostTensorDescriptor({len}, {stride});
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<CDataType> c_m(f_host_tensor_descriptor1d(M, 1));
|
||||
|
||||
a_m.GenerateTensorValue(GeneratorTensor_3<ADataType>{-5, 5});
|
||||
|
||||
DeviceMem a_m_device_buf(sizeof(ADataType) * a_m.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_m_device_buf.ToDevice(a_m.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_m_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {c_m_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 1> abc_lengths = {M};
|
||||
std::array<ck::index_t, 1> a_strides = {1};
|
||||
std::array<ck::index_t, 1> c_strides = {1};
|
||||
|
||||
auto broadcastFastGelu = DeviceElementwiseFastGeluInstance{};
|
||||
auto argument = broadcastFastGelu.MakeArgumentPointer(
|
||||
abc_lengths, {a_strides}, {c_strides}, input, output, FastGelu{});
|
||||
|
||||
if(!broadcastFastGelu.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The runtime parameters seems not supported by the device instance, exiting!");
|
||||
};
|
||||
|
||||
auto broadcastFastGelu_invoker_ptr = broadcastFastGelu.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastFastGelu_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(do_verification)
|
||||
{
|
||||
c_m_device_buf.FromDevice(c_m.mData.data());
|
||||
Tensor<CDataType> host_c_m(f_host_tensor_descriptor1d(M, 1));
|
||||
|
||||
host_elementwise1D<Tensor<ADataType>, Tensor<CDataType>, FastGelu>(
|
||||
host_c_m, a_m, M, FastGelu{});
|
||||
|
||||
pass &= ck::utils::check_err(c_m, host_c_m, "Error: Incorrect results c", 4e-3, 4e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
129
example/19_binary_elementwise/elementwise_tanh_1d.cpp
Normal file
129
example/19_binary_elementwise/elementwise_tanh_1d.cpp
Normal file
@@ -0,0 +1,129 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
using ADataType = F16;
|
||||
using CDataType = F16;
|
||||
|
||||
using Tanh = ck::tensor_operation::element_wise::TanH;
|
||||
|
||||
using DeviceElementwiseTanhInstance =
|
||||
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>,
|
||||
ck::Tuple<CDataType>,
|
||||
Tanh,
|
||||
1,
|
||||
64,
|
||||
16,
|
||||
16,
|
||||
2,
|
||||
2,
|
||||
ck::Sequence<1, 0>,
|
||||
ck::Sequence<1>,
|
||||
ck::Sequence<1>>;
|
||||
|
||||
template <typename HostTensorA, typename HostTensorC, typename Functor>
|
||||
void host_elementwise1D(HostTensorC& C, const HostTensorA& A, int M, Functor functor)
|
||||
{
|
||||
using ctype = ck::remove_reference_t<decltype(C(0))>;
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
auto Am = A(m);
|
||||
ctype Cm = 0;
|
||||
functor(Cm, Am);
|
||||
C(m) = Cm;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification;
|
||||
bool time_kernel;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
do_verification = true;
|
||||
time_kernel = false;
|
||||
}
|
||||
else if(argc == 3)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = static_cast<bool>(std::stoi(argv[2]));
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: time kernel (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::index_t M = 1024;
|
||||
|
||||
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
|
||||
return HostTensorDescriptor({len}, {stride});
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<CDataType> c_m(f_host_tensor_descriptor1d(M, 1));
|
||||
|
||||
a_m.GenerateTensorValue(GeneratorTensor_3<ADataType>{-5, 5});
|
||||
|
||||
DeviceMem a_m_device_buf(sizeof(ADataType) * a_m.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_m_device_buf.ToDevice(a_m.mData.data());
|
||||
|
||||
std::array<const void*, 1> input = {a_m_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {c_m_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck::index_t, 1> abc_lengths = {M};
|
||||
std::array<ck::index_t, 1> a_strides = {1};
|
||||
std::array<ck::index_t, 1> c_strides = {1};
|
||||
|
||||
auto broadcastTanh = DeviceElementwiseTanhInstance{};
|
||||
auto argument = broadcastTanh.MakeArgumentPointer(
|
||||
abc_lengths, {a_strides}, {c_strides}, input, output, Tanh{});
|
||||
|
||||
if(!broadcastTanh.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The runtime parameters seems not supported by the device instance, exiting!");
|
||||
};
|
||||
|
||||
auto broadcastTanh_invoker_ptr = broadcastTanh.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastTanh_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(do_verification)
|
||||
{
|
||||
c_m_device_buf.FromDevice(c_m.mData.data());
|
||||
Tensor<CDataType> host_c_m(f_host_tensor_descriptor1d(M, 1));
|
||||
|
||||
host_elementwise1D<Tensor<ADataType>, Tensor<CDataType>, Tanh>(host_c_m, a_m, M, Tanh{});
|
||||
|
||||
pass &= ck::utils::check_err(c_m, host_c_m, "Error: Incorrect results c", 4e-3, 4e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -43,29 +43,29 @@ using DeviceConvBwdWeightInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
64, // K0PerBlock
|
||||
16, // K1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
S<1, 64, 1, 2>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
// clang-format off
|
||||
|
||||
|
||||
@@ -41,29 +41,29 @@ using DeviceConvBwdWeightInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
64, // K0PerBlock
|
||||
16, // K1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
false, // ABlockLdsAddExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
S<1, 64, 1, 2>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
|
||||
@@ -158,8 +158,8 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
|
||||
|
||||
int main()
|
||||
{
|
||||
// temp disable on gfx11
|
||||
if(ck::is_gfx11_supported())
|
||||
// temp disable on gfx11 & gfx12
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -69,14 +69,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
CDEElementOp,
|
||||
GemmDefault,
|
||||
256, // BlockSize
|
||||
256, // MPerBlock
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
8, // MXdlPerWave
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
|
||||
@@ -66,15 +66,15 @@ using DeviceBatchedGemmV2Instance =
|
||||
ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
256, Scale_Block_N, Scale_Block_K,
|
||||
16, 64,
|
||||
32, 64,
|
||||
KPerBlock, 8, 32,
|
||||
16, 16,
|
||||
1, 1,
|
||||
1, 2,
|
||||
S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 32, 32, 0,
|
||||
1, 1, S<1, 16, 1, 8>, 8,
|
||||
1, 1, S<1, 16, 1, 8>, 4,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -523,7 +523,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
|
||||
bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
@@ -536,7 +536,7 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
|
||||
|
||||
problem_size.M = 128 * (dis(gen) + 1);
|
||||
problem_size.N = 128 * (dis(gen) + 1);
|
||||
problem_size.K = 256 * (dis(gen) + 2);
|
||||
problem_size.K = 256 * (dis(gen) + 4);
|
||||
|
||||
problem_size.batch_count = 2;
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ using DeviceOpInstanceKK_Generic = ck::tensor_operation::device::
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimM,
|
||||
@@ -59,7 +59,7 @@ using DeviceOpInstanceKN_Generic = ck::tensor_operation::device::
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimM,
|
||||
@@ -81,7 +81,7 @@ using DeviceOpInstanceMK_Generic = ck::tensor_operation::device::
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimM,
|
||||
@@ -103,7 +103,7 @@ using DeviceOpInstanceMN_Generic = ck::tensor_operation::device::
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
// Fp64 instances.
|
||||
@@ -126,7 +126,7 @@ using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device::
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimM,
|
||||
@@ -148,7 +148,7 @@ using DeviceOpInstanceKN_FP64 = ck::tensor_operation::device::
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimM,
|
||||
@@ -170,7 +170,7 @@ using DeviceOpInstanceMK_FP64 = ck::tensor_operation::device::
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimM,
|
||||
|
||||
@@ -370,3 +370,84 @@ inline HostTensorDescriptor make_output_descriptor(const ck::utils::conv::ConvPa
|
||||
|
||||
throw std::runtime_error("unsuppored # dim spatial");
|
||||
}
|
||||
template <typename DataType>
|
||||
inline __host__ __device__ constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 1e-1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
inline __host__ __device__ constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 16.1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 8192.1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -32,6 +32,8 @@ using BiasLayout = typename LayoutSettingSelector<NDimSpatial>::BiasLayout;
|
||||
template <ck::index_t NDimSpatial>
|
||||
using ResidualLayout = typename LayoutSettingSelector<NDimSpatial>::ResidualLayout;
|
||||
|
||||
static constexpr auto KPerBlock = sizeof(InKernelDataType) == 1 ? 64 : 32;
|
||||
|
||||
// instance for double rate mfma on gfx950 (vs gfx942)
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvFwdInstance2 =
|
||||
@@ -105,7 +107,7 @@ using DeviceConvFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
4, // AK1
|
||||
4, // BK1
|
||||
16, // MPerXdl
|
||||
@@ -333,11 +335,17 @@ bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config,
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
const Tensor<OutUserDataType> out_device_converted(out_device);
|
||||
|
||||
return ck::utils::check_err(
|
||||
out_device_converted, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
return ck::utils::check_err(out_device_converted,
|
||||
out_host,
|
||||
"Error: incorrect results!",
|
||||
get_rtol<OutUserDataType>(),
|
||||
get_atol<OutUserDataType>());
|
||||
#else
|
||||
return ck::utils::check_err(
|
||||
out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
return ck::utils::check_err(out_device,
|
||||
out_host,
|
||||
"Error: incorrect results!",
|
||||
get_rtol<OutUserDataType>(),
|
||||
get_atol<OutUserDataType>());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -23,14 +23,14 @@ using DeviceConvFwdInstance =
|
||||
1, //
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
16, // KPerBlock
|
||||
128, // NPerBlock
|
||||
64, // KPerBlock
|
||||
4, // AK1
|
||||
4, // BK1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
@@ -104,7 +104,6 @@ bool run_grouped_conv_fwd(const ExecutionConfig& config,
|
||||
in_device_buf.ToDevice(in.mData.data());
|
||||
wei_device_buf.ToDevice(wei.mData.data());
|
||||
#endif
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
|
||||
@@ -128,7 +127,6 @@ bool run_grouped_conv_fwd(const ExecutionConfig& config,
|
||||
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
copy(conv_param.input_right_pads_, input_right_pads);
|
||||
|
||||
// do Conv
|
||||
auto conv = DeviceConvFwdInstance<NDimSpatial>{};
|
||||
auto invoker = conv.MakeInvoker();
|
||||
@@ -151,7 +149,6 @@ bool run_grouped_conv_fwd(const ExecutionConfig& config,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
if(!conv.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
|
||||
@@ -87,11 +87,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X
|
||||
4, // AK1
|
||||
4, // BK1
|
||||
1, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
2, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
8, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
@@ -114,7 +114,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X
|
||||
1,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
4, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 16, 1, 16>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
@@ -138,7 +138,7 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -7,6 +7,10 @@ add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_f
|
||||
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp)
|
||||
add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp)
|
||||
if(GPU_TARGETS MATCHES "gfx125")
|
||||
target_compile_definitions(example_self_attention_forward_wmma_fp16 PRIVATE USE_GFX125_CONFIG=1)
|
||||
target_compile_definitions(example_cross_attention_forward_wmma_fp16 PRIVATE USE_GFX125_CONFIG=1)
|
||||
endif()
|
||||
|
||||
add_custom_target(example_gemm_scale_softmax_gemm)
|
||||
|
||||
|
||||
@@ -33,10 +33,10 @@ using DeviceGemmV2Instance =
|
||||
ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmDefault,
|
||||
256,
|
||||
128, 256, 64,
|
||||
128, 128, 64,
|
||||
8, 8,
|
||||
16, 16,
|
||||
4, 4,
|
||||
4, 2,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, true,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
|
||||
@@ -55,7 +55,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>;
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_splitK_gemm_example.inc"
|
||||
|
||||
@@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_splitK_gemm_example.inc"
|
||||
|
||||
@@ -54,14 +54,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 4, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 2>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_splitK_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>;
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 32, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_splitK_gemm_example.inc"
|
||||
|
||||
@@ -64,7 +64,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>;
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>;
|
||||
// clang-format on
|
||||
|
||||
#else
|
||||
@@ -85,7 +85,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlS
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
|
||||
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple<BiasDataType>, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
< NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple<BiasDataType>, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 128, 64, 16, 16, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_conv_bwd_data_bias_relu_example.inc"
|
||||
|
||||
@@ -31,7 +31,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
|
||||
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| Scheduler| Type| Type|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>, BlkGemmPipeSched,BlkGemmPipelineVer, AComputeType, BComputeType , false , false>;
|
||||
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>, BlkGemmPipeSched,BlkGemmPipelineVer, AComputeType, BComputeType , false , false>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_conv_bwd_data_example.inc"
|
||||
|
||||
@@ -26,7 +26,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
|
||||
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 128, 64, 16, 16, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_conv_bwd_data_example.inc"
|
||||
|
||||
@@ -30,7 +30,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
|
||||
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| Scheduler| Type| Type|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, AComputeType, BComputeType>;
|
||||
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 64, 16, 16, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, AComputeType, BComputeType>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_conv_bwd_data_example.inc"
|
||||
|
||||
@@ -8,6 +8,5 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
|
||||
endif()
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
|
||||
|
||||
|
||||
@@ -77,11 +77,11 @@ using DeviceBatchedGemmGemmInstance =
|
||||
4, // AK1
|
||||
4, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
2, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
8, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
@@ -106,13 +106,13 @@ using DeviceBatchedGemmGemmInstance =
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 16, 1, 16>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
2>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
#include "run_grouped_conv_conv_fwd_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ using DeviceBatchedGemmGemmInstance =
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
#include "run_grouped_conv_conv_fwd_example.inc"
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl
|
||||
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
|
||||
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -72,7 +72,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl
|
||||
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
|
||||
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -71,7 +71,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl
|
||||
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
|
||||
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
|
||||
@@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
64, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXdl
|
||||
|
||||
@@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
|
||||
@@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
|
||||
@@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
|
||||
@@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
|
||||
@@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
64, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXdl
|
||||
|
||||
@@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
64, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXdl
|
||||
|
||||
@@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
64, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXdl
|
||||
|
||||
@@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
64, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXdl
|
||||
|
||||
@@ -53,7 +53,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
|
||||
@@ -54,7 +54,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
64, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXdl
|
||||
|
||||
@@ -48,7 +48,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
|
||||
@@ -49,7 +49,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
64, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXdl
|
||||
|
||||
@@ -49,7 +49,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
64, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXdl
|
||||
|
||||
@@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
64, // BlockSize
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
|
||||
@@ -53,7 +53,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
64, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXdl
|
||||
|
||||
@@ -26,9 +26,9 @@ using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<D
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
|
||||
{
|
||||
std::cout << "FP32 are not supported on gfx11 and gfx12" << std::endl;
|
||||
std::cout << "FP32 are not supported on gfx11 and gfx120x" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
// static constexpr auto KPerBlock = sizeof(InDataType) == 1 ? 64 : 32;
|
||||
|
||||
#ifdef EXAMPLE_USE_WMMA
|
||||
template <typename DataType,
|
||||
@@ -68,32 +69,32 @@ using DeviceGroupedConvNDMultiABFwdInstance =
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MWmmaPerWave
|
||||
4, // NWmmaPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
sizeof(DataType) == 1 ? 64 : 32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MWmmaPerWave
|
||||
4, // NWmmaPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
@@ -123,33 +124,33 @@ using DeviceGroupedConvNDMultiABFwdInstance =
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
1, //
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
1, //
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
sizeof(DataType) == 1 ? 64 : 32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
|
||||
@@ -20,7 +20,8 @@ add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_bl
|
||||
add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp)
|
||||
add_example_executable(example_moe_gemm1_xdl_fp8_blockscale_splitk moe_gemm1_xdl_fp8_blockscale_splitk.cpp)
|
||||
|
||||
list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx11-generic gfx12-generic)
|
||||
list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx11-generic gfx12-generic gfx1250)
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -49,8 +49,6 @@ using D1Layout = Col;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
static constexpr int KPack = 8;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
@@ -58,6 +56,13 @@ using BElementOp = PassThrough;
|
||||
using CDEElementOp = MultiplyMultiply;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr int KPerBlock = 64;
|
||||
#if defined(CK_USE_GFX1250)
|
||||
static constexpr int KPack = 16;
|
||||
#else
|
||||
static constexpr int KPack = 8;
|
||||
#endif
|
||||
static constexpr auto K0 = KPerBlock / KPack;
|
||||
// clang-format off
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle<
|
||||
@@ -65,12 +70,12 @@ using DeviceOpInstance =
|
||||
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
128,
|
||||
32, 128, 128,
|
||||
8, 8,
|
||||
32, 128, KPerBlock,
|
||||
KPack, KPack,
|
||||
16, 16,
|
||||
2, 2,
|
||||
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
S<K0, 128 / K0, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
S<K0, 128 / K0, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
1, 1, S<1, 16, 1, 8>, S<4, 4, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v1,
|
||||
|
||||
@@ -50,8 +50,6 @@ using D1Layout = Col;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
static constexpr int KPack = 16;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
@@ -63,6 +61,13 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
|
||||
static constexpr ck::index_t Scale_Block_M = 1;
|
||||
static constexpr ck::index_t Scale_Block_N = 128;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
static constexpr int KPerBlock = 128;
|
||||
#if defined(CK_USE_GFX1250)
|
||||
static constexpr int KPack = 32;
|
||||
#else
|
||||
static constexpr int KPack = 16;
|
||||
#endif
|
||||
static constexpr auto K0 = KPerBlock / KPack;
|
||||
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle
|
||||
@@ -71,13 +76,13 @@ using DeviceOpInstance =
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
128, 128, 128,
|
||||
16, 16,
|
||||
128, 128, KPerBlock,
|
||||
KPack, KPack,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
S<K0, 256 / K0, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
S<K0, 256 / K0, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
1, 1,
|
||||
S<1, 32, 1, 8>, S<8>,
|
||||
|
||||
@@ -34,10 +34,9 @@ using F32 = float;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F8;
|
||||
using B0DataType = F8;
|
||||
static constexpr int KPack = 16;
|
||||
using ComputeType = F8;
|
||||
using A0DataType = F8;
|
||||
using B0DataType = F8;
|
||||
using ComputeType = F8;
|
||||
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
@@ -60,7 +59,13 @@ using BElementOp = PassThrough;
|
||||
using CDEElementOp = MultiplyMultiply;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr int KPerBlock = 256;
|
||||
#if defined(CK_USE_GFX1250)
|
||||
static constexpr int KPack = 32;
|
||||
#else
|
||||
static constexpr int KPack = 16;
|
||||
#endif
|
||||
static constexpr auto K0 = KPerBlock / KPack;
|
||||
// clang-format off
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle<
|
||||
@@ -68,12 +73,12 @@ using DeviceOpInstance =
|
||||
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256,
|
||||
32, 128, 256,
|
||||
16, 16,
|
||||
32, 128, KPerBlock,
|
||||
KPack, KPack,
|
||||
16, 16,
|
||||
2, 1,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<K0, 256 / K0, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<K0, 256 / K0, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v1,
|
||||
|
||||
@@ -74,7 +74,14 @@ using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MultiplyMultiply;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto PipelineVer = []() {
|
||||
#if defined(CK_USE_WMMA) && !defined(CK_USE_GFX1250)
|
||||
return ck::BlockGemmPipelineVersion::v1;
|
||||
#else
|
||||
return ck::BlockGemmPipelineVersion::v3;
|
||||
#endif
|
||||
}();
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
|
||||
// clang-format off
|
||||
@@ -88,7 +95,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 2, S<1, 16, 1, 16>, S<8, 8, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, FP8>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -114,7 +114,14 @@ using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MultiplyMultiply;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto PipelineVer = []() {
|
||||
#if defined(CK_USE_WMMA) && !defined(CK_USE_GFX1250)
|
||||
return ck::BlockGemmPipelineVersion::v1;
|
||||
#else
|
||||
return ck::BlockGemmPipelineVersion::v3;
|
||||
#endif
|
||||
}();
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
|
||||
// clang-format off
|
||||
@@ -135,7 +142,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 1, S<1, 32, 1, 8>, S<4, 4, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, I8>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, I8>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -6,9 +6,15 @@ add_custom_target(example_gemm_mx)
|
||||
add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp8_v1 gemm_mx_fp8_v1.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_v1)
|
||||
|
||||
add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_bf8)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp8_bpreshuffle gemm_mx_fp8_bpreshuffle.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bpreshuffle)
|
||||
|
||||
# TODO: Fix RRR
|
||||
# add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
|
||||
# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8)
|
||||
@@ -63,9 +69,81 @@ example_compile_options(example_moe_gemm2_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_M
|
||||
set(FP8_MXGEMM_OPTIONS)
|
||||
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
|
||||
example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_mx_fp8_v1 PRIVATE ${FP8_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_mx_fp8_bpreshuffle PRIVATE ${FP8_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS})
|
||||
|
||||
set(FP6_MXGEMM_OPTIONS)
|
||||
list(APPEND FP6_MXGEMM_OPTIONS -mavx512f)
|
||||
example_compile_options(example_gemm_mx_fp6 PRIVATE ${FP6_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_mx_bf6 PRIVATE ${FP6_MXGEMM_OPTIONS})
|
||||
|
||||
function(add_gemm_mix_prec A_NAME B_NAME A_TYPE B_TYPE)
|
||||
add_example_executable(example_gemm_mx_${A_NAME}_${B_NAME} gemm_mx_fp4.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_${A_NAME}_${B_NAME})
|
||||
|
||||
add_example_executable(example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle gemm_mx_fp4_bpreshuffle.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle)
|
||||
|
||||
example_compile_options(example_gemm_mx_${A_NAME}_${B_NAME} PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
target_compile_definitions(example_gemm_mx_${A_NAME}_${B_NAME} PRIVATE A_DATATYPE=${A_TYPE})
|
||||
target_compile_definitions(example_gemm_mx_${A_NAME}_${B_NAME} PRIVATE B_DATATYPE=${B_TYPE})
|
||||
|
||||
example_compile_options(example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
target_compile_definitions(example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE A_DATATYPE=${A_TYPE})
|
||||
target_compile_definitions(example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE B_DATATYPE=${B_TYPE})
|
||||
endfunction(add_gemm_mix_prec)
|
||||
|
||||
function(add_moe_mix_prec A_NAME B_NAME A_TYPE B_TYPE)
|
||||
add_example_executable(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns moe_gemm1_xdl_mx_fp4_bns.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns)
|
||||
|
||||
add_example_executable(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns moe_gemm2_xdl_mx_fp4_bns.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns)
|
||||
|
||||
add_example_executable(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME} moe_gemm1_xdl_mx_fp4.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME})
|
||||
|
||||
add_example_executable(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME} moe_gemm2_xdl_mx_fp4.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME})
|
||||
|
||||
add_example_executable(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle)
|
||||
|
||||
add_example_executable(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle)
|
||||
|
||||
# mx moe B no-shuffling + scale shuffling
|
||||
example_compile_options(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE A_DATATYPE=${A_TYPE})
|
||||
target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE B_DATATYPE=${B_TYPE})
|
||||
target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE A_DATATYPE=${A_TYPE})
|
||||
target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE B_DATATYPE=${B_TYPE})
|
||||
|
||||
# mx moe B no-shuffling + scale shuffling (async loads)
|
||||
example_compile_options(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME} PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME} PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME} PRIVATE A_DATATYPE=${A_TYPE})
|
||||
target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME} PRIVATE B_DATATYPE=${B_TYPE})
|
||||
target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME} PRIVATE A_DATATYPE=${A_TYPE})
|
||||
target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME} PRIVATE B_DATATYPE=${B_TYPE})
|
||||
|
||||
# mx moe B shuffling + scale shuffling (async loads)
|
||||
example_compile_options(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE A_DATATYPE=${A_TYPE})
|
||||
target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE B_DATATYPE=${B_TYPE})
|
||||
target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE A_DATATYPE=${A_TYPE})
|
||||
target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE B_DATATYPE=${B_TYPE})
|
||||
endfunction(add_moe_mix_prec)
|
||||
|
||||
# mx mixed precsion
|
||||
if(GPU_TARGETS MATCHES "gfx125")
|
||||
add_gemm_mix_prec(fp4 fp8 F4 F8)
|
||||
add_gemm_mix_prec(fp8 fp4 F8 F4)
|
||||
|
||||
add_moe_mix_prec(fp4 fp8 F4 F8)
|
||||
add_moe_mix_prec(fp8 fp4 F8 F4)
|
||||
add_moe_mix_prec(fp8 fp8 F8 F8)
|
||||
endif()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
using ADataType = ck::bf6x16_pk_t;
|
||||
using BDataType = ck::bf6x16_pk_t;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
using ADataType = ck::bf8_t;
|
||||
using BDataType = ck::bf8_t;
|
||||
@@ -44,14 +45,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
128, // BlockSize: Thread block size
|
||||
128, // MPerBlock
|
||||
32, // NPerBlock
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<16, 8, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
@@ -27,9 +26,10 @@ using ::ck::Tensor;
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using MFMA = ck::tensor_layout::gemm::MFMA;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using MFMA = ck::tensor_layout::gemm::MFMA;
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
@@ -116,7 +116,7 @@ bool parse_cmd_args(int argc,
|
||||
}
|
||||
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
void preShuffleScaleBuffer_gfx950(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
@@ -126,8 +126,9 @@ void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, i
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
// On gfx950, WarpSize=64:
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
@@ -163,13 +164,74 @@ void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, i
|
||||
}
|
||||
}
|
||||
|
||||
void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl)
|
||||
/**
|
||||
* Pre-shuffle scale buffer for gfx1250 16x16x128 wmma scale instruction
|
||||
*
|
||||
* @tparam ScaleType Scale data type
|
||||
* @tparam KStride Whether K is the leading dimension of the scale buffer
|
||||
*/
|
||||
template <typename ScaleType, ck::index_t ScaleBlockSize, bool KStride>
|
||||
void preShuffleScaleBuffer_gfx1250(const ScaleType* src,
|
||||
ScaleType* dst,
|
||||
ck::index_t MN,
|
||||
ck::index_t K)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
|
||||
static_assert(ScaleBlockSize == 32 && sizeof(ScaleType) == 1,
|
||||
"wrong! only support 8-bit scale with ScaleBlockSize=32");
|
||||
|
||||
constexpr ck::index_t MPerXdlops = 16;
|
||||
// constexpr ck::index_t NPerXdlops = 16;
|
||||
constexpr ck::index_t KPerXdlops = 128;
|
||||
|
||||
int MNPack = 2; // 2 sets of scales in M/N direction
|
||||
int KPack = 1; // 1 set of scales in K direction
|
||||
|
||||
int MNStep = MPerXdlops;
|
||||
int KStep = KPerXdlops / ScaleBlockSize; // scales per thread
|
||||
|
||||
int K0 = K / KPack / KStep; // KRepeat - how many KStep blocks
|
||||
|
||||
// On gfx1250, WarpSize=32:
|
||||
// -- The 2 16x128 building blocks will be packed into 1 32x128
|
||||
// -- The 4 16x16x128 wmma will be packed into 1 32x32x128
|
||||
|
||||
// unfold the MN32xK(128/32) scale buffer
|
||||
// 4 16 1 2
|
||||
// To KStep -> MNStep -> KPack -> MNPack
|
||||
// or ???
|
||||
// 2 16 1 4
|
||||
// MNPack -> MNStep -> KPack -> KStep
|
||||
for(int mn = 0; mn < MN; ++mn)
|
||||
{
|
||||
int iMNRepeat = mn / (MNStep * MNPack); // i MNRepeat (MN block id)
|
||||
int tempmn = mn % (MNStep * MNPack); // position in MN block
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int iKRepeat = k / (KStep * KPack); // i KRepeat
|
||||
int tempk = k % (KStep * KPack); // position in KStep block
|
||||
|
||||
int outputIndex = (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) +
|
||||
(iKRepeat * KStep * KPack) * (MNStep * MNPack) +
|
||||
tempmn * (KStep * KPack) + tempk;
|
||||
|
||||
if constexpr(KStride)
|
||||
dst[outputIndex] = src[mn * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + mn];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void preShuffleBuffer(const T* src, T* dst, int N, int K, int NXdl)
|
||||
{
|
||||
const int KPack = 16;
|
||||
const int NLane = NXdl;
|
||||
const int KLane = ck::get_warp_size() / NLane;
|
||||
const int K_pk = K / ck::packed_size_v<T>;
|
||||
const int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
@@ -352,7 +414,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
a_m_k_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
|
||||
b_k_n_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
|
||||
break;
|
||||
|
||||
case 2:
|
||||
a_m_k.GenerateTensorDistr(
|
||||
float_distr{-2.0, 2.0}, ck::identity{}, std::minstd_rand(time(nullptr))); // R[-2,2]
|
||||
@@ -369,12 +430,34 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
}
|
||||
}
|
||||
|
||||
preShuffleScaleBuffer<ck::is_same_v<ALayout, Row>>(a_m_k_scale.mData.data(),
|
||||
a_shuffled_scale.mData.data(),
|
||||
Scale_Padded_M,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<BRefLayout, Col>>(
|
||||
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
|
||||
if(ck::get_warp_size() == 64)
|
||||
{
|
||||
preShuffleScaleBuffer_gfx950<ck::is_same_v<ALayout, Row>>(a_m_k_scale.mData.data(),
|
||||
a_shuffled_scale.mData.data(),
|
||||
Scale_Padded_M,
|
||||
K / ScaleBlockSize);
|
||||
|
||||
preShuffleScaleBuffer_gfx950<ck::is_same_v<BRefLayout, Col>>(
|
||||
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
|
||||
}
|
||||
else if(ck::get_warp_size() == 32)
|
||||
{
|
||||
preShuffleScaleBuffer_gfx1250<ck::e8m0_bexp_t, ScaleBlockSize, ck::is_same_v<ALayout, Row>>(
|
||||
a_m_k_scale.mData.data(),
|
||||
a_shuffled_scale.mData.data(),
|
||||
Scale_Padded_M,
|
||||
K / ScaleBlockSize);
|
||||
|
||||
preShuffleScaleBuffer_gfx1250<ck::e8m0_bexp_t,
|
||||
ScaleBlockSize,
|
||||
ck::is_same_v<BRefLayout, Col>>(
|
||||
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! Scale pre-shuffle unsupported warp size");
|
||||
}
|
||||
|
||||
if constexpr(BPreShuffle)
|
||||
{
|
||||
int NPerXdl = 16; // Fixed 16
|
||||
@@ -459,7 +542,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
std::cout << "Done." << std::endl;
|
||||
std::cout << "\nDone." << std::endl;
|
||||
std::cout << "Computing GEMM on host..." << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F8 = ck::f8_t;
|
||||
|
||||
using ADataType = ck::f4x2_pk_t;
|
||||
using BDataType = ck::f4x2_pk_t;
|
||||
#if defined(A_DATATYPE)
|
||||
using ADataType = A_DATATYPE;
|
||||
#else
|
||||
using ADataType = F4;
|
||||
#endif
|
||||
#if defined(B_DATATYPE)
|
||||
using BDataType = B_DATATYPE;
|
||||
#else
|
||||
using BDataType = F4;
|
||||
#endif
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t;
|
||||
@@ -21,7 +32,8 @@ using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t DataPackedSize =
|
||||
ck::packed_size_v<ADataType>; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
|
||||
@@ -2,9 +2,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
using ADataType = ck::f4x2_pk_t;
|
||||
using BDataType = ck::f4x2_pk_t;
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F8 = ck::f8_t;
|
||||
|
||||
#if defined(A_DATATYPE)
|
||||
using ADataType = A_DATATYPE;
|
||||
#else
|
||||
using ADataType = F4;
|
||||
#endif
|
||||
#if defined(B_DATATYPE)
|
||||
using BDataType = B_DATATYPE;
|
||||
#else
|
||||
using BDataType = F4;
|
||||
#endif
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t;
|
||||
@@ -21,7 +33,8 @@ using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t DataPackedSize =
|
||||
ck::packed_size_v<ADataType>; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
@@ -30,7 +43,7 @@ constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
|
||||
|
||||
// AB DataType: f4x2_pk_t
|
||||
// Mathmatically, all numbers are represented as f4x2.
|
||||
// Mathematically, all numbers are represented as f4x2.
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
@@ -47,24 +60,24 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Thread block size
|
||||
128, // MPerBlock
|
||||
512, // NPerBlock
|
||||
128, // BlockSize: Thread block size
|
||||
64, // MPerBlock
|
||||
64, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
8, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<8, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<8, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
@@ -72,9 +85,9 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
|
||||
16, // BBlockTransferDstScalarPerVector_BK1
|
||||
true, // BBlockLdsExtraN
|
||||
2, // CShuffleMXdlPerWavePerShuffle
|
||||
4, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 8, 1, 32>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlockW
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 8, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
2, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
using ADataType = ck::f6x16_pk_t;
|
||||
using BDataType = ck::f6x16_pk_t;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::bf8_t;
|
||||
|
||||
101
example/67_gemm_microscaling/gemm_mx_fp8_bpreshuffle.cpp
Normal file
101
example/67_gemm_microscaling/gemm_mx_fp8_bpreshuffle.cpp
Normal file
@@ -0,0 +1,101 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = MFMA;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256;
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
ADataType, // ADataType
|
||||
XPackedDataType, // AScaleDataType
|
||||
BDataType, // BDataType
|
||||
XPackedDataType, // BScaleDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // GemmAccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
AElementOp, // AElementwiseOperation
|
||||
BElementOp, // BElementwiseOperation
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Thread block size
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
16, // BBlockTransferSrcScalarPerVector
|
||||
16, // BBlockTransferDstScalarPerVector_BK1
|
||||
true, // BBlockLdsExtraN
|
||||
2, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
BDataType // ComputeTypeB
|
||||
>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<DeviceOpInstance,
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
XPackedDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ScaleBlockSize>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
101
example/67_gemm_microscaling/gemm_mx_fp8_v1.cpp
Normal file
101
example/67_gemm_microscaling/gemm_mx_fp8_v1.cpp
Normal file
@@ -0,0 +1,101 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256;
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
ADataType, // ADataType
|
||||
XPackedDataType, // AScaleDataType
|
||||
BDataType, // BDataType
|
||||
XPackedDataType, // BScaleDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // GemmAccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
AElementOp, // AElementwiseOperation
|
||||
BElementOp, // BElementwiseOperation
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Thread block size
|
||||
64, // MPerBlock
|
||||
128, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
16, // BBlockTransferSrcScalarPerVector
|
||||
16, // BBlockTransferDstScalarPerVector_BK1
|
||||
true, // BBlockLdsExtraN
|
||||
2, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
BDataType // ComputeTypeB
|
||||
>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<DeviceOpInstance,
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
XPackedDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ScaleBlockSize>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
@@ -1,47 +1,29 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "gemm_mx_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
using A0DataType = F4;
|
||||
#if defined(A_DATATYPE)
|
||||
using A0DataType = A_DATATYPE;
|
||||
#else
|
||||
using A0DataType = F4;
|
||||
#endif
|
||||
#if defined(B_DATATYPE)
|
||||
using B0DataType = B_DATATYPE;
|
||||
#else
|
||||
using B0DataType = F4;
|
||||
#endif
|
||||
using A1DataType = XPackedDataType;
|
||||
using B0DataType = F4;
|
||||
using B1DataType = XPackedDataType;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
@@ -89,67 +71,14 @@ struct MulABScaleExpertWeight
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
// A, B Scale preshuffle
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 128;
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
@@ -157,6 +86,11 @@ static constexpr ck::index_t NPerBlock = 64;
|
||||
static constexpr ck::index_t BlockSize = 256;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
static constexpr ck::index_t ClusterLengths_BK0 =
|
||||
ck::is_same_v<B0DataType, F4> && ck::is_same_v<A0DataType, F4> ? 8 : 4;
|
||||
static constexpr ck::index_t ClusterLengths_N =
|
||||
ck::is_same_v<B0DataType, F4> && ck::is_same_v<A0DataType, F4> ? 32 : 64;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
@@ -166,10 +100,10 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
MPerBlock, NPerBlock, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
2, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
|
||||
S<ClusterLengths_BK0, ClusterLengths_N, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 32, 1, 8>, S<4, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3,
|
||||
ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
@@ -373,6 +307,9 @@ int main(int argc, char* argv[])
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize());
|
||||
|
||||
// a0_t_k.savetxt("a.txt", "float", 128);
|
||||
// b0_e_n_k.savetxt("b.txt", "float", 128);
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
@@ -392,14 +329,30 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// A/B scale shuffle
|
||||
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
|
||||
b_scale_preshuffled.mData.data(),
|
||||
N * 2 * experts,
|
||||
K / ScaleBlockSize);
|
||||
if(ck::is_gfx125_supported())
|
||||
{
|
||||
preShuffleScaleBuffer_gfx1250<XDataType, ScaleBlockSize, ck::is_same_v<A0Layout, Row>>(
|
||||
a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer_gfx1250<XDataType, ScaleBlockSize, ck::is_same_v<B0Layout, Col>>(
|
||||
b1_e_n_k.mData.data(),
|
||||
b_scale_preshuffled.mData.data(),
|
||||
N * 2 * experts,
|
||||
K / ScaleBlockSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
preShuffleScaleBuffer_gfx950<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer_gfx950<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
|
||||
b_scale_preshuffled.mData.data(),
|
||||
N * 2 * experts,
|
||||
K / ScaleBlockSize);
|
||||
}
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
@@ -452,9 +405,10 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" ||
|
||||
ck::is_gfx125_supported()))
|
||||
{
|
||||
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
|
||||
std::cout << "This kernel support gfx942, gfx950 and gfx125x only" << std::endl;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user