Merge commit 'b0ee317d83b77741022997265d4125697e7f7804' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-12 20:11:58 +00:00
parent facbc883fa
commit 302aa809ea
65 changed files with 2301 additions and 232 deletions

View File

@@ -1,10 +1,8 @@
ARG BASE_DOCKER="rocm/pytorch:latest"
ARG BASE_DOCKER="rocm/composable_kernel-private:ck_aiter_base"
FROM $BASE_DOCKER
ARG AITER_BRANCH="main"
ARG CK_AITER_BRANCH="develop"
RUN groupadd -g 109 render && \
usermod -u 1001 jenkins && \
groupmod -g 1001 jenkins && \
RUN groupadd irc && \
pip install pandas zmq einops && \
pip install numpy==1.26.2 && \
sudo mkdir /home/jenkins && \

46
Jenkinsfile vendored
View File

@@ -149,7 +149,7 @@ def getDockerImage(Map conf=[:]){
image = conf.get("docker_name", "")
echo "Using legacy docker: ${image}"
}
else if ( params.BUILD_GFX950 && conf.get("docker_name", "") != "" ){
else if ( (params.BUILD_GFX950 || params.RUN_CK_TILE_FMHA_TESTS) && conf.get("docker_name", "") != "" ){
image = conf.get("docker_name", "")
echo "Using special docker: ${image}"
}
@@ -186,11 +186,11 @@ def buildDocker(install_prefix){
dockerArgs = dockerArgs + " --no-cache --build-arg BASE_DOCKER='${base_image_name}' -f Dockerfile.compiler . "
}
else if(params.RUN_AITER_TESTS){
image_name = "rocm/composable_kernel:ck_aiter"
image_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_aiter"
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . "
}
else if(params.RUN_PYTORCH_TESTS){
image_name = "rocm/composable_kernel:ck_pytorch"
image_name = "${env.CK_DOCKERHUB}:ck_pytorch"
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.pytorch --build-arg CK_PYTORCH_BRANCH='${params.ck_pytorch_branch}' . "
}
else{
@@ -716,7 +716,7 @@ def process_results(Map conf=[:]){
env.HSA_ENABLE_SDMA=0
checkout scm
//use older image that has user jenkins
def image = "rocm/composable_kernel:ck_ub22.04_rocm6.3"
def image = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm6.3"
def prefixpath = "/opt/rocm"
// Jenkins is complaining about the render group
@@ -827,7 +827,7 @@ def run_aiter_tests(Map conf=[:]){
env.HSA_ENABLE_SDMA=0
checkout scm
//use the latest pytorch image
def image = "rocm/composable_kernel:ck_aiter"
def image = "${env.CK_DOCKERHUB_PRIVATE}:ck_aiter"
def dockerOpts="--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --user=jenkins -v=/var/jenkins/:/var/jenkins"
def variant = env.STAGE_NAME
def retimage
@@ -885,7 +885,7 @@ def run_pytorch_tests(Map conf=[:]){
env.HSA_ENABLE_SDMA=0
checkout scm
//use the latest pytorch-nightly image
def image = "rocm/composable_kernel:ck_pytorch"
def image = "${env.CK_DOCKERHUB}:ck_pytorch"
def dockerOpts="--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --user=jenkins -v=/var/jenkins/:/var/jenkins"
def variant = env.STAGE_NAME
def retimage
@@ -1207,6 +1207,18 @@ pipeline {
cleanWs()
}
}
stage("Run AITER Tests on gfx950")
{
when {
beforeAgent true
expression { params.RUN_AITER_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx950")}
steps{
run_aiter_tests()
cleanWs()
}
}
}
}
stage("Run Grouped Conv Large Case Tests")
@@ -1321,7 +1333,7 @@ pipeline {
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \
make -j128 tile_example_fmha_fwd tile_example_fmha_bwd && \
cd ../ &&
example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
}
@@ -1330,6 +1342,26 @@ pipeline {
cleanWs()
}
}
stage("Run CK_TILE_FMHA Tests on gfx950")
{
when {
beforeAgent true
expression { params.RUN_CK_TILE_FMHA_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx950") }
environment{
def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0"
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx950 && \
make -j128 tile_example_fmha_fwd tile_example_fmha_bwd && \
cd ../ &&
example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx950 """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, docker_name: docker_name, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
}
}
stage("Run TILE_ENGINE_GEMM Tests")

View File

@@ -27,3 +27,16 @@ add_example_executable(example_gemm_xdl_splitk_reduce_multi_d_bf16 gemm_xdl_spli
add_example_executable(example_gemm_xdl_splitk_reduce_bf16A_i8B gemm_xdl_splitk_reduce_bf16A_i8B.cpp)
add_example_executable(example_gemm_xdl_splitk_reduce_bfp16 gemm_xdl_splitk_reduce_bf16.cpp)
add_custom_target(example_splitK_gemm_wmma)
add_example_executable(example_gemm_wmma_splitk_reduce_bf16 gemm_wmma_splitk_reduce_bf16.cpp)
add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_bf16)
add_example_executable(example_gemm_wmma_splitk_reduce_bf16A_i8B gemm_wmma_splitk_reduce_bf16A_i8B.cpp)
add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_bf16A_i8B)
add_example_executable(example_gemm_wmma_splitk_reduce_multi_d_bf16 gemm_wmma_splitk_reduce_multi_d_bf16.cpp)
add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_multi_d_bf16)
add_example_executable(example_gemm_wmma_splitk_reduce_multi_d_fp16 gemm_wmma_splitk_reduce_multi_d_fp16.cpp)
add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_multi_d_fp16)

View File

@@ -99,3 +99,85 @@ bool parse_cmd_args(int argc,
return true;
}
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}

View File

@@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using AccDataType = float;
using CShuffleDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using ReduceDataType = ck::bhalf_t;
using D0DataType = ck::bhalf_t;
using DsDataType = ck::Tuple<>;
using ALayout = Row;
using BLayout = Row;
using CLayout = Row;
using D0Layout = CLayout;
using DsLayout = ck::Tuple<>;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
using DeviceWmmaGemmInstance =
ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1<
ALayout, BLayout, DsLayout, CLayout,
ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmDefault,
256,
128, 128, 32,
8, 8,
16, 16,
4, 2,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, true,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, true,
1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1, ReduceDataType>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
#include "run_gemm_wmma_splitk_reduce_example.inc"
int main(int argc, char* argv[]) { return !run_wmma_gemm_splitk_example(argc, argv); }

View File

@@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
using ADataType = ck::bhalf_t;
using BDataType = int8_t;
using AccDataType = float;
using CShuffleDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using ReduceDataType = float;
using D0DataType = ck::bhalf_t;
using DsDataType = ck::Tuple<>;
using ALayout = Row;
using BLayout = Row;
using CLayout = Row;
using D0Layout = Row;
using DsLayout = ck::Tuple<>;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
using DeviceWmmaGemmInstance =
ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1<
ALayout, BLayout, DsLayout, CLayout,
ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmDefault,
256,
128, 128, 32,
8, 8,
16, 16,
4, 2,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, true,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, true,
1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1, ReduceDataType>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
#include "run_gemm_wmma_splitk_reduce_example.inc"
int main(int argc, char* argv[]) { return !run_wmma_gemm_splitk_example(argc, argv); }

View File

@@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using AccDataType = float;
using CShuffleDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using ReduceDataType = float;
using D0DataType = ck::bhalf_t;
using DsDataType = ck::Tuple<D0DataType>;
using ALayout = Row;
using BLayout = Row;
using CLayout = Row;
using D0Layout = CLayout;
using DsLayout = ck::Tuple<D0Layout>;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1<
ALayout, BLayout, DsLayout, CLayout,
ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmDefault,
256,
128, 128, 32,
8, 8,
16, 16,
4, 2,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, true,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, true,
1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1, ReduceDataType>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
#include "run_gemm_wmma_splitk_reduce_multi_d_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_splitk_multi_d_example(argc, argv); }

View File

@@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t;
using ReduceDataType = float;
using D0DataType = ck::half_t;
using DsDataType = ck::Tuple<D0DataType>;
using ALayout = Row;
using BLayout = Row;
using CLayout = Row;
using D0Layout = CLayout;
using DsLayout = ck::Tuple<D0Layout>;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1<
ALayout, BLayout, DsLayout, CLayout,
ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmDefault,
256,
128, 256, 64,
8, 8,
16, 16,
4, 4,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, true,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, true,
1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1, ReduceDataType>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
#include "run_gemm_wmma_splitk_reduce_multi_d_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_splitk_multi_d_example(argc, argv); }

View File

@@ -3,88 +3,6 @@
#pragma once
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{

View File

@@ -0,0 +1,191 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename ProblemType>
bool run_wmma_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto M = problem_size.M;
auto N = problem_size.N;
auto K = problem_size.K;
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto KBatch = problem_size.KBatch;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 3:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "init method: " << config.init_method << std::endl;
std::cout << "KBatch: " << KBatch << std::endl;
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// device GEMM
auto device_op = DeviceWmmaGemmInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 0>{}, // empty D tensors
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 0>{}, // empty D strides
StrideC,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
// Allocate workspace for split-K reduction if needed
size_t workspace_size = device_op.GetWorkSpaceSize(argument.get());
DeviceMem workspace_buf(workspace_size);
std::cout << "Workspace size: " << workspace_size << " bytes" << std::endl;
if(workspace_size > 0)
{
argument->p_workspace_ = workspace_buf.GetDeviceBuffer();
std::cout << "Allocated workspace of size: " << workspace_size << " bytes" << std::endl;
}
if(!device_op.IsSupportedArgument(argument.get()))
{
std::cout << "The runtime argument is not supported!" << std::endl;
std::cout << "Debug info:" << std::endl;
std::cout << " M=" << M << ", N=" << N << ", K=" << K << ", KBatch=" << KBatch
<< std::endl;
std::cout << " StrideA=" << StrideA << ", StrideB=" << StrideB << ", StrideC=" << StrideC
<< std::endl;
return false;
}
bool pass = true;
float ave_time = 0;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, false});
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass = ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
}
if(config.time_kernel)
{
ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, config.time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E12 / ave_time;
float gb_per_sec = num_btype / 1.E9 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;
}
return pass;
}
bool run_wmma_gemm_splitk_example(int argc, char* argv[])
{
ProblemSizeSplitK problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_wmma_gemm(problem_size, config);
}

View File

@@ -0,0 +1,214 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename ProblemSize>
bool run_wmma_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto M = problem_size.M;
auto N = problem_size.N;
auto K = problem_size.K;
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto StrideD0 = problem_size.StrideC;
auto KBatch = problem_size.KBatch;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
StrideD0 = f_get_default_stride(M, N, StrideD0, D0Layout{});
Tensor<ADataType> a_m_k(
f_host_tensor_descriptor(problem_size.M, problem_size.K, problem_size.StrideA, ALayout{}));
Tensor<BDataType> b_k_n(
f_host_tensor_descriptor(problem_size.K, problem_size.N, problem_size.StrideB, BLayout{}));
Tensor<D0DataType> d0_m_n(
f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, D0Layout{}));
switch(config.init_method)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
d0_m_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
break;
case 3:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
d0_m_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
}
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "init method: " << config.init_method << std::endl;
std::cout << "KBatch: " << KBatch << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
d0_m_n_device_buf.ToDevice(d0_m_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CDEElementOp{};
// do GEMM
auto gemm = DeviceGemmV2Instance{};
auto invoker = gemm.MakeInvoker();
constexpr auto kNum_DTensors = DsDataType::Size();
const std::array<const void*, kNum_DTensors> p_ds = {d0_m_n_device_buf.GetDeviceBuffer()};
const std::array<ck::index_t, kNum_DTensors> d_strides = {problem_size.StrideC};
auto argument =
gemm.MakeArgumentPointer(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
p_ds,
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
problem_size.M,
problem_size.N,
problem_size.K,
problem_size.StrideA,
problem_size.StrideB,
d_strides,
problem_size.StrideC,
problem_size.KBatch,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument.get()))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return false;
}
auto workspace_size = gemm.GetWorkSpaceSize(argument.get());
DeviceMem workspace_device_buf(workspace_size);
std::cout << "Workspace size: " << workspace_size << " bytes" << std::endl;
std::cout << "Allocated workspace of size: " << workspace_size << " bytes" << std::endl;
if(workspace_size > 0)
{
argument->p_workspace_ = workspace_device_buf.GetDeviceBuffer();
}
if(config.do_verification)
{
using ReferenceGemmInstanceMultiD = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstanceMultiD{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
c_m_n_host_result.ForEach(
[&](auto& self, auto idx) { c_element_op(self(idx), self(idx), d0_m_n(idx)); });
}
std::cout << "init method: " << config.init_method << std::endl;
std::cout << "KBatch: " << problem_size.KBatch << std::endl;
float ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, config.time_kernel});
std::size_t flop = std::size_t(2) * problem_size.M * problem_size.N * problem_size.K;
std::size_t num_btype = sizeof(ADataType) * problem_size.M * problem_size.K +
sizeof(BDataType) * problem_size.K * problem_size.N +
sizeof(CDataType) * problem_size.M * problem_size.N +
sizeof(D0DataType) * problem_size.M * problem_size.N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
double rtol = get_rtol<CDataType>();
double atol = get_atol<CDataType>();
return ck::utils::check_err(
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", rtol, atol);
}
return true;
}
int run_gemm_splitk_multi_d_example(int argc, char* argv[])
{
ProblemSizeSplitK problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_wmma_gemm(problem_size, config);
}

View File

@@ -26,7 +26,7 @@ endforeach()
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
list(PREPEND FMHA_FWD_ENABLE_APIS "fwd")
endif()
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
@@ -51,6 +51,15 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS
# --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd...
)
# Reduce building time by disabling instances that are not currently used in the gtests
# TODO: Consider to use a special receipt for testing only, or even two receipts: a small subset of
# instances for quick CI runs and a larger subset for scheduled runs (the tests skip tests when
# there is no corresponding instance for parameters).
if(BUILD_TESTING)
# Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv)
endif()
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}

View File

@@ -181,15 +181,15 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
if(ck_tile::is_gfx12_supported())
{
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
constexpr int divisor = 2;
constexpr int kABK0PerLane = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
kABK0PerLane,
GemmConfig::K_Warp_Tile / divisor / kABK0PerLane});
divisor,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}

View File

@@ -314,15 +314,15 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
if(ck_tile::is_gfx12_supported())
{
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
constexpr int divisor = 2;
constexpr int kABK0PerLane = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
kABK0PerLane,
GemmConfig::K_Warp_Tile / divisor / kABK0PerLane});
divisor,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}

View File

@@ -45,15 +45,15 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
if(ck_tile::is_gfx12_supported())
{
// TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase<gfx12>
constexpr int divisor = 2;
constexpr int kABK0PerLane = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = FlatmmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
kABK0PerLane,
FlatmmConfig::K_Warp_Tile / divisor / kABK0PerLane});
divisor,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}

View File

@@ -129,5 +129,10 @@ inline bool is_gfx103_supported()
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
}
inline bool is_wmma_supported()
{
return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported();
}
} // namespace ck
#endif

View File

@@ -0,0 +1,562 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include <type_traits>
#include <typeinfo>
#include <memory>
#include <array>
#include <stdexcept>
#include "ck/utility/common_header.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ReduceDataType = CDataType,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
Tuple<>,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
ReduceDataType,
Tuple<>,
ReduceDataType,
AElementwiseOperation,
BElementwiseOperation,
PassThrough,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
false,
false>;
struct Argument : public GridwiseGemm::Argument
{
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
const ::std::array<const void*, NumDTensor> p_ds_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
const ::std::array<index_t, NumDTensor> stride_ds_,
index_t StrideC_,
index_t KBatch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: GridwiseGemm::Argument(p_a_grid_,
p_b_grid_,
::std::array<const void*, 0>{},
reinterpret_cast<ReduceDataType*>(p_c_grid_),
M_,
N_,
K_,
StrideA_,
StrideB_,
std::array<index_t, 0>{},
StrideC_,
KBatch_,
a_element_op_,
b_element_op_,
PassThrough{},
true),
p_c_grid(p_c_grid_),
c_element_op(c_element_op_),
p_ds(p_ds_),
StrideDs(stride_ds_)
{
}
CDataType* p_c_grid;
CElementwiseOperation c_element_op;
const ::std::array<const void*, NumDTensor> p_ds;
::std::array<index_t, NumDTensor> StrideDs;
};
using ReduceAdd = ck::reduce::Add;
using OutElementwiseOperation = CElementwiseOperation;
static constexpr auto DsVectorLengthSequence = generate_sequence_v2(
[](auto i) {
using DLayout = ::std::__remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(is_same<CLayout, DLayout>::value)
return Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{};
else
return Number<1>{};
},
Number<NumDTensor>{});
using DeviceReduceInstance = DeviceReduceThreadWiseMultiD<
ReduceDataType, // InDataType
DsDataType, // DsDatatype
GemmAccDataType, // AccDataType
CDataType, // OutDataType
3, // Rank
1, // NumReduceDim
ReduceAdd,
PassThrough,
OutElementwiseOperation,
256, // BlockSize_
CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_
1, // KThreadSliceSize_
0, // InSrcVectorDim_
CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_
CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_
decltype(DsVectorLengthSequence)>;
struct Invoker : public BaseInvoker
{
float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
static constexpr index_t NumInDim = 3;
static constexpr index_t NumOutDim = 2;
::std::array<index_t, NumInDim> in_lengths = {arg.KBatch, arg.M, arg.N};
::std::array<index_t, NumOutDim> out_lengths = {arg.M, arg.N};
::std::array<index_t, NumInDim> in_strides;
::std::array<index_t, NumOutDim> out_strides;
if constexpr(is_same<CLayout, ck::tensor_layout::gemm::RowMajor>::value)
{
in_strides = {arg.M * arg.N, arg.N, 1};
out_strides = {arg.N, 1};
}
else
{
in_strides = {arg.M * arg.N, 1, arg.M};
out_strides = {1, arg.M};
}
::std::array<int, 1> reduce_dims{0};
::std::array<::std::array<index_t, NumOutDim>, NumDTensor> DsLengths;
::std::array<::std::array<index_t, NumOutDim>, NumDTensor> DsStrides;
static_for<0, NumDTensor, 1>{}([&](auto i) {
DsLengths[i] = out_lengths;
using DLayout = ::std::__remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(is_same<DLayout, ck::tensor_layout::gemm::RowMajor>::value)
{
DsStrides[i] = {arg.StrideDs[i], 1};
}
else
{
DsStrides[i] = {1, arg.StrideDs[i]};
}
});
auto reduce = DeviceReduceInstance{};
auto argument_ptr = reduce.MakeArgumentPointer(in_lengths,
in_strides,
DsLengths,
DsStrides,
out_lengths,
out_strides,
reduce_dims,
arg.p_workspace_,
arg.p_ds,
arg.p_c_grid,
PassThrough{},
OutElementwiseOperation{});
auto invoker_ptr = reduce.MakeInvokerPointer();
float ave_time = 0;
if(reduce.IsSupportedArgument(argument_ptr.get()))
{
ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
}
else
{
throw ::std::runtime_error(
"The runtime parameters are not supported by the device instance.");
}
return ave_time;
}
float Run(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{})
{
auto arg = *dynamic_cast<const typename GridwiseGemm::Argument*>(&arg_);
// workspace required when doing two-kernel reduce or Ds present
const bool need_workspace = !(!(arg.IsReduceAdd() || NumDTensor > 0) &&
is_same<CDataType, ReduceDataType>::value);
if(need_workspace)
{
if(arg.p_workspace_ == nullptr)
{
throw ::std::runtime_error("using reduce, but empty workspace!");
}
arg.p_e_grid = reinterpret_cast<ReduceDataType*>(arg.p_workspace_);
}
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw ::std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
::std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
if(has_main_k_block_loop)
{
const auto kernel =
::ck::kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
ave_time = launch_and_time_kernel(
stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg);
}
else
{
const auto kernel =
::ck::kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
ave_time = launch_and_time_kernel(
stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg);
}
if(need_workspace)
{
ave_time += RunReduce(arg_, stream_config);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_wmma_supported())
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(
*dynamic_cast<const typename GridwiseGemm::Argument*>(&arg));
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
{
return GridwiseGemm::CalculateGridSize(M, N, KBatch);
}
static constexpr index_t GetBlockSize() { return BlockSize; }
static size_t GetSharedMemoryNumberOfByte()
{
return GridwiseGemm::GetSharedMemoryNumberOfByte();
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const ::std::array<const void*, NumDTensor> p_ds,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
const ::std::array<index_t, NumDTensor> stride_ds,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_c,
M,
N,
K,
StrideA,
StrideB,
stride_ds,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
::std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return ::std::make_unique<Invoker>(Invoker{});
}
// Polymorphic interfaces
::std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
::std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
::std::array<index_t, NumDTensor> DsStrides,
index_t StrideC,
index_t KSplit,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return ::std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
DsStrides,
StrideC,
KSplit,
a_element_op,
b_element_op,
c_element_op);
}
::std::string GetTypeString() const override
{
auto str = ::std::stringstream();
auto BlkGemmPipelineSchedulerToString = [](BlockGemmPipelineScheduler s) {
switch(s)
{
case BlockGemmPipelineScheduler::Intrawave: return ::std::string("Intrawave");
case BlockGemmPipelineScheduler::Interwave: return ::std::string("Interwave");
}
return ::std::string("?");
};
auto BlkGemmPipelineVersionToString = [](BlockGemmPipelineVersion v) {
switch(v)
{
case BlockGemmPipelineVersion::v1: return ::std::string("v1");
case BlockGemmPipelineVersion::v2: return ::std::string("v2");
case BlockGemmPipelineVersion::v3: return ::std::string("v3");
case BlockGemmPipelineVersion::v4: return ::std::string("v4");
case BlockGemmPipelineVersion::v5: return ::std::string("v5");
}
return ::std::string("v?");
};
// clang-format off
str << "DeviceGemmWmmaUniversalReduce"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< ::std::string(ALayout::name)[0]
<< ::std::string(BLayout::name)[0]
<< ::std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WmmaTile: "
<< MPerWmma<<"x"<<NPerWmma << ", "
<< "WmmaRepeat: "
<< MRepeat<<"x" << NRepeat<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString(BlkGemmPipeSched) << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString(BlkGemmPipelineVer) << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
// Need workspace if using split-K or have D tensors
if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && is_same<CDataType, ReduceDataType>::value))
{
return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType);
}
return 0;
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -3,6 +3,11 @@
#pragma once
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
#include <iostream>
#include <ostream>
#endif
#include "ck/utility/env.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
@@ -1049,6 +1054,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
{
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop
<< ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages
<< ") for pipeline version != v1." << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
return false;
}
}

View File

@@ -70,9 +70,9 @@ struct WmmaTraitsBase<gfx12_t, ADType, BDType, CDType>
static constexpr index_t kRepeat = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABK0PerLane = 2;
static constexpr index_t kABK0PerLane = 1;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABK1PerLane = 4;
static constexpr index_t kABK1PerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 16;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -8,6 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
@@ -20,6 +21,7 @@ namespace instance {
using DsLayout = ck::Tuple<>;
using DsDataType = ck::Tuple<>;
#ifdef CK_USE_XDL
#ifdef CK_ENABLE_FP16
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
@@ -326,7 +328,54 @@ void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadd
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#endif
#ifdef CK_USE_WMMA
#if defined(CK_ENABLE_FP16)
void add_device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
void add_device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if defined(CK_ENABLE_BF16)
void add_device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#endif
template <typename ADataType,
@@ -373,6 +422,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
#ifdef CK_USE_XDL
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
@@ -395,6 +445,12 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
#endif
#ifdef CK_USE_WMMA
add_device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
op_ptrs);
#endif
}
}
#endif
@@ -406,6 +462,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
#ifdef CK_USE_XDL
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances(
@@ -420,6 +477,12 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
#endif
#ifdef CK_USE_WMMA
add_device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
op_ptrs);
#endif
}
}
#endif
@@ -430,6 +493,7 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
#ifdef CK_USE_XDL
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
@@ -444,6 +508,12 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
#endif
#ifdef CK_USE_WMMA
add_device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
op_ptrs);
#endif
}
}
#endif

View File

@@ -1,6 +1,7 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(GEMM_UNIVERSAL_REDUCE_INSTANCES)
# XDL instances
list(APPEND GEMM_UNIVERSAL_REDUCE_INSTANCES
device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp
@@ -30,4 +31,11 @@ list(APPEND GEMM_UNIVERSAL_REDUCE_INSTANCES
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
)
# WMMA instances
list(APPEND GEMM_UNIVERSAL_REDUCE_INSTANCES
device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp
device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp
device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp
)
add_instance_library(device_gemm_universal_reduce_instance ${GEMM_UNIVERSAL_REDUCE_INSTANCES})

View File

@@ -0,0 +1,72 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using DsLayout = ck::Tuple<>;
using DsDataType = ck::Tuple<>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <GemmSpecialization GemmSpec,
typename DsLayout = ck::Tuple<>,
typename DsDataType = ck::Tuple<>>
using device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances =
std::tuple<
// clang-format off
//#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPerWmma|NPerWmma|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Reduce|
//#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | | | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MRepeatPer|NRepeatPer| _MBlock_MRepeatPerShuffle_MWaveM| ScalarPerVector| Pipeline| Pipeline| DataType|
//#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Shuffle | Shuffle | PerShuffle_NBlock_NRepeatPerShuffle| _NPerBlock | Scheduler| Version| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NWaveNPerRepeat | | | | |
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = bhalf_t;
using Row = tensor_layout::gemm::RowMajor;
using PassThrough = element_wise::PassThrough;
void add_device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
if(ck::is_gfx12_supported())
{
add_device_operation_instances(
instances,
device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances<GemmDefault,
DsLayout,
DsDataType>{});
add_device_operation_instances(
instances,
device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances<GemmKPadding,
DsLayout,
DsDataType>{});
add_device_operation_instances(
instances,
device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances<GemmMNPadding,
DsLayout,
DsDataType>{});
add_device_operation_instances(
instances,
device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances<GemmMNKPadding,
DsLayout,
DsDataType>{});
}
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using DsLayout = ck::Tuple<>;
using DsDataType = ck::Tuple<>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <GemmSpecialization GemmSpec,
typename DsLayout = ck::Tuple<>,
typename DsDataType = ck::Tuple<>>
using device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances =
std::tuple<
// clang-format off
//#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPerWmma|NPerWmma|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Reduce|
//#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | | | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MRepeatPer|NRepeatPer| _MBlock_MRepeatPerShuffle_MWaveM| ScalarPerVector| Pipeline| Pipeline| DataType|
//#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Shuffle | Shuffle | PerShuffle_NBlock_NRepeatPerShuffle| _NPerBlock | Scheduler| Version| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NWaveNPerRepeat | | | | |
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 4, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 4, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 4, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using BF16 = bhalf_t;
using Row = tensor_layout::gemm::RowMajor;
using PassThrough = element_wise::PassThrough;
void add_device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
if(ck::is_gfx12_supported())
{
add_device_operation_instances(
instances,
device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances<GemmDefault,
DsLayout,
DsDataType>{});
add_device_operation_instances(
instances,
device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances<GemmKPadding,
DsLayout,
DsDataType>{});
add_device_operation_instances(
instances,
device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances<GemmMNPadding,
DsLayout,
DsDataType>{});
add_device_operation_instances(
instances,
device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances<GemmMNKPadding,
DsLayout,
DsDataType>{});
}
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,72 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = half_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using DsLayout = ck::Tuple<>;
using DsDataType = ck::Tuple<>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <GemmSpecialization GemmSpec,
typename DsLayout = ck::Tuple<>,
typename DsDataType = ck::Tuple<>>
using device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances =
std::tuple<
// clang-format off
//#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPerWmma|NPerWmma|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Reduce|
//#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | | | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MRepeatPer|NRepeatPer| _MBlock_MRepeatPerShuffle_MWaveM| ScalarPerVector| Pipeline| Pipeline| DataType|
//#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Shuffle | Shuffle | PerShuffle_NBlock_NRepeatPerShuffle| _NPerBlock | Scheduler| Version| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NWaveNPerRepeat | | | | |
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>,
DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,57 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = half_t;
using Row = tensor_layout::gemm::RowMajor;
using PassThrough = element_wise::PassThrough;
using Add = element_wise::Add;
using DsLayout_F16 = ck::Tuple<>;
using DsDataType_F16 = ck::Tuple<>;
void add_device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout_F16,
Row,
F16,
F16,
DsDataType_F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
if(ck::is_gfx12_supported())
{
add_device_operation_instances(
instances,
device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances<GemmDefault,
DsLayout_F16,
DsDataType_F16>{});
add_device_operation_instances(
instances,
device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances<GemmKPadding,
DsLayout_F16,
DsDataType_F16>{});
add_device_operation_instances(
instances,
device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances<GemmMNPadding>{});
add_device_operation_instances(
instances,
device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances<GemmMNKPadding>{});
}
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -10,6 +10,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp"
@@ -86,10 +87,21 @@ bool profile_gemm_universal_reduce_impl(int do_verification,
switch(init_method)
{
case 0: break;
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 3:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});

View File

@@ -68,7 +68,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_OPS profile_gemm_splitk.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp)
list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp)
@@ -90,6 +89,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
@@ -185,7 +185,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
@@ -221,6 +220,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_instance)
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)

View File

@@ -248,6 +248,7 @@ add_subdirectory(gemm_universal)
add_subdirectory(gemm_b_scale)
add_subdirectory(gemm_universal_streamk)
add_subdirectory(gemm_reduce)
add_subdirectory(gemm_universal_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)

View File

@@ -18,7 +18,7 @@ function(create_tile_add_rmsnorm2d_rdquant_fwd SUFFIX)
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
endfunction()
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
create_tile_add_rmsnorm2d_rdquant_fwd("fp16")
create_tile_add_rmsnorm2d_rdquant_fwd("bf16")
else()

View File

@@ -1,4 +1,3 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_ck_tile_batched_gemm test_batched_gemm.cpp)
endif()

View File

@@ -27,21 +27,41 @@ class TestCkTileBatchedGemm : public ::testing::Test
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
template <typename ALayout, typename BLayout, typename CLayout>
struct GemmWarpConfig_Mfma
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
struct GemmWarpConfig_Wmma
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename GemmWarpConfig, typename ALayout, typename BLayout, typename CLayout>
void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Tile = GemmWarpConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmWarpConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmWarpConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
constexpr bool DoubleSmemBuffer = false;
@@ -255,9 +275,13 @@ class TestCkTileBatchedGemm : public ::testing::Test
BatchStrideB,
BatchStrideC,
BatchCount};
invoke_batched_gemm<ALayout, BLayout, CLayout>(args,
ck_tile::stream_config{nullptr, false});
#if CK_TILE_USE_WMMA
invoke_batched_gemm<GemmWarpConfig_Wmma, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, false});
#else
invoke_batched_gemm<GemmWarpConfig_Mfma, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, false});
#endif
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideC =" << StrideC

View File

@@ -1,5 +1,4 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx950")
add_gtest_executable(test_ck_tile_batched_transpose test_batched_transpose.cpp)
set_property(TARGET test_ck_tile_batched_transpose PROPERTY CXX_STANDARD 20)
else()

View File

@@ -1,4 +1,4 @@
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_ck_tile_tuple_apply test_tuple_apply.cpp)
if(result EQUAL 0)
target_link_libraries(test_ck_tile_tuple_apply PRIVATE utility)

View File

@@ -1,4 +1,4 @@
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp)
endif()
if(GPU_TARGETS MATCHES "gfx95")

View File

@@ -1,4 +1,4 @@
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp)
if(result EQUAL 0)
target_link_libraries(test_ck_tile_elementwise_1d PRIVATE utility)

View File

@@ -106,7 +106,7 @@ class TestCkTileElementwise : public ::testing::Test
ck_tile::index_t grid_size =
(total_m_elements + TestElementWiseShape::kBlockM - 1) / TestElementWiseShape::kBlockM;
dim3 grid(grid_size, 1, 1);
dim3 block(TestElementWiseShape::kBlockSize, 1, 1);
dim3 block = dim3(ew_kernel.BlockSize());
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log

View File

@@ -401,7 +401,7 @@ TEST_P(PagedKV, Test)
0, // scale_s
0, // logits_soft_cap
is_v_rowmajor, // is_v_rowmajor
def_lse, // lse
false, // lse
page_block_size, // page_block_size
false, // use_cache_batch_idx
"n", // bias_str

View File

@@ -12,16 +12,16 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
-enable-noalias-to-md-conversion=0
)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12")
add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_universal_int8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_test_executable(test_ck_tile_gemm_pipeline_universal_pk_int4 test_gemm_pipeline_universal_pk_int4.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_universal_pk_int4 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target")
endif()
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_test_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_test_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp)
@@ -30,37 +30,47 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_universal_int8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_test_executable(test_ck_tile_gemm_pipeline_universal_pk_int4 test_gemm_pipeline_universal_pk_int4.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_universal_pk_int4 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
# On Radeon devices, build the WMMA version instead
add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target")
endif()
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx90a")
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
add_test_executable(test_ck_tile_gemm_pipeline_universal_fp16 test_gemm_pipeline_universal_fp16.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE --save-temps -Wno-gnu-line-marker)
add_test_executable(test_ck_tile_gemm_pipeline_universal_bf16 test_gemm_pipeline_universal_bf16.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_universal_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_test_executable(test_ck_tile_gemm_pipeline_basic_fp16 test_gemm_pipeline_basic_fp16.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target ")
endif()
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
if(GPU_TARGETS MATCHES "gfx94|gfx95")
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
if(GPU_TARGETS MATCHES "gfx11|gfx12")
# On Radeon devices, build the WMMA version instead
add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target test_ck_tile_gemm_pipeline")
endif()

View File

@@ -13,6 +13,28 @@
#include "test_gemm_pipeline_smoke_util.hpp"
#include "test_gemm_pipeline_smoke_run_test.inc"
struct GemmConfig_Mfma : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
struct GemmConfig_Wmma : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename GemmConfig,
typename ADataType,
typename BDataType,
@@ -38,17 +60,17 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
constexpr int kBlockPerCu = 1;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -130,7 +152,10 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
}
}
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
template <typename GemmConfig,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
bool run_gemm_test_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
@@ -142,12 +167,12 @@ bool run_gemm_test_prec_type(std::string a_layout,
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
arg_parser, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
arg_parser, Col{}, Col{}, Row{});
}
else
@@ -160,22 +185,22 @@ bool run_gemm_test_prec_type(std::string a_layout,
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
arg_parser, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
arg_parser, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
arg_parser, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
arg_parser, Col{}, Col{}, Row{});
}
else
@@ -185,7 +210,7 @@ bool run_gemm_test_prec_type(std::string a_layout,
}
}
template <typename APrecType, typename BPrecType, typename CPrecType>
template <typename GemmConfig, typename APrecType, typename BPrecType, typename CPrecType>
bool run_gemm_test(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -195,7 +220,8 @@ bool run_gemm_test(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
return run_gemm_test_prec_type<APrecType, BPrecType, CPrecType>(a_layout, b_layout, arg_parser);
return run_gemm_test_prec_type<GemmConfig, APrecType, BPrecType, CPrecType>(
a_layout, b_layout, arg_parser);
}
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
@@ -255,8 +281,15 @@ int run_gemm_combinations()
// Call the function with the current configuration
try
{
is_success = run_gemm_test<APrecType, BPrecType, CPrecType>(ARG_COUNT, argv) &&
#if CK_TILE_USE_WMMA
is_success = run_gemm_test<GemmConfig_Wmma, APrecType, BPrecType, CPrecType>(
ARG_COUNT, argv) &&
is_success;
#else
is_success = run_gemm_test<GemmConfig_Mfma, APrecType, BPrecType, CPrecType>(
ARG_COUNT, argv) &&
is_success;
#endif
}
catch(const ArgumentsNotSupportedException& e)
{

View File

@@ -220,6 +220,27 @@ struct GemmConfigComputeV5 : public GemmConfigBase
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
};
template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr int kBlockPerCu = 2;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;

View File

@@ -325,6 +325,13 @@ int run_gemm_combinations()
// Call the function with the current configuration
try
{
#if CK_TILE_USE_WMMA
is_success = run_gemm_test<GemmConfigComputeV3_WMMA<CPrecType>,
APrecType,
BPrecType,
CPrecType>(ARG_COUNT, argv) &&
is_success;
#else
is_success = run_gemm_test<GemmConfigComputeV3<CPrecType>,
APrecType,
BPrecType,
@@ -335,6 +342,7 @@ int run_gemm_combinations()
BPrecType,
CPrecType>(ARG_COUNT, argv) &&
is_success;
#endif
}
catch(const ArgumentsNotSupportedException& e)
{

View File

@@ -1,10 +1,9 @@
# Currently ck_tile is only built on gfx9
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_gtest_executable(test_gemm_multi_d_cshuffle test_gemm_multi_d_cshuffle.cpp)
add_gtest_executable(test_gemm_multi_d_default2d test_gemm_multi_d_default2d.cpp)
target_compile_definitions(test_gemm_multi_d_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -86,7 +86,28 @@ class TestCkTileGemmMultiD : public ::testing::Test
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
template <typename ADataType,
struct GemmWarpConfig_Mfma
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
struct GemmWarpConfig_Wmma
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename GemmWarpConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
@@ -99,17 +120,17 @@ class TestCkTileGemmMultiD : public ::testing::Test
void invoke_gemm_multi_d(const ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args,
const ck_tile::stream_config& s)
{
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Tile = GemmWarpConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmWarpConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmWarpConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
constexpr bool DoubleSmemBuffer = false;
@@ -359,8 +380,9 @@ class TestCkTileGemmMultiD : public ::testing::Test
StrideB,
stridesDs,
StrideE});
invoke_gemm_multi_d<ADataType,
#if CK_TILE_USE_WMMA
invoke_gemm_multi_d<GemmWarpConfig_Wmma,
ADataType,
BDataType,
DsDataType,
AccDataType,
@@ -370,6 +392,19 @@ class TestCkTileGemmMultiD : public ::testing::Test
DsLayout,
ELayout,
CDElementWiseFn>(args, ck_tile::stream_config{nullptr, false});
#else
invoke_gemm_multi_d<GemmWarpConfig_Mfma,
ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
CDElementWiseFn>(args, ck_tile::stream_config{nullptr, false});
#endif
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideE =" << StrideE

View File

@@ -12,7 +12,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
-enable-noalias-to-md-conversion=0
)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12")
add_gtest_executable(test_ck_tile_gemm_pipeline_wp test_gemm_pipeline_wp.cpp)
target_compile_options(test_ck_tile_gemm_pipeline_wp PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -31,8 +31,10 @@ using F8Types = std::tuple<Row, Col, Row, F8, F8, F32, F16, Default, WeightPresh
using KernelTypesWeightPreshuffle = ::testing::Types<
std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffle>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffle>,
F8Types
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffle>
#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8
, F8Types
#endif
>;
// clang-format on

View File

@@ -63,6 +63,23 @@ struct config
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32;
};
template <typename Datatype>
struct config_wmma
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(Datatype);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename Tuple>
class TestCkTileGemmPipeline : public ::testing::Test
{
@@ -79,13 +96,12 @@ class TestCkTileGemmPipeline : public ::testing::Test
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
using GemmConfig = config<ADataType>;
static constexpr bool Persistent =
ck_tile::tuple_element_or_default_t<Tuple, 9, std::false_type>::value;
// TODO: expose tile size through test t-param ?
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
template <typename GemmConfig, bool PadM, bool PadN, bool PadK, bool Preshuffle>
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// TODO: This should be parameterized in tests
@@ -253,6 +269,48 @@ class TestCkTileGemmPipeline : public ::testing::Test
k_batches_ = {1};
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <bool PadM = true, bool PadN = true, bool PadK = true, bool Preshuffle = false>
void Run(const int M,
const int N,
@@ -263,11 +321,17 @@ class TestCkTileGemmPipeline : public ::testing::Test
{
for(auto kb : k_batches_)
{
RunSingle<PadM, PadN, PadK, Preshuffle>(M, N, K, StrideA, StrideB, StrideC, kb);
#if CK_TILE_USE_WMMA
RunSingle<config_wmma<ADataType>, PadM, PadN, PadK, Preshuffle>(
M, N, K, StrideA, StrideB, StrideC, kb);
#else
RunSingle<config<ADataType>, PadM, PadN, PadK, Preshuffle>(
M, N, K, StrideA, StrideB, StrideC, kb);
#endif
}
}
template <bool PadM, bool PadN, bool PadK, bool Preshuffle>
template <typename GemmConfig, bool PadM, bool PadN, bool PadK, bool Preshuffle>
void RunSingle(const int M,
const int N,
const int K,
@@ -327,16 +391,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<BDataType> t_view({N / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
K / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(b_k_n.begin(), b_k_n.end(), t_view.begin());
ck_tile::HostTensor<BDataType> b_shuffle_host =
ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
@@ -354,7 +409,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
stride_B,
stride_C};
invoke_gemm<PadM, PadN, PadK, Preshuffle>(args, ck_tile::stream_config{nullptr, false});
invoke_gemm<GemmConfig, PadM, PadN, PadK, Preshuffle>(
args, ck_tile::stream_config{nullptr, false});
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;

View File

@@ -1,4 +1,4 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_ck_tile_grouped_gemm test_grouped_gemm.cpp)
endif()

View File

@@ -31,7 +31,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
using PersistentType = std::tuple_element_t<7, Tuple>;
static constexpr bool Persistent = PersistentType::value;
struct GroupedGemKernelParam
struct GroupedGemKernelParam_Mfma
{
static const bool kPadM = false;
static const bool kPadN = false;
@@ -51,13 +51,24 @@ class TestCkTileGroupedGemm : public ::testing::Test
static const ck_tile::index_t K_Warp_Tile = 16;
};
struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma
{
static const ck_tile::index_t M_Tile = 128;
static const ck_tile::index_t N_Tile = 128;
static const ck_tile::index_t K_Tile = 64;
static const ck_tile::index_t M_Warp_Tile = 16;
static const ck_tile::index_t N_Warp_Tile = 16;
static const ck_tile::index_t K_Warp_Tile = 16;
};
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
@@ -200,7 +211,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
@@ -460,15 +471,27 @@ class TestCkTileGroupedGemm : public ::testing::Test
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
invoke_grouped_gemm_persistent<ALayout, BLayout, CLayout>(
#if CK_TILE_USE_WMMA
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Wmma, ALayout, BLayout, CLayout>(
stream, group_count, kargs_ptr, splitk);
#else
invoke_grouped_gemm_persistent<GroupedGemKernelParam_Mfma, ALayout, BLayout, CLayout>(
stream, group_count, kargs_ptr, splitk);
#endif
}
else
{
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
#if CK_TILE_USE_WMMA
invoke_grouped_gemm<GroupedGemKernelParam_Wmma, ALayout, BLayout, CLayout>(
gemm_descs,
ck_tile::stream_config{nullptr, false, 1},
gemm_workspace.GetDeviceBuffer());
#else
invoke_grouped_gemm<GroupedGemKernelParam_Mfma, ALayout, BLayout, CLayout>(
gemm_descs,
ck_tile::stream_config{nullptr, false, 1},
gemm_workspace.GetDeviceBuffer());
#endif
}
// Copy results back to host for validation

View File

@@ -1,4 +1,3 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_tile_image_to_column test_tile_image_to_column.cpp)
endif()

View File

@@ -14,7 +14,7 @@ function(create_tile_layernorm2d_fwd SUFFIX)
target_compile_options(${TEST_CK_TILE_LAYERNORM2D_FWD} PRIVATE ${TEST_CK_TILE_LAYERNORM2D_FWD_COMPILE_OPTIONS})
endfunction()
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd")
set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")

View File

@@ -1,5 +1,4 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
function (add_moe_smoothquant_test TARGET_NAME MAIN_SRC)
message(DEBUG "adding ${TARGET_NAME}")
add_gtest_executable(${TARGET_NAME} ${MAIN_SRC})

View File

@@ -1,5 +1,5 @@
# Currently ck_tile is only built on gfx90a, gfx942 and gfx950
if(GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx950" OR GPU_TARGETS MATCHES "gfx90a")
# Currently ck_tile is only built on gfx90a, gfx942, gfx950, gfx11 and gfx12
if(GPU_TARGETS MATCHES "gfx942|gfx950|gfx90a|gfx11|gfx12")
function(add_moe_sorting_test EXECUTABLE USE_2D_BUF)
add_gtest_executable(${EXECUTABLE} test_moe_sorting.cpp moe_sorting_api.cpp)

View File

@@ -1,5 +1,4 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
function(add_permute_test TARGET_NAME MAIN_SRC)
add_gtest_executable(${TARGET_NAME} ${MAIN_SRC})

View File

@@ -17,9 +17,11 @@
#include <utility>
#include <vector>
#if !CK_TILE_USE_WMMA
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
#include "alternative_impl/matrix_core_swizzle.hpp"
#endif
#endif
namespace detail {
template <int bytes>
@@ -193,6 +195,7 @@ class TestCkTilePermute : public ::testing::Test
return permute<DataType>(a, stream_config);
};
#if !CK_TILE_USE_WMMA
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
if((perm == std::string("0,1,4,2,5,3,6") || perm == std::string("0,1,2,4,5,3,6") ||
@@ -278,6 +281,7 @@ class TestCkTilePermute : public ::testing::Test
}
}
else
#endif
#endif
{
run_permute();

View File

@@ -1,4 +1,4 @@
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_ck_tile_reduce2d test_reduce2d.cpp)
if(result EQUAL 0)
target_link_libraries(test_ck_tile_reduce2d PRIVATE utility)

View File

@@ -59,7 +59,7 @@ class TestCkTileReduce : public ::testing::Test
using Kernel = ck_tile::Reduce<Problem>;
// Launch configuration
constexpr ck_tile::index_t kBlockSize = 256;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kGridSize =

View File

@@ -14,7 +14,7 @@ function(create_tile_rmsnorm2d_fwd SUFFIX)
target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS})
endfunction()
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
set(RMSNORM2D_FWD_KNOWN_APIS "fwd;bwd")
set(RMSNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${RMSNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")

View File

@@ -1,5 +1,4 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
function (add_smoothquant_test TARGET_NAME MAIN_SRC)
message(DEBUG "adding ${TARGET_NAME}")

View File

@@ -10,8 +10,7 @@ function(add_tile_topk_softmax_test SUFFIX)
target_compile_options(${TEST_NAME} PRIVATE ${TEST_TOPK_SOFTMAX_COMPILE_OPTIONS})
endfunction()
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_tile_topk_softmax_test(fp16)
add_tile_topk_softmax_test(bf16)
else()

View File

@@ -0,0 +1,14 @@
add_gtest_executable(test_gemm_universal_reduce_bf16_wmma test_gemm_universal_reduce_bf16_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_reduce_bf16_wmma PRIVATE utility device_gemm_universal_reduce_instance)
endif()
add_gtest_executable(test_gemm_universal_reduce_fp16_wmma test_gemm_universal_reduce_fp16_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_reduce_fp16_wmma PRIVATE utility device_gemm_universal_reduce_instance)
endif()
add_gtest_executable(test_gemm_universal_reduce_bf16A_i8_wmma test_gemm_universal_reduce_bf16A_i8_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_universal_reduce_bf16A_i8_wmma PRIVATE utility device_gemm_universal_reduce_instance)
endif()

View File

@@ -0,0 +1,31 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "gtest/gtest.h"
#include "profiler/profile_gemm_universal_reduce_impl.hpp"
TEST(GemmUniversalReduce, BF16A_I8)
{
using Row = ck::tensor_layout::gemm::RowMajor;
int M = 512;
int N = 256;
int K = 128;
int KBatch = 1;
bool pass = true;
pass = pass && ck::profiler::profile_gemm_universal_reduce_impl<ck::bhalf_t,
int8_t,
ck::Tuple<>,
float,
ck::bhalf_t,
Row,
Row,
ck::Tuple<>,
Row>(
true, 3, false, true, M, N, K, K, N, N, KBatch, 1, 10);
EXPECT_TRUE(pass);
}

View File

@@ -0,0 +1,31 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "gtest/gtest.h"
#include "profiler/profile_gemm_universal_reduce_impl.hpp"
TEST(GemmUniversalReduce, BF16)
{
using Row = ck::tensor_layout::gemm::RowMajor;
int M = 512;
int N = 256;
int K = 128;
int KBatch = 1;
bool pass = true;
pass = pass && ck::profiler::profile_gemm_universal_reduce_impl<ck::bhalf_t,
ck::bhalf_t,
ck::Tuple<>,
float,
ck::bhalf_t,
Row,
Row,
ck::Tuple<>,
Row>(
true, 1, false, true, M, N, K, K, N, N, KBatch, 1, 10);
EXPECT_TRUE(pass);
}

View File

@@ -0,0 +1,31 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "gtest/gtest.h"
#include "profiler/profile_gemm_universal_reduce_impl.hpp"
TEST(GemmUniversalReduce, FP16)
{
using Row = ck::tensor_layout::gemm::RowMajor;
int M = 512;
int N = 256;
int K = 128;
int KBatch = 1;
bool pass = true;
pass = pass && ck::profiler::profile_gemm_universal_reduce_impl<ck::half_t,
ck::half_t,
ck::Tuple<>,
float,
ck::half_t,
Row,
Row,
ck::Tuple<>,
Row>(
true, 1, false, true, M, N, K, K, N, N, KBatch, 1, 10);
EXPECT_TRUE(pass);
}