[rocm-libraries] ROCm/rocm-libraries#6978 (commit e58096d)

[CK] add composable kernel support on gfx1250 (#6978)

## Motivation

Add composable kernel support on gfx1250.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Qun Lin <qlin@amd.com>
Co-authored-by: jialuo12_amdeng <jia.luo@amd.com>
Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com>
Co-authored-by: hsivasun_amdeng <haresh.sivasuntharampillai@amd.com>
This commit is contained in:
Illia Silin
2026-05-15 06:46:51 -07:00
committed by GitHub
parent ac18460782
commit 717f2efef7
912 changed files with 73598 additions and 11750 deletions

View File

@@ -50,6 +50,7 @@ option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF)
option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF)
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
option(CK_EXPERIMENTAL_GEMM_BENCHMARK "Enable experimental gemm benchmark for gfx1250" OFF)
option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
option(FORCE_DISABLE_WMMA "Skip compiling WMMA specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
option(BUILD_CK_TILE_ENGINE "Build the tile_engine subdirectory" OFF)
@@ -303,6 +304,9 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx10")
add_definitions(-DCK_GFX1030_SUPPORT)
endif()
# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA
set(CK_TILE_USE_WMMA 0)
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") AND NOT FORCE_DISABLE_WMMA)
message(STATUS "Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA)
@@ -321,6 +325,8 @@ endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950")
add_definitions(-DCK_USE_OCP_FP8)
set(CK_USE_OCP_FP8 "ON")
add_definitions(-DCK_TILE_USE_OCP_FP8)
set(CK_TILE_USE_OCP_FP8 "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx94")
add_definitions(-DCK_USE_FNUZ_FP8)
@@ -332,6 +338,16 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
add_definitions(-DCK_GFX950_SUPPORT)
set(CK_GFX950_SUPPORT "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx1250")
add_definitions(-DCK_USE_GFX1250)
add_definitions(-DCK_USE_NATIVE_MX_SUPPORT)
set(CK_USE_NATIVE_MX_SUPPORT "ON")
add_definitions(-DCK_GFX1250_SUPPORT)
set(CK_GFX1250_SUPPORT "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
add_definitions(-DCK_GFX12_SUPPORT)
endif()
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
add_definitions(-DCK_ENABLE_TF32)
@@ -413,6 +429,7 @@ endif()
option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF)
option(ENABLE_JSON_DUMP "Whether to enable json dump for examples." OFF)
option(CK_TEST_DISABLE_GPU_VALIDATION "Whether to disable GPU validation in CK tests." OFF )
if(USE_BITINT_EXTENSION_INT4)
add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
@@ -431,6 +448,10 @@ if (ENABLE_JSON_DUMP)
message("CK compiled with ENABLE_JSON_DUMP set to ${ENABLE_JSON_DUMP}")
endif()
if (CK_TEST_DISABLE_GPU_VALIDATION)
add_compile_definitions(CK_TEST_DISABLE_GPU_VALIDATION)
endif()
## Threads
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
@@ -797,6 +818,10 @@ if(BUILD_CK_PROFILER)
endif()
endif()
if (CK_EXPERIMENTAL_GEMM_BENCHMARK)
add_subdirectory(experimental/gemm_benchmark)
endif()
if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
add_subdirectory(codegen)
endif()

146
Jenkinsfile vendored
View File

@@ -772,6 +772,26 @@ def cmake_build(Map conf=[:]){
try {
//build CK
sh cmd
if (runAllUnitTests){
// Archive artifacts if they were generated
if (fileExists("ck_build_trace_${arch_name}.json")) {
archiveArtifacts "ck_build_trace_${arch_name}.json"
}
if (fileExists("clang_build_analysis_${arch_name}.log")) {
archiveArtifacts "clang_build_analysis_${arch_name}.log"
}
// Process ninja build trace after full build
sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${arch_name}.json"
archiveArtifacts "ck_build_trace_${arch_name}.json"
sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${arch_name}.json"
if (params.NINJA_FTIME_TRACE) {
echo "running ClangBuildAnalyzer"
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log"
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${arch_name}.log"
archiveArtifacts "clang_build_analysis_${arch_name}.log"
}
}
} catch (Exception buildError) {
echo "Build failed: ${buildError.getMessage()}"
throw buildError
@@ -796,49 +816,32 @@ def cmake_build(Map conf=[:]){
}
}
//run tests except when NO_CK_BUILD is set
//run tests except when NO_CK_BUILD is set and except on gfx1250
if(!setup_args.contains("NO_CK_BUILD")){
if (params.NINJA_BUILD_TRACE || params.BUILD_INSTANCES_ONLY){
// do not run unit tests when building instances only
if(!params.BUILD_INSTANCES_ONLY){
if (!runAllUnitTests){
// Smart Build: Run smart_build_and_test.sh
sh """
export WORKSPACE_ROOT=${env.WORKSPACE}
export PARALLEL=32
export NINJA_JOBS=${nt}
export ARCH_NAME=${arch_name}
export PROCESS_NINJA_TRACE=true
export NINJA_FTIME_TRACE=${params.NINJA_FTIME_TRACE ? 'true' : 'false'}
bash ../script/dependency-parser/smart_build_and_test.sh
"""
// Archive artifacts if they were generated
if (fileExists("ck_build_trace_${arch_name}.json")) {
archiveArtifacts "ck_build_trace_${arch_name}.json"
}
if (fileExists("clang_build_analysis_${arch_name}.log")) {
archiveArtifacts "clang_build_analysis_${arch_name}.log"
}
}
else{
// run unit tests unless building library for all targets
// Note: This else block is when NINJA_BUILD_TRACE=false and BUILD_INSTANCES_ONLY=false
// So no ninja trace processing needed here
if (!params.BUILD_INSTANCES_ONLY){
if (!runAllUnitTests && !setup_args.contains("gfx1250") ){
// Smart Build: Run smart_build_and_test.sh
sh """
export WORKSPACE_ROOT=${env.WORKSPACE}
export PARALLEL=32
export NINJA_JOBS=${nt}
export ARCH_NAME=${arch_name}
export PROCESS_NINJA_TRACE=false
export NINJA_FTIME_TRACE=false
bash ../script/dependency-parser/smart_build_and_test.sh
"""
}
else{ //run all tests
if(!setup_args.contains("gfx1250")){
echo "Full test suite requested (RUN_ALL_UNIT_TESTS=true or develop branch)"
sh "ninja -j${nt} check"
// Process ninja build trace after full build
sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${arch_name}.json"
archiveArtifacts "ck_build_trace_${arch_name}.json"
sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${arch_name}.json"
if (params.NINJA_FTIME_TRACE) {
echo "running ClangBuildAnalyzer"
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log"
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${arch_name}.log"
archiveArtifacts "clang_build_analysis_${arch_name}.log"
}
}
if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) {
sh 'ninja check-builder'
else{ //do not run tests on gfx1250, just build everything
echo "Building for gfx1250"
sh "ninja -j${nt}"
}
if (params.RUN_ROCM_CK_TESTS) {
sh 'ninja check-rocm-ck'
@@ -850,47 +853,8 @@ def cmake_build(Map conf=[:]){
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
}
}
if(params.BUILD_INSTANCES_ONLY){
// build deb packages
echo "Build library package"
sh 'ninja -j64 package'
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.2.0_amd64.deb'
stash includes: "composablekernel-dev**.deb", name: "lib_package"
}
}
else{
// run unit tests unless building library for all targets
// Note: This else block is when NINJA_BUILD_TRACE=false and BUILD_INSTANCES_ONLY=false
// So no ninja trace processing needed here
if (!params.BUILD_INSTANCES_ONLY){
if (!runAllUnitTests){
// Smart Build: Run smart_build_and_test.sh
sh """
export WORKSPACE_ROOT=${env.WORKSPACE}
export PARALLEL=32
export NINJA_JOBS=${nt}
export ARCH_NAME=${arch_name}
export PROCESS_NINJA_TRACE=false
export NINJA_FTIME_TRACE=false
bash ../script/dependency-parser/smart_build_and_test.sh
"""
}
else{
echo "Full test suite requested (RUN_ALL_UNIT_TESTS=true or develop branch)"
sh "ninja -j${nt} check"
}
if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) {
sh 'ninja check-builder'
}
if (params.RUN_ROCM_CK_TESTS) {
sh 'ninja check-rocm-ck'
}
if(params.BUILD_PACKAGES){
echo "Build ckProfiler packages"
sh 'ninja -j64 package'
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
}
if (params.RUN_BUILDER_TESTS && !setup_args.contains("-DCK_CXX_STANDARD=") && !setup_args.contains("gfx10") && !setup_args.contains("gfx11")) {
sh 'ninja check-builder'
}
}
}
@@ -1414,6 +1378,10 @@ pipeline {
name: "BUILD_GFX12",
defaultValue: true,
description: "Build CK and run tests on gfx12 (default: ON)")
booleanParam(
name: "BUILD_GFX1250",
defaultValue: true,
description: "Build CK for gfx1250 (default: ON)")
booleanParam(
name: "NINJA_BUILD_TRACE",
defaultValue: true,
@@ -2052,6 +2020,7 @@ pipeline {
cleanWs()
}
}
/*
stage("Build CK and run Tests on gfx1010")
{
when {
@@ -2068,6 +2037,7 @@ pipeline {
cleanWs()
}
}
*/
stage("Build CK and run Tests on gfx1030")
{
when {
@@ -2116,6 +2086,21 @@ pipeline {
cleanWs()
}
}
stage("Build CK for gfx1250")
{
when {
beforeAgent true
expression { params.BUILD_GFX1250.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1250" -DDISABLE_DL_KERNELS="ON" """
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:npi-mi450-latest", config_targets: "install", no_reboot:true, build_type: 'Release', prefixpath: '/usr/local')
cleanWs()
}
}
}
post {
always {
@@ -2194,3 +2179,4 @@ pipeline {
}
}
}

View File

@@ -47,7 +47,7 @@ list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllv
example_compile_options(example_gemm_xdl_fp8_v3 PRIVATE ${GEMM_OPTIONS})
example_compile_options(example_gemm_xdl_bf16_v3 PRIVATE ${GEMM_OPTIONS})
list(APPEND gpu_list gfx942 gfx950 gfx1200 gfx1201 gfx12-generic)
list(APPEND gpu_list gfx942 gfx950 gfx1200 gfx1201 gfx12-generic gfx1250)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
@@ -82,7 +82,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
list(APPEND gpu_list gfx90a gfx942 gfx950)
list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1250)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
@@ -95,7 +95,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif()
endforeach()
list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1200 gfx1201 gfx12-generic)
list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1200 gfx1201 gfx12-generic gfx1250)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
@@ -141,6 +141,9 @@ add_example_executable(example_gemm_wmma_bf16_pk_i4_v3 gemm_wmma_bf16_pk_i4_v3.c
add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16_pk_i4_v3)
add_example_executable(example_gemm_wmma_fp8_v3 gemm_wmma_fp8_v3.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_v3)
add_example_executable(example_gemm_wmma_fp8_v3_reg_spill gemm_wmma_fp8_v3_reg_spill.cpp)
example_compile_options(example_gemm_wmma_fp8_v3_reg_spill PRIVATE "SHELL: -Rpass-analysis=kernel-resource-usage ")
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_v3_reg_spill)
add_example_executable(example_gemm_wmma_fp16_v3 gemm_wmma_fp16_v3.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_v3)
add_example_executable(example_gemm_wmma_fp16_pk_i4_v3 gemm_wmma_fp16_pk_i4_v3.cpp)
@@ -149,6 +152,11 @@ add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3)
add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx125")
add_example_executable(example_gemm_xdl_bf16_v3_prefetch gemm_xdl_bf16_v3_prefetch.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3_prefetch)
endif()
add_example_executable(example_gemm_wmma_fp8_bpreshuffle gemm_wmma_fp8_bpreshuffle.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_bpreshuffle)
add_example_executable(example_gemm_wmma_fp16_bpreshuffle gemm_wmma_fp16_bpreshuffle.cpp)

View File

@@ -29,7 +29,7 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
128,
128, 64, 64,
128, 64, 128,
16, 16, // AK1, BK1
16, 16,
4, 2,
@@ -42,7 +42,6 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf
ComputeTypeA, ComputeTypeB>;
// clang-format on
using ReferenceComputeType = ck::f8_t;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
@@ -50,8 +49,8 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
AElementOp,
BElementOp,
CElementOp,
ReferenceComputeType,
ReferenceComputeType>;
ComputeTypeA,
ComputeTypeB>;
#include "run_gemm_example_v2.inc"

View File

@@ -0,0 +1,113 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* \brief Example of GEMM using WMMA that illustrates register spilling on gfx1200 architecture but
* no register spilling on gfx1250.
*
* This example demonstrates how more registers available on the gfx1250 architecture can help avoid
* register spilling that occurs on gfx1200 when using a specific GEMM configuration.
*
* This example must be compiled with the following flag to see the resource allocations:
* "-Rpass-analysis=kernel-resource-usage"
*
* On gfx1200, the kernel will show register spilling due to limited VGPRs:
* \verbatim
* TotalSGPRs: 105
* VGPRs: 256
* ScratchSize [bytes/lane]: 56
* Dynamic Stack: False
* Occupancy [waves/SIMD]: 5
* SGPRs Spill: 0
* VGPRs Spill: 15
* LDS Size [bytes/block]: 32768
*
* gfx1201 - AMD Radeon RX 9070 XT
* Problem {M:3840, N:4096, K:4096, SA:4096, SB:4096, SC:4096, MP:3840, NP:4096, KRead:4096,
* KP:4096, AK0:512, BK0:512, MBlock: 30, NBlock: 32}
*
* Perf: 0.882764 ms, 145.961 TFlops, 72.4578 GB/s, DeviceGemm_Wmma_CShuffleV3<Default, RCR>
* BlkSize: 128, BlkTile: 128x128x128, WaveTile: 16x16, WaveMap: 4x4, VmemReadVec: 8x8,
* BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1, BlkGemmPipelinePrefetchStages:
* 1, KPack: 16
* \endverbatim
*
* On gfx1250, the same kernel will not show register spilling due to increased VGPRs:
* \verbatim
* TotalSGPRs: 32
* VGPRs: 318
* ScratchSize [bytes/lane]: 0
* Dynamic Stack: False
* Occupancy [waves/SIMD]: 3
* SGPRs Spill: 0
* VGPRs Spill: 0
* LDS Size [bytes/block]: 32768
* \endverbatim
*
* \note The register allocations above can be influenced by compiler version and code
* changes/optimizations.
*/
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using ComputeTypeA = ck::f8_t;
using ComputeTypeB = ck::f8_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
128, //blocksize
128, 128, 128, // M/N/KPerBlock
8, 8, // AK1, BK1
16, 16, //MPerWmma, NPerWmma
4, 4, //MRepeat, NRepeat
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,//
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
1, 1, S<1, 32, 1, 4>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1,
ComputeTypeA, ComputeTypeB>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ComputeTypeA,
ComputeTypeB>;
#include "run_gemm_example_v2.inc"
int main(int argc, char* argv[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This kernel support gfx12 only" << std::endl;
return 0;
}
return !run_gemm_splitk_example(argc, argv);
}

View File

@@ -0,0 +1,325 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using AccDataType = float;
using CShuffleDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
#if 1
static const uint32_t AB_K1 = 8;
// clang-format off
template <bool UseDataCachePrefetch>
using DeviceGemmV3Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
128,
256, 256,
64, AB_K1, AB_K1,
16, 16,
8, 16,
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, AB_K1, AB_K1, 0,
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, AB_K1, AB_K1, 0,
2, 4, S<1, 8, 1, 16>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3,
CDataType,
CDataType,
false,
false,
0,
UseDataCachePrefetch>;
// clang-format on
#else
// prefetch is faster on these params
// clang-format off
template <bool UseDataCachePrefetch>
using DeviceGemmV3Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
256,
128, 128,
64, 8, 8,
16, 16,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3,
CDataType,
CDataType,
false,
false,
0,
UseDataCachePrefetch>;
// clang-format on
#endif
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
template <typename GemmInstanceType, typename ProblemType>
std::pair<bool, float> run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto M = problem_size.M;
auto N = problem_size.N;
auto K = problem_size.K;
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto KBatch = problem_size.KBatch;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1 || stride == 0)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 3:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
DeviceMem workspace;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = GemmInstanceType{};
auto invoker = gemm.MakeInvoker();
float ave_time = 0;
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return std::make_pair(true, ave_time);
}
bool pass = true;
if((config.do_verification == 1) || (config.do_verification == 3))
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
}
if(config.time_kernel)
{
ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 100, true, 4});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return std::make_pair(pass, ave_time);
}
bool parse_cmd_args(int argc,
char* argv[],
ProblemSizeSplitK& problem_size,
ExecutionConfig& config,
bool& compareWithNonDataCachePrefetchImpl)
{
compareWithNonDataCachePrefetchImpl = false;
if(argc == 1)
{
// use default case
}
else if(argc == 4 || argc >= 10)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
if(argc >= 10)
{
problem_size.M = std::stoi(argv[4]);
problem_size.N = std::stoi(argv[5]);
problem_size.K = std::stoi(argv[6]);
problem_size.StrideA = std::stoi(argv[7]);
problem_size.StrideB = std::stoi(argv[8]);
problem_size.StrideC = std::stoi(argv[9]);
if(argc >= 11)
{
problem_size.KBatch = std::stoi(argv[10]);
if(argc > 12)
{
compareWithNonDataCachePrefetchImpl = std::stoi(argv[11]);
}
}
}
}
else
{
std::cerr
<< "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)"
<< std::endl
<< "arg10: KBatch" << std::endl
<< "arg11: compareWithNonDataCachePrefetchImpl (0=no, 1=yes)" << std::endl;
return false;
}
return true;
}
int main(int argc, char* argv[])
{
ProblemSizeSplitK problem_size;
ExecutionConfig config;
bool compareWithNonDataCachePrefetchImpl;
if(!parse_cmd_args(argc, argv, problem_size, config, compareWithNonDataCachePrefetchImpl))
{
return 1;
}
auto [pass, ave_time] = run_gemm<DeviceGemmV3Instance<true>>(problem_size, config);
if(compareWithNonDataCachePrefetchImpl)
{
auto [pass2, ave_time2] = run_gemm<DeviceGemmV3Instance<false>>(problem_size, config);
std::cout << "DataCache Prefetching enabled ave_time: " << ave_time << " ms" << std::endl;
std::cout << "DataCache Prefetching disabled ave_time: " << ave_time2 << " ms" << std::endl;
float speedup = ave_time2 / ave_time;
std::cout << "On average kernel with DataCache prefetching is " << speedup
<< " times faster than without DataCache prefetching." << std::endl;
if(speedup < 1.0f)
std::cout << "WARNING: Kernel with DataCache prefetching is slower!" << std::endl;
}
return pass ? 0 : 1;
}

View File

@@ -28,10 +28,10 @@ using DeviceGemmV2_Streamk_Instance =
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
64,
16, 16,
32, 32,
256, 8, 16,
16, 16,
1, 1,
2, 2,
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>,

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#define CK_ENABLE_DYNAMIC_WARP_SIZE 1
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using ADataType = float;
using BDataType = float;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = float;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
128,
64, 64,
64, 4, 4,
16, 16,
2, 4,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 2, 2, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 2, 2, 0,
1, 2, S<1, 32, 1, 4>, 2,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example_v2.inc"
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#define CK_ENABLE_DYNAMIC_WARP_SIZE 1
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using ADataType = double;
using BDataType = double;
using AccDataType = double;
using CShuffleDataType = double;
using CDataType = double;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
128,
64, 64,
64, 4, 4,
16, 16,
2, 4,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 2, 2, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 2, 2, 0,
1, 2, S<1, 32, 1, 4>, 2,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example_v2.inc"
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }

View File

@@ -29,8 +29,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipBLds
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| SrcScalar| buffer| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| size | | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if 0
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 8, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 8, 8, 7, 1>;
#if 0
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 8, 8, 7, 1>;
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;

View File

@@ -181,7 +181,7 @@ int main(int argc, char* argv[])
exit(0);
}
bool is_supported = ck::is_gfx11_supported();
bool is_supported = ck::is_gfx11_supported() || ck::is_gfx125_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()

View File

@@ -181,7 +181,7 @@ int main(int argc, char* argv[])
exit(0);
}
bool is_supported = ck::is_gfx11_supported();
bool is_supported = ck::is_gfx11_supported() || ck::is_gfx125_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()

View File

@@ -19,7 +19,7 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4)
list(APPEND gpu_list gfx90a gfx942 gfx950)
list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1250)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)

View File

@@ -47,14 +47,14 @@ using DeviceGroupedConvNDFwdInstance =
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
128, // NPerBlock
64, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder

View File

@@ -49,9 +49,9 @@ using DeviceGroupedConvNDFwdInstance =
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
128, // KPerBlock
32, // AK1
32, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave

View File

@@ -50,9 +50,9 @@ using DeviceGroupedConvNDFwdInstance =
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
128, // KPerBlock
32, // AK1
32, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave

View File

@@ -47,14 +47,14 @@ using DeviceGroupedConvNDFwdInstance =
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
128, // NPerBlock
64, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder

View File

@@ -48,14 +48,14 @@ using DeviceGroupedConvNDFwdInstance =
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
128, // NPerBlock
128, // KPerBlock
32, // AK1
32, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder

View File

@@ -51,10 +51,10 @@ using DeviceGroupedConvNDFwdInstance =
16, // KPerBlock
4, // AK1
4, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -72,13 +72,13 @@ using DeviceGroupedConvNDFwdInstance =
1,
1,
S<1, 16, 1, 16>,
4>;
2>;
#include "run_convnd_fwd_example.inc"
int main(int argc, char* argv[])
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
{
return 0;
}

View File

@@ -51,9 +51,9 @@ using DeviceGroupedConvNDFwdInstance =
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
128, // KPerBlock
32, // AK1
32, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave

View File

@@ -50,9 +50,9 @@ using DeviceGroupedConvNDFwdInstance =
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
128, // KPerBlock
32, // AK1
32, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave

View File

@@ -17,7 +17,7 @@ using RsDataType = ck::Tuple<R0DataType>;
int main(int argc, char* argv[])
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
{
return 0;
}

View File

@@ -25,6 +25,7 @@ static constexpr auto ConvSpec =
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto KPerBlock = sizeof(ADataType) == 1 ? 64 : 32;
// clang-format off
template <ck::index_t NDimSpatial>
using DeviceInstance =
@@ -34,9 +35,9 @@ using DeviceInstance =
//######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| Specialization| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| |
#ifdef BUILD_INT4_EXAMPLE
< NDimSpatial, ALayout<NDimSpatial>, BLayout<NDimSpatial>, DELayout<NDimSpatial>, RLayout<NDimSpatial>, KernelADataType, KernelBDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
< NDimSpatial, ALayout<NDimSpatial>, BLayout<NDimSpatial>, DELayout<NDimSpatial>, RLayout<NDimSpatial>, KernelADataType, KernelBDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, KPerBlock, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
#else
< NDimSpatial, ALayout<NDimSpatial>, BLayout<NDimSpatial>, DELayout<NDimSpatial>, RLayout<NDimSpatial>, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, 1>;
< NDimSpatial, ALayout<NDimSpatial>, BLayout<NDimSpatial>, DELayout<NDimSpatial>, RLayout<NDimSpatial>, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, KPerBlock, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, 1>;
#endif
template <ck::index_t NDimSpatial>
@@ -292,15 +293,15 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
conv_output_device_buf.FromDevice(conv_output_device.mData.data());
r0_device_buf.FromDevice(r0_device.mData.data());
auto rtol = ck::is_same_v<EDataType, BF16> ? 1e-1f : 1e-3f;
auto pass = ck::utils::check_err(conv_output_device,
conv_output_host,
"Error: incorrect results! (Matrix E)",
1e-3f,
rtol,
1e-3f);
pass =
pass && ck::utils::check_err(
r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-3f, 1e-3f);
r0_device, r0_host, "Error: incorrect results! (Matrix R0)", rtol, 1e-3f);
if(pass)
std::cout << "Verification on CPU: PASS" << std::endl;

View File

@@ -328,7 +328,8 @@ int main(int argc, char* argv[])
problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
if(argc == 5)
if(argc == 1) {}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);

View File

@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
struct ProblemSize final
@@ -302,7 +302,8 @@ int main(int argc, char* argv[])
problem_size.group_count = 16;
if(argc == 5)
if(argc == 1) {}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);

View File

@@ -61,7 +61,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
struct ProblemSize final
@@ -287,7 +287,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
c_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
}
}
@@ -302,7 +301,8 @@ int main(int argc, char* argv[])
problem_size.group_count = 16;
if(argc == 5)
if(argc == 1) {}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);

View File

@@ -59,7 +59,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSpl
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
#include "run_grouped_gemm_example.inc"
@@ -71,7 +71,8 @@ int main(int argc, char* argv[])
problem_size.group_count = 16;
if(argc == 4)
if(argc == 1) {}
else if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);

View File

@@ -69,10 +69,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip
16, // KPerBlock
4, // AK1
4, // BK1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
2, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
8, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
@@ -89,7 +89,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
S<32, 8>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock
1>; // RThread DstScalarPerVector _MPerBlock
// clang-format on
@@ -121,25 +121,21 @@ int main(int argc, char* argv[])
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
else if(argc == 4 || argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
if(argc == 10)
{
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[9]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[9]);
}
}
else
{
@@ -147,10 +143,10 @@ int main(int argc, char* argv[])
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: Measure kernel execution time (1=ON, 0=Off)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
exit(0);
exit(1);
}
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
{
return 0;
}

View File

@@ -76,10 +76,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip
16, // KPerBlock
4, // AK1
4, // BK1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
2, // NXdlPerWave
16, // MPerXdl
16, // NPerXdl
8, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
@@ -96,7 +96,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
S<32, 8>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock
1>; // RThread DstScalarPerVector _MPerBlock
// clang-format on
@@ -127,25 +127,21 @@ int main(int argc, char* argv[])
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
else if(argc == 4 || argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
if(argc == 10)
{
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[9]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[9]);
}
}
else
{
@@ -154,10 +150,10 @@ int main(int argc, char* argv[])
<< " arg3: Measure kernel execution time (1=ON, 0=Off)\n"
<< " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"
<< std::endl;
exit(EXIT_SUCCESS);
exit(1);
}
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
{
exit(EXIT_SUCCESS);
}

View File

@@ -4,4 +4,6 @@
add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp)
add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp)
add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp)
add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp)
add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp)
add_example_executable(example_elementwise_tanh_1d elementwise_tanh_1d.cpp)
add_example_executable(example_elementwise_fastgelu_1d elementwise_fastgelu_1d.cpp)

View File

@@ -0,0 +1,130 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
using ::ck::DeviceMem;
using ::ck::HostTensorDescriptor;
using ::ck::Tensor;
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using ADataType = F16;
using CDataType = F16;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using DeviceElementwiseFastGeluInstance =
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>,
ck::Tuple<CDataType>,
FastGelu,
1,
64,
16,
16,
2,
2,
ck::Sequence<1, 0>,
ck::Sequence<1>,
ck::Sequence<1>>;
template <typename HostTensorA, typename HostTensorC, typename Functor>
void host_elementwise1D(HostTensorC& C, const HostTensorA& A, int M, Functor functor)
{
using ctype = ck::remove_reference_t<decltype(C(0))>;
for(int m = 0; m < M; ++m)
{
auto Am = A(m);
ctype Cm = 0;
functor(Cm, Am);
C(m) = Cm;
}
}
int main(int argc, char* argv[])
{
bool do_verification;
bool time_kernel;
if(argc == 1)
{
do_verification = true;
time_kernel = false;
}
else if(argc == 3)
{
do_verification = std::stoi(argv[1]);
time_kernel = static_cast<bool>(std::stoi(argv[2]));
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: time kernel (0=no, 1=yes)\n");
exit(0);
}
ck::index_t M = 1024;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor({len}, {stride});
};
Tensor<ADataType> a_m(f_host_tensor_descriptor1d(M, 1));
Tensor<CDataType> c_m(f_host_tensor_descriptor1d(M, 1));
a_m.GenerateTensorValue(GeneratorTensor_3<ADataType>{-5, 5});
DeviceMem a_m_device_buf(sizeof(ADataType) * a_m.mDesc.GetElementSpaceSize());
DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpaceSize());
a_m_device_buf.ToDevice(a_m.mData.data());
std::array<const void*, 1> input = {a_m_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 1> abc_lengths = {M};
std::array<ck::index_t, 1> a_strides = {1};
std::array<ck::index_t, 1> c_strides = {1};
auto broadcastFastGelu = DeviceElementwiseFastGeluInstance{};
auto argument = broadcastFastGelu.MakeArgumentPointer(
abc_lengths, {a_strides}, {c_strides}, input, output, FastGelu{});
if(!broadcastFastGelu.IsSupportedArgument(argument.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
auto broadcastFastGelu_invoker_ptr = broadcastFastGelu.MakeInvokerPointer();
float ave_time =
broadcastFastGelu_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::cout << "Perf: " << ave_time << " ms" << std::endl;
bool pass = true;
if(do_verification)
{
c_m_device_buf.FromDevice(c_m.mData.data());
Tensor<CDataType> host_c_m(f_host_tensor_descriptor1d(M, 1));
host_elementwise1D<Tensor<ADataType>, Tensor<CDataType>, FastGelu>(
host_c_m, a_m, M, FastGelu{});
pass &= ck::utils::check_err(c_m, host_c_m, "Error: Incorrect results c", 4e-3, 4e-3);
}
return pass ? 0 : 1;
}

View File

@@ -0,0 +1,129 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
using ::ck::DeviceMem;
using ::ck::HostTensorDescriptor;
using ::ck::Tensor;
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using ADataType = F16;
using CDataType = F16;
using Tanh = ck::tensor_operation::element_wise::TanH;
using DeviceElementwiseTanhInstance =
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>,
ck::Tuple<CDataType>,
Tanh,
1,
64,
16,
16,
2,
2,
ck::Sequence<1, 0>,
ck::Sequence<1>,
ck::Sequence<1>>;
template <typename HostTensorA, typename HostTensorC, typename Functor>
void host_elementwise1D(HostTensorC& C, const HostTensorA& A, int M, Functor functor)
{
using ctype = ck::remove_reference_t<decltype(C(0))>;
for(int m = 0; m < M; ++m)
{
auto Am = A(m);
ctype Cm = 0;
functor(Cm, Am);
C(m) = Cm;
}
}
int main(int argc, char* argv[])
{
bool do_verification;
bool time_kernel;
if(argc == 1)
{
do_verification = true;
time_kernel = false;
}
else if(argc == 3)
{
do_verification = std::stoi(argv[1]);
time_kernel = static_cast<bool>(std::stoi(argv[2]));
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: time kernel (0=no, 1=yes)\n");
exit(0);
}
ck::index_t M = 1024;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor({len}, {stride});
};
Tensor<ADataType> a_m(f_host_tensor_descriptor1d(M, 1));
Tensor<CDataType> c_m(f_host_tensor_descriptor1d(M, 1));
a_m.GenerateTensorValue(GeneratorTensor_3<ADataType>{-5, 5});
DeviceMem a_m_device_buf(sizeof(ADataType) * a_m.mDesc.GetElementSpaceSize());
DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpaceSize());
a_m_device_buf.ToDevice(a_m.mData.data());
std::array<const void*, 1> input = {a_m_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 1> abc_lengths = {M};
std::array<ck::index_t, 1> a_strides = {1};
std::array<ck::index_t, 1> c_strides = {1};
auto broadcastTanh = DeviceElementwiseTanhInstance{};
auto argument = broadcastTanh.MakeArgumentPointer(
abc_lengths, {a_strides}, {c_strides}, input, output, Tanh{});
if(!broadcastTanh.IsSupportedArgument(argument.get()))
{
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
auto broadcastTanh_invoker_ptr = broadcastTanh.MakeInvokerPointer();
float ave_time =
broadcastTanh_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::cout << "Perf: " << ave_time << " ms" << std::endl;
bool pass = true;
if(do_verification)
{
c_m_device_buf.FromDevice(c_m.mData.data());
Tensor<CDataType> host_c_m(f_host_tensor_descriptor1d(M, 1));
host_elementwise1D<Tensor<ADataType>, Tensor<CDataType>, Tanh>(host_c_m, a_m, M, Tanh{});
pass &= ck::utils::check_err(c_m, host_c_m, "Error: Incorrect results c", 4e-3, 4e-3);
}
return pass ? 0 : 1;
}

View File

@@ -43,29 +43,29 @@ using DeviceConvBwdWeightInstance =
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
32, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
64, // K0PerBlock
16, // K1
16, // MPerXdl
16, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S<1, 64, 1, 2>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format off

View File

@@ -41,29 +41,29 @@ using DeviceConvBwdWeightInstance =
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
32, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
64, // K0PerBlock
16, // K1
16, // MPerXdl
16, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1
false, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S<1, 64, 1, 2>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
template <ck::index_t NDimSpatial>

View File

@@ -158,8 +158,8 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
int main()
{
// temp disable on gfx11
if(ck::is_gfx11_supported())
// temp disable on gfx11 & gfx12
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
return 0;
}

View File

@@ -74,7 +74,7 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_
int main(int argc, char* argv[])
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
{
return 0;
}

View File

@@ -69,14 +69,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
CDEElementOp,
GemmDefault,
256, // BlockSize
256, // MPerBlock
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXDL
16, // NPerXDL
8, // MXdlPerWave
4, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder

View File

@@ -66,15 +66,15 @@ using DeviceBatchedGemmV2Instance =
ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
256, Scale_Block_N, Scale_Block_K,
16, 64,
32, 64,
KPerBlock, 8, 32,
16, 16,
1, 1,
1, 2,
S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 8,
1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>;
// clang-format on

View File

@@ -65,7 +65,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
int main(int argc, char* argv[])
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
{
return 0;
}

View File

@@ -523,7 +523,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
{
return 1;
}
@@ -536,7 +536,7 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[])
problem_size.M = 128 * (dis(gen) + 1);
problem_size.N = 128 * (dis(gen) + 1);
problem_size.K = 256 * (dis(gen) + 2);
problem_size.K = 256 * (dis(gen) + 4);
problem_size.batch_count = 2;

View File

@@ -37,7 +37,7 @@ using DeviceOpInstanceKK_Generic = ck::tensor_operation::device::
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
@@ -59,7 +59,7 @@ using DeviceOpInstanceKN_Generic = ck::tensor_operation::device::
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
@@ -81,7 +81,7 @@ using DeviceOpInstanceMK_Generic = ck::tensor_operation::device::
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
@@ -103,7 +103,7 @@ using DeviceOpInstanceMN_Generic = ck::tensor_operation::device::
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
// clang-format on
// Fp64 instances.
@@ -126,7 +126,7 @@ using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device::
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
@@ -148,7 +148,7 @@ using DeviceOpInstanceKN_FP64 = ck::tensor_operation::device::
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
@@ -170,7 +170,7 @@ using DeviceOpInstanceMK_FP64 = ck::tensor_operation::device::
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,

View File

@@ -370,3 +370,84 @@ inline HostTensorDescriptor make_output_descriptor(const ck::utils::conv::ConvPa
throw std::runtime_error("unsuppored # dim spatial");
}
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}

View File

@@ -25,7 +25,7 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
int main(int argc, char* argv[])
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
{
return 0;
}

View File

@@ -32,6 +32,8 @@ using BiasLayout = typename LayoutSettingSelector<NDimSpatial>::BiasLayout;
template <ck::index_t NDimSpatial>
using ResidualLayout = typename LayoutSettingSelector<NDimSpatial>::ResidualLayout;
static constexpr auto KPerBlock = sizeof(InKernelDataType) == 1 ? 64 : 32;
// instance for double rate mfma on gfx950 (vs gfx942)
template <ck::index_t NDimSpatial>
using DeviceConvFwdInstance2 =
@@ -105,7 +107,7 @@ using DeviceConvFwdInstance =
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
KPerBlock, // KPerBlock
4, // AK1
4, // BK1
16, // MPerXdl
@@ -333,11 +335,17 @@ bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config,
#ifdef BUILD_INT4_EXAMPLE
const Tensor<OutUserDataType> out_device_converted(out_device);
return ck::utils::check_err(
out_device_converted, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
return ck::utils::check_err(out_device_converted,
out_host,
"Error: incorrect results!",
get_rtol<OutUserDataType>(),
get_atol<OutUserDataType>());
#else
return ck::utils::check_err(
out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
return ck::utils::check_err(out_device,
out_host,
"Error: incorrect results!",
get_rtol<OutUserDataType>(),
get_atol<OutUserDataType>());
#endif
}

View File

@@ -23,14 +23,14 @@ using DeviceConvFwdInstance =
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
16, // KPerBlock
128, // NPerBlock
64, // KPerBlock
4, // AK1
4, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
@@ -104,7 +104,6 @@ bool run_grouped_conv_fwd(const ExecutionConfig& config,
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
#endif
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
@@ -128,7 +127,6 @@ bool run_grouped_conv_fwd(const ExecutionConfig& config,
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
// do Conv
auto conv = DeviceConvFwdInstance<NDimSpatial>{};
auto invoker = conv.MakeInvoker();
@@ -151,7 +149,6 @@ bool run_grouped_conv_fwd(const ExecutionConfig& config,
InElementOp{},
WeiElementOp{},
OutElementOp{});
if(!conv.IsSupportedArgument(argument))
{
throw std::runtime_error(

View File

@@ -87,11 +87,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X
4, // AK1
4, // BK1
1, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
16, // MPerXDL
16, // NPerXDL
2, // MXdlPerWave
8, // NXdlPerWave
8, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
@@ -114,7 +114,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X
1,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle
S<1, 16, 1, 16>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
@@ -138,7 +138,7 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
int main(int argc, char* argv[])
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
{
return 0;
}

View File

@@ -7,6 +7,10 @@ add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_f
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp)
add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx125")
target_compile_definitions(example_self_attention_forward_wmma_fp16 PRIVATE USE_GFX125_CONFIG=1)
target_compile_definitions(example_cross_attention_forward_wmma_fp16 PRIVATE USE_GFX125_CONFIG=1)
endif()
add_custom_target(example_gemm_scale_softmax_gemm)

View File

@@ -33,10 +33,10 @@ using DeviceGemmV2Instance =
ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmDefault,
256,
128, 256, 64,
128, 128, 64,
8, 8,
16, 16,
4, 4,
4, 2,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, true,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,

View File

@@ -55,7 +55,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>;
// clang-format on
#include "run_splitK_gemm_example.inc"

View File

@@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
#include "run_splitK_gemm_example.inc"

View File

@@ -54,14 +54,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 4, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 2>;
// clang-format on
#include "run_splitK_gemm_example.inc"
int main(int argc, char* argv[])
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
{
return 0;
}

View File

@@ -52,7 +52,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 32, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>;
// clang-format on
#include "run_splitK_gemm_example.inc"

View File

@@ -64,7 +64,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>;
// clang-format on
#else
@@ -85,7 +85,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlS
int main(int argc, char* argv[])
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
{
return 0;
}

View File

@@ -26,7 +26,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple<BiasDataType>, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>;
< NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple<BiasDataType>, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 128, 64, 16, 16, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
#include "run_grouped_conv_bwd_data_bias_relu_example.inc"

View File

@@ -31,7 +31,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| Scheduler| Type| Type|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>, BlkGemmPipeSched,BlkGemmPipelineVer, AComputeType, BComputeType , false , false>;
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>, BlkGemmPipeSched,BlkGemmPipelineVer, AComputeType, BComputeType , false , false>;
// clang-format on
#include "run_grouped_conv_bwd_data_example.inc"

View File

@@ -26,7 +26,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>;
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 128, 64, 16, 16, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
#include "run_grouped_conv_bwd_data_example.inc"

View File

@@ -30,7 +30,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| Scheduler| Type| Type|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, AComputeType, BComputeType>;
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 64, 16, 16, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, AComputeType, BComputeType>;
// clang-format on
#include "run_grouped_conv_bwd_data_example.inc"

View File

@@ -8,6 +8,5 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1")
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
endif()
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)

View File

@@ -77,11 +77,11 @@ using DeviceBatchedGemmGemmInstance =
4, // AK1
4, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
16, // MPerXDL
16, // NPerXDL
2, // MXdlPerWave
8, // NXdlPerWave
8, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
@@ -106,13 +106,13 @@ using DeviceBatchedGemmGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 16, 1, 16>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
2>; // CShuffleBlockTransferScalarPerVector_NPerBlock
#include "run_grouped_conv_conv_fwd_example.inc"
int main(int argc, char* argv[])
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
{
return 0;
}

View File

@@ -106,7 +106,7 @@ using DeviceBatchedGemmGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
#include "run_grouped_conv_conv_fwd_example.inc"

View File

@@ -72,7 +72,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>;
// clang-format on
int main(int argc, char* argv[])

View File

@@ -72,7 +72,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>;
// clang-format on
int main(int argc, char* argv[])

View File

@@ -71,7 +71,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>;
// clang-format on
int main(int argc, char* argv[])

View File

@@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance =
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
128, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma

View File

@@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance =
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
64, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl

View File

@@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance =
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
128, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma

View File

@@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance =
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
128, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma

View File

@@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance =
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
128, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma

View File

@@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance =
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
128, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma

View File

@@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance =
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
64, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl

View File

@@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance =
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
64, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl

View File

@@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance =
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
64, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl

View File

@@ -55,7 +55,7 @@ using DeviceGroupedConvNDFwdInstance =
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
64, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl

View File

@@ -53,7 +53,7 @@ using DeviceGroupedConvNDFwdInstance =
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
128, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma

View File

@@ -54,7 +54,7 @@ using DeviceGroupedConvNDFwdInstance =
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
64, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl

View File

@@ -48,7 +48,7 @@ using DeviceGroupedConvNDFwdInstance =
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
128, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma

View File

@@ -49,7 +49,7 @@ using DeviceGroupedConvNDFwdInstance =
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
64, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl

View File

@@ -49,7 +49,7 @@ using DeviceGroupedConvNDFwdInstance =
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
64, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl

View File

@@ -52,7 +52,7 @@ using DeviceGroupedConvNDFwdInstance =
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
128, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma

View File

@@ -53,7 +53,7 @@ using DeviceGroupedConvNDFwdInstance =
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
64, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl

View File

@@ -26,9 +26,9 @@ using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<D
int main(int argc, char* argv[])
{
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
if(ck::is_gfx11_supported() || ck::is_gfx120_supported())
{
std::cout << "FP32 are not supported on gfx11 and gfx12" << std::endl;
std::cout << "FP32 are not supported on gfx11 and gfx120x" << std::endl;
return 0;
}

View File

@@ -44,6 +44,7 @@ static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// static constexpr auto KPerBlock = sizeof(InDataType) == 1 ? 64 : 32;
#ifdef EXAMPLE_USE_WMMA
template <typename DataType,
@@ -68,32 +69,32 @@ using DeviceGroupedConvNDMultiABFwdInstance =
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MWmmaPerWave
4, // NWmmaPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
sizeof(DataType) == 1 ? 64 : 32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MWmmaPerWave
4, // NWmmaPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
@@ -123,33 +124,33 @@ using DeviceGroupedConvNDMultiABFwdInstance =
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
sizeof(DataType) == 1 ? 64 : 32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,

View File

@@ -20,7 +20,8 @@ add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_bl
add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp)
add_example_executable(example_moe_gemm1_xdl_fp8_blockscale_splitk moe_gemm1_xdl_fp8_blockscale_splitk.cpp)
list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx11-generic gfx12-generic)
list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx11-generic gfx12-generic gfx1250)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)

View File

@@ -49,8 +49,6 @@ using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;
static constexpr int KPack = 8;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
@@ -58,6 +56,13 @@ using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr int KPerBlock = 64;
#if defined(CK_USE_GFX1250)
static constexpr int KPack = 16;
#else
static constexpr int KPack = 8;
#endif
static constexpr auto K0 = KPerBlock / KPack;
// clang-format off
using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle<
@@ -65,12 +70,12 @@ using DeviceOpInstance =
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
128,
32, 128, 128,
8, 8,
32, 128, KPerBlock,
KPack, KPack,
16, 16,
2, 2,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<K0, 128 / K0, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<K0, 128 / K0, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1, 1, S<1, 16, 1, 8>, S<4, 4, 1>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1,

View File

@@ -50,8 +50,6 @@ using D1Layout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
static constexpr int KPack = 16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
@@ -63,6 +61,13 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr ck::index_t Scale_Block_M = 1;
static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128;
static constexpr int KPerBlock = 128;
#if defined(CK_USE_GFX1250)
static constexpr int KPack = 32;
#else
static constexpr int KPack = 16;
#endif
static constexpr auto K0 = KPerBlock / KPack;
using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle
@@ -71,13 +76,13 @@ using DeviceOpInstance =
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
128, 128, 128,
16, 16,
128, 128, KPerBlock,
KPack, KPack,
16, 16,
4, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
S<K0, 256 / K0, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
S<K0, 256 / K0, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
1, 1,
S<1, 32, 1, 8>, S<8>,

View File

@@ -34,10 +34,9 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8;
using B0DataType = F8;
static constexpr int KPack = 16;
using ComputeType = F8;
using A0DataType = F8;
using B0DataType = F8;
using ComputeType = F8;
using AccDataType = F32;
using CShuffleDataType = F32;
@@ -60,7 +59,13 @@ using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr int KPerBlock = 256;
#if defined(CK_USE_GFX1250)
static constexpr int KPack = 32;
#else
static constexpr int KPack = 16;
#endif
static constexpr auto K0 = KPerBlock / KPack;
// clang-format off
using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle<
@@ -68,12 +73,12 @@ using DeviceOpInstance =
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256,
32, 128, 256,
16, 16,
32, 128, KPerBlock,
KPack, KPack,
16, 16,
2, 1,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<K0, 256 / K0, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<K0, 256 / K0, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1,

View File

@@ -74,7 +74,14 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto PipelineVer = []() {
#if defined(CK_USE_WMMA) && !defined(CK_USE_GFX1250)
return ck::BlockGemmPipelineVersion::v1;
#else
return ck::BlockGemmPipelineVersion::v3;
#endif
}();
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// clang-format off
@@ -88,7 +95,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2, S<1, 16, 1, 16>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, FP8>;
// clang-format on
int main(int argc, char* argv[])

View File

@@ -114,7 +114,14 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto PipelineVer = []() {
#if defined(CK_USE_WMMA) && !defined(CK_USE_GFX1250)
return ck::BlockGemmPipelineVersion::v1;
#else
return ck::BlockGemmPipelineVersion::v3;
#endif
}();
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// clang-format off
@@ -135,7 +142,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 32, 1, 8>, S<4, 4, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, I8>;
ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, I8>;
// clang-format on
int main(int argc, char* argv[])

View File

@@ -6,9 +6,15 @@ add_custom_target(example_gemm_mx)
add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8)
add_example_executable(example_gemm_mx_fp8_v1 gemm_mx_fp8_v1.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_v1)
add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_bf8)
add_example_executable(example_gemm_mx_fp8_bpreshuffle gemm_mx_fp8_bpreshuffle.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bpreshuffle)
# TODO: Fix RRR
# add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8)
@@ -63,9 +69,81 @@ example_compile_options(example_moe_gemm2_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_M
set(FP8_MXGEMM_OPTIONS)
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_fp8_v1 PRIVATE ${FP8_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_fp8_bpreshuffle PRIVATE ${FP8_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS})
set(FP6_MXGEMM_OPTIONS)
list(APPEND FP6_MXGEMM_OPTIONS -mavx512f)
example_compile_options(example_gemm_mx_fp6 PRIVATE ${FP6_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_bf6 PRIVATE ${FP6_MXGEMM_OPTIONS})
function(add_gemm_mix_prec A_NAME B_NAME A_TYPE B_TYPE)
add_example_executable(example_gemm_mx_${A_NAME}_${B_NAME} gemm_mx_fp4.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_${A_NAME}_${B_NAME})
add_example_executable(example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle gemm_mx_fp4_bpreshuffle.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle)
example_compile_options(example_gemm_mx_${A_NAME}_${B_NAME} PRIVATE ${FP4_MXGEMM_OPTIONS})
target_compile_definitions(example_gemm_mx_${A_NAME}_${B_NAME} PRIVATE A_DATATYPE=${A_TYPE})
target_compile_definitions(example_gemm_mx_${A_NAME}_${B_NAME} PRIVATE B_DATATYPE=${B_TYPE})
example_compile_options(example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
target_compile_definitions(example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE A_DATATYPE=${A_TYPE})
target_compile_definitions(example_gemm_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE B_DATATYPE=${B_TYPE})
endfunction(add_gemm_mix_prec)
function(add_moe_mix_prec A_NAME B_NAME A_TYPE B_TYPE)
add_example_executable(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns moe_gemm1_xdl_mx_fp4_bns.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns)
add_example_executable(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns moe_gemm2_xdl_mx_fp4_bns.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns)
add_example_executable(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME} moe_gemm1_xdl_mx_fp4.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME})
add_example_executable(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME} moe_gemm2_xdl_mx_fp4.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME})
add_example_executable(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle)
add_example_executable(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle)
# mx moe B no-shuffling + scale shuffling
example_compile_options(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE A_DATATYPE=${A_TYPE})
target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE B_DATATYPE=${B_TYPE})
target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE A_DATATYPE=${A_TYPE})
target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bns PRIVATE B_DATATYPE=${B_TYPE})
# mx moe B no-shuffling + scale shuffling (async loads)
example_compile_options(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME} PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME} PRIVATE ${FP4_MXGEMM_OPTIONS})
target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME} PRIVATE A_DATATYPE=${A_TYPE})
target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME} PRIVATE B_DATATYPE=${B_TYPE})
target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME} PRIVATE A_DATATYPE=${A_TYPE})
target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME} PRIVATE B_DATATYPE=${B_TYPE})
# mx moe B shuffling + scale shuffling (async loads)
example_compile_options(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE A_DATATYPE=${A_TYPE})
target_compile_definitions(example_moe_gemm1_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE B_DATATYPE=${B_TYPE})
target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE A_DATATYPE=${A_TYPE})
target_compile_definitions(example_moe_gemm2_xdl_mx_${A_NAME}_${B_NAME}_bpreshuffle PRIVATE B_DATATYPE=${B_TYPE})
endfunction(add_moe_mix_prec)
# mx mixed precsion
if(GPU_TARGETS MATCHES "gfx125")
add_gemm_mix_prec(fp4 fp8 F4 F8)
add_gemm_mix_prec(fp8 fp4 F8 F4)
add_moe_mix_prec(fp4 fp8 F4 F8)
add_moe_mix_prec(fp8 fp4 F8 F4)
add_moe_mix_prec(fp8 fp8 F8 F8)
endif()

View File

@@ -2,6 +2,7 @@
// SPDX-License-Identifier: MIT
#include "gemm_mx_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
using ADataType = ck::bf6x16_pk_t;
using BDataType = ck::bf6x16_pk_t;

View File

@@ -2,6 +2,7 @@
// SPDX-License-Identifier: MIT
#include "gemm_mx_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
using ADataType = ck::bf8_t;
using BDataType = ck::bf8_t;
@@ -44,14 +45,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
128, // BlockSize: Thread block size
128, // MPerBlock
32, // NPerBlock
64, // MPerBlock
64, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXDL
16, // NPerXDL
4, // MXdlPerWave
2, // MXdlPerWave
2, // NXdlPerWave
S<16, 8, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder

View File

@@ -9,7 +9,6 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/utility/data_type.hpp"
@@ -27,9 +26,10 @@ using ::ck::Tensor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using MFMA = ck::tensor_layout::gemm::MFMA;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using MFMA = ck::tensor_layout::gemm::MFMA;
using Bypass = ck::tensor_layout::BypassLayoutVerification;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
@@ -116,7 +116,7 @@ bool parse_cmd_args(int argc,
}
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
void preShuffleScaleBuffer_gfx950(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
@@ -126,8 +126,9 @@ void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, i
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// On gfx950, WarpSize=64:
// The 4 16x128 building blocks will be packed into 1 32x256
// The 8 16x16x128 mfma will be packed into 1 32x32x256
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
@@ -163,13 +164,74 @@ void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, i
}
}
void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl)
/**
* Pre-shuffle scale buffer for gfx1250 16x16x128 wmma scale instruction
*
* @tparam ScaleType Scale data type
* @tparam KStride Whether K is the leading dimension of the scale buffer
*/
template <typename ScaleType, ck::index_t ScaleBlockSize, bool KStride>
void preShuffleScaleBuffer_gfx1250(const ScaleType* src,
ScaleType* dst,
ck::index_t MN,
ck::index_t K)
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
static_assert(ScaleBlockSize == 32 && sizeof(ScaleType) == 1,
"wrong! only support 8-bit scale with ScaleBlockSize=32");
constexpr ck::index_t MPerXdlops = 16;
// constexpr ck::index_t NPerXdlops = 16;
constexpr ck::index_t KPerXdlops = 128;
int MNPack = 2; // 2 sets of scales in M/N direction
int KPack = 1; // 1 set of scales in K direction
int MNStep = MPerXdlops;
int KStep = KPerXdlops / ScaleBlockSize; // scales per thread
int K0 = K / KPack / KStep; // KRepeat - how many KStep blocks
// On gfx1250, WarpSize=32:
// -- The 2 16x128 building blocks will be packed into 1 32x128
// -- The 4 16x16x128 wmma will be packed into 1 32x32x128
// unfold the MN32xK(128/32) scale buffer
// 4 16 1 2
// To KStep -> MNStep -> KPack -> MNPack
// or ???
// 2 16 1 4
// MNPack -> MNStep -> KPack -> KStep
for(int mn = 0; mn < MN; ++mn)
{
int iMNRepeat = mn / (MNStep * MNPack); // i MNRepeat (MN block id)
int tempmn = mn % (MNStep * MNPack); // position in MN block
for(int k = 0; k < K; ++k)
{
int iKRepeat = k / (KStep * KPack); // i KRepeat
int tempk = k % (KStep * KPack); // position in KStep block
int outputIndex = (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) +
(iKRepeat * KStep * KPack) * (MNStep * MNPack) +
tempmn * (KStep * KPack) + tempk;
if constexpr(KStride)
dst[outputIndex] = src[mn * K + k];
else
dst[outputIndex] = src[k * MN + mn];
}
}
}
template <typename T>
void preShuffleBuffer(const T* src, T* dst, int N, int K, int NXdl)
{
const int KPack = 16;
const int NLane = NXdl;
const int KLane = ck::get_warp_size() / NLane;
const int K_pk = K / ck::packed_size_v<T>;
const int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
@@ -352,7 +414,6 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
a_m_k_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
break;
case 2:
a_m_k.GenerateTensorDistr(
float_distr{-2.0, 2.0}, ck::identity{}, std::minstd_rand(time(nullptr))); // R[-2,2]
@@ -369,12 +430,34 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
}
}
preShuffleScaleBuffer<ck::is_same_v<ALayout, Row>>(a_m_k_scale.mData.data(),
a_shuffled_scale.mData.data(),
Scale_Padded_M,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<BRefLayout, Col>>(
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
if(ck::get_warp_size() == 64)
{
preShuffleScaleBuffer_gfx950<ck::is_same_v<ALayout, Row>>(a_m_k_scale.mData.data(),
a_shuffled_scale.mData.data(),
Scale_Padded_M,
K / ScaleBlockSize);
preShuffleScaleBuffer_gfx950<ck::is_same_v<BRefLayout, Col>>(
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
}
else if(ck::get_warp_size() == 32)
{
preShuffleScaleBuffer_gfx1250<ck::e8m0_bexp_t, ScaleBlockSize, ck::is_same_v<ALayout, Row>>(
a_m_k_scale.mData.data(),
a_shuffled_scale.mData.data(),
Scale_Padded_M,
K / ScaleBlockSize);
preShuffleScaleBuffer_gfx1250<ck::e8m0_bexp_t,
ScaleBlockSize,
ck::is_same_v<BRefLayout, Col>>(
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
}
else
{
throw std::runtime_error("wrong! Scale pre-shuffle unsupported warp size");
}
if constexpr(BPreShuffle)
{
int NPerXdl = 16; // Fixed 16
@@ -459,7 +542,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(config.verbosity > 0)
{
std::cout << "Done." << std::endl;
std::cout << "\nDone." << std::endl;
std::cout << "Computing GEMM on host..." << std::endl;
}

View File

@@ -2,9 +2,20 @@
// SPDX-License-Identifier: MIT
#include "gemm_mx_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
using F4 = ck::f4x2_pk_t;
using F8 = ck::f8_t;
using ADataType = ck::f4x2_pk_t;
using BDataType = ck::f4x2_pk_t;
#if defined(A_DATATYPE)
using ADataType = A_DATATYPE;
#else
using ADataType = F4;
#endif
#if defined(B_DATATYPE)
using BDataType = B_DATATYPE;
#else
using BDataType = F4;
#endif
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t;
@@ -21,7 +32,8 @@ using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t DataPackedSize =
ck::packed_size_v<ADataType>; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2

View File

@@ -2,9 +2,21 @@
// SPDX-License-Identifier: MIT
#include "gemm_mx_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
using ADataType = ck::f4x2_pk_t;
using BDataType = ck::f4x2_pk_t;
using F4 = ck::f4x2_pk_t;
using F8 = ck::f8_t;
#if defined(A_DATATYPE)
using ADataType = A_DATATYPE;
#else
using ADataType = F4;
#endif
#if defined(B_DATATYPE)
using BDataType = B_DATATYPE;
#else
using BDataType = F4;
#endif
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t;
@@ -21,7 +33,8 @@ using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t DataPackedSize =
ck::packed_size_v<ADataType>; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
@@ -30,7 +43,7 @@ constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
// AB DataType: f4x2_pk_t
// Mathmatically, all numbers are represented as f4x2.
// Mathematically, all numbers are represented as f4x2.
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
@@ -47,24 +60,24 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
128, // MPerBlock
512, // NPerBlock
128, // BlockSize: Thread block size
64, // MPerBlock
64, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXDL
16, // NPerXDL
8, // MXdlPerWave
8, // NXdlPerWave
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
2, // MXdlPerWave
2, // NXdlPerWave
S<8, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<8, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
@@ -72,9 +85,9 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
16, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle
S<1, 8, 1, 32>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlockW
2, // CShuffleNXdlPerWavePerShuffle
S<1, 8, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA

View File

@@ -2,6 +2,7 @@
// SPDX-License-Identifier: MIT
#include "gemm_mx_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
using ADataType = ck::f6x16_pk_t;
using BDataType = ck::f6x16_pk_t;

View File

@@ -2,6 +2,7 @@
// SPDX-License-Identifier: MIT
#include "gemm_mx_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;

View File

@@ -2,6 +2,7 @@
// SPDX-License-Identifier: MIT
#include "gemm_mx_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::bf8_t;

View File

@@ -0,0 +1,101 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gemm_mx_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = MFMA;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256;
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XPackedDataType, // AScaleDataType
BDataType, // BDataType
XPackedDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
128, // MPerBlock
128, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXDL
16, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
int main(int argc, char* argv[])
{
return run_mx_gemm_example<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
XPackedDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
ScaleBlockSize>(argc, argv)
? 0
: -1;
}

View File

@@ -0,0 +1,101 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gemm_mx_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256;
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XPackedDataType, // AScaleDataType
BDataType, // BDataType
XPackedDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
64, // MPerBlock
128, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXDL
16, // NPerXDL
2, // MXdlPerWave
4, // NXdlPerWave
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
int main(int argc, char* argv[])
{
return run_mx_gemm_example<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
XPackedDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
ScaleBlockSize>(argc, argv)
? 0
: -1;
}

View File

@@ -1,47 +1,29 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "gemm_mx_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
using ::ck::DeviceMem;
using ::ck::HostTensorDescriptor;
using ::ck::Tensor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F4 = ck::f4x2_pk_t;
using F8 = ck::f8_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bypass = ck::tensor_layout::BypassLayoutVerification;
using A0DataType = F4;
#if defined(A_DATATYPE)
using A0DataType = A_DATATYPE;
#else
using A0DataType = F4;
#endif
#if defined(B_DATATYPE)
using B0DataType = B_DATATYPE;
#else
using B0DataType = F4;
#endif
using A1DataType = XPackedDataType;
using B0DataType = F4;
using B1DataType = XPackedDataType;
using EDataType = F16;
using AccDataType = F32;
@@ -89,67 +71,14 @@ struct MulABScaleExpertWeight
}
};
using CDEElementOp = MulABScaleExpertWeight;
// A, B Scale preshuffle
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 128;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr ck::index_t MPerBlock = 128;
@@ -157,6 +86,11 @@ static constexpr ck::index_t NPerBlock = 64;
static constexpr ck::index_t BlockSize = 256;
static constexpr bool MulRoutedWeight = true;
static constexpr ck::index_t ClusterLengths_BK0 =
ck::is_same_v<B0DataType, F4> && ck::is_same_v<A0DataType, F4> ? 8 : 4;
static constexpr ck::index_t ClusterLengths_N =
ck::is_same_v<B0DataType, F4> && ck::is_same_v<A0DataType, F4> ? 32 : 64;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
A0Layout, B0Layout, DsLayout, ELayout,
@@ -166,10 +100,10 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
MPerBlock, NPerBlock, KPerBlock,
16, 16,
16, 16,
4, 2,
2, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
S<ClusterLengths_BK0, ClusterLengths_N, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 32, 1, 8>, S<4, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3,
ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
@@ -373,6 +307,9 @@ int main(int argc, char* argv[])
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize());
// a0_t_k.savetxt("a.txt", "float", 128);
// b0_e_n_k.savetxt("b.txt", "float", 128);
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{
@@ -392,14 +329,30 @@ int main(int argc, char* argv[])
}
// A/B scale shuffle
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
b_scale_preshuffled.mData.data(),
N * 2 * experts,
K / ScaleBlockSize);
if(ck::is_gfx125_supported())
{
preShuffleScaleBuffer_gfx1250<XDataType, ScaleBlockSize, ck::is_same_v<A0Layout, Row>>(
a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer_gfx1250<XDataType, ScaleBlockSize, ck::is_same_v<B0Layout, Col>>(
b1_e_n_k.mData.data(),
b_scale_preshuffled.mData.data(),
N * 2 * experts,
K / ScaleBlockSize);
}
else
{
preShuffleScaleBuffer_gfx950<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer_gfx950<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
b_scale_preshuffled.mData.data(),
N * 2 * experts,
K / ScaleBlockSize);
}
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
@@ -452,9 +405,10 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" ||
ck::is_gfx125_supported()))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
std::cout << "This kernel support gfx942, gfx950 and gfx125x only" << std::endl;
}
if(time_kernel)

Some files were not shown because too many files have changed in this diff Show More