Wmma support for multiple ABD GEMM (#2803)

* multi_abd wmma support:

 - Add multiple A and B support to multiple D implementation (gridwise level)
 - Add multi_abd GEMM (device level)
 - Add instances (xdl parity)
 - Add tests (both xdl and wmma)
 - Add examples
 - Add ckProfiler support (both xdl and wmma)

* Fix bug in device print function

* Fix unused template parameter

* Fix batched gemm for multiABD gridwise implementation

* Fix gemm_universal_reduce with multiABDs gridwise implementation

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>

[ROCm/composable_kernel commit: 3d29bff2f0]
This commit is contained in:
Enrico Degregori
2025-09-23 03:49:06 +02:00
committed by GitHub
parent 7122e27fd8
commit 0b149c8695
38 changed files with 5343 additions and 312 deletions

View File

@@ -1,3 +1,7 @@
add_example_executable(example_gemm_multi_ABD_wmma_fp16 gemm_multi_ABD_wmma_fp16.cpp)
add_example_executable(example_gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8 gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp)
add_example_executable(example_gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8 gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp)
add_example_executable(example_gemm_multi_ABD_wmma_fastgelu_bf16_i8 gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp)
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
add_example_executable(example_gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp)
add_example_executable(example_gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp)

View File

@@ -0,0 +1,307 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AElementOp = PassThrough;
using BElementOp = Multiply;
using CDEElementOp = AddFastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3<
AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
256,
128,
128,
64,
8,
8,
16,
16,
4,
2,
S<8, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
0,
S<8, 32, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<8, 8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 2;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 4096;
ck::index_t N = 768;
ck::index_t K = 6144;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideD = N;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{}));
Tensor<D0DataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
b1_device_buf.ToDevice(b1_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 2;
constexpr ck::index_t NumDTensor = 1;
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB, StrideB},
std::array<ck::index_t, NumDTensor>{StrideD},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
Tensor<B1DataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n));
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B1DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}

View File

@@ -0,0 +1,299 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Multiply;
using CDEElementOp = FastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3<
AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
256,
128,
128,
64,
8,
8,
16,
16,
4,
2,
S<8, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
0,
S<8, 32, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<8, 8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 2;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 4096;
ck::index_t N = 768;
ck::index_t K = 6144;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 5});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
b1_device_buf.ToDevice(b1_k_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 2;
constexpr ck::index_t NumDTensor = 0;
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB, StrideB},
std::array<ck::index_t, NumDTensor>{},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
Tensor<A0DataType> a_m_k({M, K});
Tensor<B1DataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n));
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B1DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}

View File

@@ -0,0 +1,362 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Row;
using DLayout = Row;
using ELayout = Row;
struct AddScale
{
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
__host__ __device__ constexpr void
operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const
{
const auto a0_v_t = ck::vector_type<ck::half_t, 4>{a0};
const auto a1_v_t = ck::vector_type<ck::half_t, 4>{a1};
auto r_v_t = ck::vector_type<ck::half_t, 4>{};
r_v_t.AsType<ck::half_t>()(I0) =
scale * (a0_v_t.AsType<ck::half_t>()[I0] + a1_v_t.AsType<ck::half_t>()[I0]);
r_v_t.AsType<ck::half_t>()(I1) =
scale * (a0_v_t.AsType<ck::half_t>()[I1] + a1_v_t.AsType<ck::half_t>()[I1]);
r_v_t.AsType<ck::half_t>()(I2) =
scale * (a0_v_t.AsType<ck::half_t>()[I2] + a1_v_t.AsType<ck::half_t>()[I2]);
r_v_t.AsType<ck::half_t>()(I3) =
scale * (a0_v_t.AsType<ck::half_t>()[I3] + a1_v_t.AsType<ck::half_t>()[I3]);
a = r_v_t.AsType<ck::half4_t>()[I0];
}
__host__ __device__ constexpr void
operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const
{
a = scale * (a0 + a1);
}
// this attribute controls the copy_function applying element_wise_op with
// pack4_data
constexpr const static bool is_pack4_invocable = true;
float scale = 1.0;
};
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, ck::half_t>(
ck::half_t& e, const float& c, const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * c + beta_ * ck::type_convert<float>(d));
};
float alpha_;
float beta_;
};
using AElementOp = AddScale;
using BElementOp = PassThrough;
using CDEElementOp = AlphaBetaAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3<
ck::Tuple<ALayout, ALayout>,
ck::Tuple<BLayout>,
ck::Tuple<DLayout>,
ELayout,
ck::Tuple<ADataType, ADataType>,
ck::Tuple<BDataType>,
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
256,
256,
128,
32,
8,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
0,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
1,
1,
8,
0,
1,
1,
S<1, 64, 1, 4>,
S<8, 8, 8>>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideD = N;
ck::index_t StrideE = N;
float alpha = 1.0f;
float beta = 1.0f;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
alpha = std::stof(argv[4]);
beta = std::stof(argv[5]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
alpha = std::stof(argv[11]);
beta = std::stof(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 12: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, "
"beta\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<ADataType> a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
a1_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
a1_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
a1_device_buf.ToDevice(a1_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{0.2};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(std::array<const void*, 2>{a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer()},
std::array<const void*, 1>{b_device_buf.GetDeviceBuffer()},
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, 2>{StrideA, StrideA},
std::array<ck::index_t, 1>{StrideB},
std::array<ck::index_t, 1>{StrideD},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
Tensor<ADataType> a_m_k({M, K});
for(int m = 0; m < M; ++m)
{
for(int k = 0; k < K; ++k)
{
a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k));
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CShuffleDataType,
AccDataType,
PassThrough,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}

View File

@@ -0,0 +1,296 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using BsDataType = ck::Tuple<B0DataType>;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = BF16;
using D1DataType = BF16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using BsLayout = ck::Tuple<B0Layout>;
using D0Layout = Row;
using D1Layout = D0Layout;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyAddFastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3<
AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
256,
128,
128,
64,
8,
8,
16,
16,
4,
2,
S<8, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
0,
S<8, 32, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<8, 8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 4096;
ck::index_t N = 768;
ck::index_t K = 6144;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideD = N;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-5, 5});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 1;
constexpr ck::index_t NumDTensor = 2;
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB},
std::array<ck::index_t, NumDTensor>{StrideD, StrideD},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -15,6 +15,151 @@
namespace ck {
namespace utility {
template <typename Argument, typename AsDataType, typename BsDataType, typename DsDataType>
struct RotatingMemWrapperMultiABD
{
static constexpr index_t NumAs = AsDataType::Size();
static constexpr index_t NumBs = BsDataType::Size();
static constexpr index_t NumDs = DsDataType::Size();
using AsGridPointer = decltype(Argument::p_as_grid);
using BsGridPointer = decltype(Argument::p_bs_grid);
using DsGridPointer = decltype(Argument::p_ds_grid);
RotatingMemWrapperMultiABD() = delete;
RotatingMemWrapperMultiABD(Argument& arg_,
std::size_t rotating_count_,
std::array<std::size_t, NumAs> size_as_,
std::array<std::size_t, NumBs> size_bs_,
std::array<std::size_t, NumDs> size_ds_)
: arg(arg_),
rotating_count(rotating_count_),
size_as(size_as_),
size_bs(size_bs_),
size_ds(size_ds_)
{
p_as_grids.push_back(arg.p_as_grid);
p_bs_grids.push_back(arg.p_bs_grid);
p_ds_grids.push_back(arg.p_ds_grid);
for(size_t i = 1; i < rotating_count; i++)
{
{
AsGridPointer as_buffer;
static_for<0, NumAs, 1>{}([&](auto j) {
void* pADeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_as_[j]));
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
static_cast<const void*>(p_as_grids[0][j]),
size_as_[j],
hipMemcpyDeviceToDevice));
using ADataType = remove_cvref_t<tuple_element_t<j.value, AsDataType>>;
as_buffer(j) = static_cast<const ADataType*>(pADeviceBuf);
});
p_as_grids.push_back(as_buffer);
}
{
BsGridPointer bs_buffer;
static_for<0, NumBs, 1>{}([&](auto j) {
void* pBDeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_bs_[j]));
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
static_cast<const void*>(p_bs_grids[0][j]),
size_bs_[j],
hipMemcpyDeviceToDevice));
using BDataType = remove_cvref_t<tuple_element_t<j.value, BsDataType>>;
bs_buffer(j) = static_cast<const BDataType*>(pBDeviceBuf);
});
p_bs_grids.push_back(bs_buffer);
}
{
DsGridPointer ds_buffer;
static_for<0, NumDs, 1>{}([&](auto j) {
void* pDDeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
static_cast<const void*>(p_ds_grids[0][j]),
size_ds_[j],
hipMemcpyDeviceToDevice));
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
});
p_ds_grids.push_back(ds_buffer);
}
}
}
void Next()
{
if(rotating_count > 1)
{
std::size_t idx = iter++ % rotating_count;
arg.p_as_grid = p_as_grids[idx];
arg.p_bs_grid = p_bs_grids[idx];
arg.p_ds_grid = p_ds_grids[idx];
}
}
void Print()
{
std::cout << "RotatingMemWrapperMultiD: { size_a: {";
static_for<0, NumAs, 1>{}(
[&](auto j) { std::cout << size_as[j] << (j.value < NumAs - 1 ? ", " : ""); });
std::cout << "}, size_b: {";
static_for<0, NumBs, 1>{}(
[&](auto j) { std::cout << size_bs[j] << (j.value < NumBs - 1 ? ", " : ""); });
std::cout << "}, rotating_count: " << rotating_count << "}" << std::endl;
}
~RotatingMemWrapperMultiABD()
{
if(rotating_count > 1)
{
// restore ptr
arg.p_as_grid = p_as_grids[0];
arg.p_bs_grid = p_bs_grids[0];
arg.p_ds_grid = p_ds_grids[0];
// free device mem
for(size_t i = 1; i < rotating_count; i++)
{
static_for<0, NumAs, 1>{}([&](auto j) {
using ADataType = remove_cvref_t<tuple_element_t<j.value, AsDataType>>;
hip_check_error(
hipFree(static_cast<void*>(const_cast<ADataType*>(p_as_grids[i][j]))));
});
static_for<0, NumBs, 1>{}([&](auto j) {
using BDataType = remove_cvref_t<tuple_element_t<j.value, BsDataType>>;
hip_check_error(
hipFree(static_cast<void*>(const_cast<BDataType*>(p_bs_grids[i][j]))));
});
static_for<0, NumDs, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
hip_check_error(
hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
});
}
}
}
private:
Argument& arg;
std::size_t iter = 0;
std::size_t rotating_count = 1;
std::array<std::size_t, NumAs> size_as = {0};
std::array<std::size_t, NumBs> size_bs = {0};
std::array<std::size_t, NumDs> size_ds = {0};
std::vector<AsGridPointer> p_as_grids;
std::vector<BsGridPointer> p_bs_grids;
std::vector<DsGridPointer> p_ds_grids;
};
template <typename Argument, typename DsDataType>
struct RotatingMemWrapperMultiD
{
@@ -318,6 +463,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
// total_time += cur_time;
// #endif
#if !defined(CK_USE_WMMA)
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
// std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
@@ -326,6 +472,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
static_cast<const void*>(gemm_args.p_a_grid),
static_cast<const void*>(gemm_args.p_b_grid));
}
#endif
}
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -55,6 +55,155 @@ struct DeviceGemmMultipleABD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
// GEMM:
// input : A0[M, K], B0[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleABDSplitK : public BaseOperator
{
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
std::array<const void*, NumBTensor> p_bs,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
std::array<ck::index_t, NumATensor> StrideAs,
std::array<ck::index_t, NumBTensor> StrideBs,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
/// @brief Wrapper for backward compatibility that allows to use instances of
/// DeviceGemmMultipleABDSplitK in contexts where DeviceGemmMultipleABD is expected.
///
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
/// The only difference between API of DeviceGemmMultipleABD and DeviceGemmMultipleABDSplitK
/// is that DeviceGemmMultipleABDSplitK::MakeArgumentPointer requires an additional parameter
/// KBatch which is explicitly passed as 1 by this wrapper.
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleABDSplitKWrapper : public DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleABDSplitK<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef __HIPCC_RTC__
explicit DeviceGemmMultipleABDSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
: p_op_(std::move(p_op))
{
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return p_op_->IsSupportedArgument(p_arg);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
std::array<const void*, NumBTensor> p_bs,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
std::array<ck::index_t, NumATensor> StrideAs,
std::array<ck::index_t, NumBTensor> StrideBs,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return p_op_->MakeArgumentPointer(p_as,
p_bs,
p_ds,
p_e,
M,
N,
K,
StrideAs,
StrideBs,
StrideDs,
StrideE,
1, // KBatch
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return p_op_->MakeInvokerPointer();
}
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
private:
std::unique_ptr<DeviceOp> p_op_;
#endif // __HIPCC_RTC__
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -64,9 +64,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
// shift A matrices pointer for splitk
typename GridwiseGemm::AsGridPointer p_as_grid_shift;
static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) {
using ADataType_ =
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::AsDataType_>>;
p_as_grid_shift(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
splitk_batch_offset.a_k_split_offset[i] + a_batch_offset;
});
// shift B matrices pointer for splitk
typename GridwiseGemm::BsGridPointer p_bs_grid_shift;
static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
using BDataType_ =
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::BsDataType_>>;
p_bs_grid_shift(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
});
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset + a_batch_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset + b_batch_offset,
p_as_grid_shift,
p_bs_grid_shift,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
p_shared,
@@ -278,8 +296,8 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
ADataType,
BDataType,
Tuple<ADataType>,
Tuple<BDataType>,
AccDataType,
CShuffleDataType,
Tuple<>, // DsDataType
@@ -346,15 +364,15 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
BElementwiseOperation b_element_op_,
CElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: GridwiseGemm::Argument(p_a_grid_,
p_b_grid_,
: GridwiseGemm::Argument(std::array<const void*, 1>{p_a_grid_},
std::array<const void*, 1>{p_b_grid_},
std::array<const void*, 0>{}, // p_ds_grid_
p_c_grid_,
M_,
N_,
K_,
StrideA_,
StrideB_,
std::array<index_t, 1>{StrideA_},
std::array<index_t, 1>{StrideB_},
std::array<index_t, 0>{}, // StrideDs_
StrideC_,
k_batch_,
@@ -423,26 +441,33 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
// Packed sizes are 1 for all implemented data types but we include it anyway
// for future compatibility.
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType) / GridwiseGemm::APackedSize;
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
// Note: the grid descriptors and size_a / size_b do *not* take batching into
// account, so we have to manually multiply overall buffer sizes for rotating
// memory by batch.
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_,
stream_config.rotating_count,
arg_.Batch * size_a_buffer,
arg_.Batch * size_b_buffer);
std::array<std::size_t, 1> size_as_buffers;
size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() *
sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch;
std::array<std::size_t, 1> size_bs_buffers;
size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch;
ck::utility::RotatingMemWrapperMultiABD<Argument,
Tuple<ADataType>,
Tuple<BDataType>,
Tuple<>>
rotating_mem(arg_,
stream_config.rotating_count,
size_as_buffers,
size_bs_buffers,
std::array<std::size_t, 0>{});
rotating_mem.Print();
auto run_flush_cache = [&]() {

View File

@@ -0,0 +1,422 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/// @brief \"Universal\" GEMM operation with SplitK support and multiple D tensors.
///
/// @par Overview
/// This GEMM operation implements the following mathematical equation:
/// E{M,N} = CDE_op(A_op(As{M,K}...) * B_op(Bs{K,N}...), Ds{M,N}...)
/// Where As, Bs, Ds are input tensors and E is the output tensor. The A/B_op are
/// elementwise
// operations that could be applied on each tensor respectively. The CDE_op is an
// elementwise operation applied to the C and all D tensors.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
///
/// @note This Kernel implementation supports SplitK algorithm. It can be configured
/// to split the dot product accumulated over the K dimension into multiple working groups.
/// The partial products of different workgroups are then reduced using the AtomicAdd
/// operation.
///
/// @tparam AsLayout A tensors data layouts.
/// @tparam BsLayout B tensors data layouts.
/// @tparam DsLayout D tensors data layouts.
/// @tparam ELayout E tensor data layout.
/// @tparam AsDataType A tensors data types.
/// @tparam BsDataType B tensors data types.
/// @tparam DsDataType D tensors data types.
/// @tparam EDataType E tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
/// GEMM) and D input tensors.
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
/// @tparam NPerBlock The input/output data tile size in the N dimension.
/// @tparam KPerBlock The input data tile size in the K dimension.
/// @tparam AK1 The vector load size from global memory for A tensor.
/// @tparam BK1 The vector load size from global memory for B tensor.
/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction.
/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction.
/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront.
/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront.
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question, "How many threads can be
/// arranged on each input data axis?"
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question: "How many threads to
/// arrange on each input data axis?"
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in M dimension.
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
/// Used when loading data from D tensors and storing data
/// to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
/// instructions.
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
/// instructions.
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
/// in global memory. Currently not supported!
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
/// in global memory (pre-shuffled).
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemmMultipleABD_Wmma_CShuffleV3
: public DeviceGemmMultipleABDSplitK<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
// Note: Pass multiple layout but then using only the first one
// This is to replicate xdl functionality but it should be extended
using ALayout = remove_cvref_t<tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<tuple_element_t<0, BsLayout>>;
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
using Argument = typename GridwiseGemm::Argument;
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
AsDataType,
BsDataType,
DsDataType,
EDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Invoker
using Invoker = typename DeviceGemmCommon::Invoker;
static bool IsSupportedArgument(const Argument& arg)
{
return DeviceGemmCommon::IsSupportedArgument(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::array<const void*, GridwiseGemm::NumATensor> p_as,
std::array<const void*, GridwiseGemm::NumBTensor> p_bs,
std::array<const void*, GridwiseGemm::NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
std::array<ck::index_t, GridwiseGemm::NumATensor> StrideAs,
std::array<ck::index_t, GridwiseGemm::NumBTensor> StrideBs,
std::array<index_t, GridwiseGemm::NumDTensor> StrideDs,
index_t StrideE,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_as,
p_bs,
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideAs,
StrideBs,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, GridwiseGemm::NumATensor> p_as,
std::array<const void*, GridwiseGemm::NumBTensor> p_bs,
std::array<const void*, GridwiseGemm::NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
std::array<ck::index_t, GridwiseGemm::NumATensor> StrideAs,
std::array<ck::index_t, GridwiseGemm::NumBTensor> StrideBs,
std::array<ck::index_t, GridwiseGemm::NumDTensor> StrideDs,
index_t StrideE,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_as,
p_bs,
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideAs,
StrideBs,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmMultipleABD_Wmma_CShuffleV3"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", ";
static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) {
using ALayout_ = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
str << std::string(ALayout_::name)[0];
});
static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
using BLayout_ = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
str << std::string(BLayout_::name)[0];
});
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
str << std::string(DLayout::name)[0];
});
str << std::string(ELayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
<< "WaveTile: "
<< MPerWmma << "x"<<NPerWmma << ", "
<< "WaveMap: "
<< MRepeat << "x" << NRepeat << ", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "KPack: "
<< GridwiseGemm::KPack;
// clang-format on
return str.str();
}
REGISTER_EXTRA_PRINTING_METHODS
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -193,8 +193,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
Tuple<ADataType>,
Tuple<BDataType>,
AccDataType,
CShuffleDataType,
DsDataType,
@@ -244,8 +244,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
Tuple<ADataType>,
Tuple<BDataType>,
DsDataType,
EDataType,
MPerBlock,
@@ -291,15 +291,15 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
return Argument{std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
StrideDs,
StrideE,
KBatch,
@@ -328,15 +328,15 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
StrideDs,
StrideE,
KBatch,

View File

@@ -182,8 +182,8 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
ADataType,
BDataType,
Tuple<ADataType>,
Tuple<BDataType>,
AccDataType,
CShuffleDataType,
Tuple<>, // DsDataType
@@ -233,8 +233,8 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
Tuple<ADataType>,
Tuple<BDataType>,
Tuple<>,
CDataType,
MPerBlock,
@@ -283,15 +283,15 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
BElementwiseOperation b_element_op,
CElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
return Argument{std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
std::array<const void*, 0>{}, // p_ds_grid_
p_c,
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
std::array<index_t, 0>{}, // StrideDs_
StrideC,
KBatch,
@@ -317,15 +317,15 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
std::array<const void*, 0>{}, // p_ds_grid_
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
std::array<index_t, 0>{}, // StrideDs_
StrideC,
KBatch,

View File

@@ -91,8 +91,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
ADataType,
BDataType,
Tuple<ADataType>,
Tuple<BDataType>,
BScaleDataType,
AccDataType,
CShuffleDataType,
Tuple<>, // DsDataType
@@ -144,8 +145,8 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
Tuple<ADataType>,
Tuple<BDataType>,
Tuple<>,
CDataType,
MPerBlock,
@@ -195,15 +196,15 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
BElementwiseOperation b_element_op,
CElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
return Argument{std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
std::array<const void*, 0>{}, // p_ds_grid_
p_c,
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
std::array<index_t, 0>{}, // StrideDs_
StrideC,
StrideScaleB,
@@ -233,15 +234,15 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
std::array<const void*, 0>{}, // p_ds_grid_
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
std::array<index_t, 0>{}, // StrideDs_
StrideC,
StrideScaleB,

View File

@@ -23,8 +23,8 @@ namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename ADataType,
typename BDataType,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType,
index_t MPerBlock,
@@ -88,15 +88,24 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType) / GridwiseGemm::APackedSize;
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
std::array<std::size_t, GridwiseGemm::NumATensor> size_as_buffers;
static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) {
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
size_as_buffers[i] = a_grid_desc_ak0_m_ak1[i].GetElementSpaceSize() *
sizeof(ADataType) / GridwiseGemm::APackedSize;
});
std::array<std::size_t, GridwiseGemm::NumBTensor> size_bs_buffers;
static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
size_bs_buffers[i] = b_grid_desc_bk0_n_bk1[i].GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
});
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
@@ -108,12 +117,13 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_,
stream_config.rotating_count,
size_a_buffer,
size_b_buffer,
size_ds_buffers);
ck::utility::
RotatingMemWrapperMultiABD<Argument, AsDataType, BsDataType, DsDataType>
rotating_mem(arg_,
stream_config.rotating_count,
size_as_buffers,
size_bs_buffers,
size_ds_buffers);
rotating_mem.Print();
auto run_flush_cache = [&]() {

View File

@@ -98,8 +98,8 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
BLayout,
Tuple<>,
CLayout,
ADataType,
BDataType,
Tuple<ADataType>,
Tuple<BDataType>,
GemmAccDataType,
ReduceDataType,
Tuple<>,
@@ -147,15 +147,15 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
struct Argument : public GridwiseGemm::Argument
{
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
Argument(std::array<const void*, 1> p_a_grid_,
std::array<const void*, 1> p_b_grid_,
const ::std::array<const void*, NumDTensor> p_ds_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, 1> StrideA_,
std::array<index_t, 1> StrideB_,
const ::std::array<index_t, NumDTensor> stride_ds_,
index_t StrideC_,
index_t KBatch_,
@@ -430,15 +430,15 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
return Argument{std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
p_ds,
p_c,
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
stride_ds,
StrideC,
KBatch,
@@ -472,15 +472,15 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return ::std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
return ::std::make_unique<Argument>(std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
DsStrides,
StrideC,
KSplit,

View File

@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -39,8 +40,8 @@ namespace ck {
/// @tparam BLayout B tensor data layout.
/// @tparam DsLayout D tensors data layouts.
/// @tparam ELayout E tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam AsDataType A tensors data types.
/// @tparam BsDataType B tensors data types.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
@@ -129,8 +130,8 @@ template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AsDataType,
typename BsDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
@@ -181,8 +182,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AsDataType,
BsDataType,
AccDataType,
CShuffleDataType,
DsDataType,
@@ -233,8 +234,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AsDataType,
BsDataType,
AccDataType,
CShuffleDataType,
DsDataType,
@@ -305,8 +306,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
using Base::CalculateMPadded;
using Base::CalculateNBlock;
using Base::CalculateNPadded;
using Base::MakeAGridDescriptor_AK0_M_AK1;
using Base::MakeBGridDescriptor_BK0_N_BK1;
using Base::MakeAsGridDescriptor_AK0_M_AK1;
using Base::MakeBsGridDescriptor_BK0_N_BK1;
using Base::MakeDEGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
@@ -320,24 +321,30 @@ struct GridwiseGemm_wmma_cshuffle_v3
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
using Base::NumATensor;
using Base::NumBTensor;
using Base::NumDTensor;
using typename Base::AsGridPointer;
using typename Base::BsGridPointer;
using typename Base::DsGridPointer;
using AsDataType_ = AsDataType;
using BsDataType_ = BsDataType;
struct Problem
{
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumATensor> StrideAs_,
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t KBatch_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideAs{StrideAs_},
StrideBs{StrideBs_},
StrideDs{StrideDs_},
StrideE{StrideE_},
KBatch{KBatch_},
@@ -355,7 +362,15 @@ struct GridwiseGemm_wmma_cshuffle_v3
__host__ void Print() const
{
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
<< "SAs: {";
static_for<0, NumATensor, 1>{}([&](auto i) {
std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
});
std::cout << "}, " << "SBs: {";
static_for<0, NumBTensor, 1>{}([&](auto i) {
std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
});
std::cout << "}, ";
if constexpr(NumDTensor > 0)
{
std::cout << "SDs: { ";
@@ -373,8 +388,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumATensor> StrideAs;
std::array<index_t, NumBTensor> StrideBs;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t KBatch;
@@ -391,15 +406,15 @@ struct GridwiseGemm_wmma_cshuffle_v3
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
__host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
std::array<const void*, NumBTensor> p_bs_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumATensor> StrideAs_,
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t k_batch_,
@@ -407,9 +422,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
: Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_},
p_as_grid{},
p_bs_grid{},
p_ds_grid{},
p_e_grid{p_e_grid_},
a_element_op{a_element_op_},
@@ -417,9 +432,27 @@ struct GridwiseGemm_wmma_cshuffle_v3
cde_element_op{cde_element_op_},
is_reduce(is_reduce_)
{
// populate pointer, desc for As
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
// A pointer
p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
});
// populate pointer, desc for Bs
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
// B pointer
p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
});
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
});
}
@@ -434,8 +467,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
return (Problem::KBatch > 1) && (!is_reduce);
}
const ADataType* p_a_grid;
const BDataType* p_b_grid;
AsGridPointer p_as_grid;
BsGridPointer p_bs_grid;
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
@@ -452,29 +485,39 @@ struct GridwiseGemm_wmma_cshuffle_v3
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{
// Note: in xdl implementation multiple AB supports one layout
// but multiple strides, so we create an array of offsets with
// the same values.
// It should be fixed later on. Once we will have a thread transfer
// more flexible.
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = k_id * karg.KRead / APackedSize;
static_for<0, NumATensor, 1>{}(
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
static_for<0, NumATensor, 1>{}(
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
b_k_split_offset = k_id * karg.KRead / BPackedSize;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = k_id * k0_offset / BPackedSize;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
}
}
@@ -497,8 +540,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
std::array<index_t, NumATensor> a_k_split_offset;
std::array<index_t, NumBTensor> b_k_split_offset;
index_t c_reduce_offset;
};
@@ -514,8 +557,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
@@ -524,10 +567,10 @@ struct GridwiseGemm_wmma_cshuffle_v3
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
@@ -562,20 +605,20 @@ struct GridwiseGemm_wmma_cshuffle_v3
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_b_grid,
TailNum>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
@@ -595,10 +638,26 @@ struct GridwiseGemm_wmma_cshuffle_v3
__device__ static void
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
{
// shift A matrices pointer for splitk
AsGridPointer p_as_grid_splitk;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
splitk_batch_offset.a_k_split_offset[i];
});
// shift B matrices pointer for splitk
BsGridPointer p_bs_grid_splitk;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
splitk_batch_offset.b_k_split_offset[i];
});
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
p_as_grid_splitk,
p_bs_grid_splitk,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg,

View File

@@ -22,8 +22,9 @@ template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AsDataType,
typename BsDataType,
typename BScaleType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
@@ -76,8 +77,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AsDataType,
BsDataType,
AccDataType,
CShuffleDataType,
DsDataType,
@@ -123,15 +124,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
PermuteA,
PermuteB>
{
using BScaleType = ck::half_t;
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AsDataType,
BsDataType,
AccDataType,
CShuffleDataType,
DsDataType,
@@ -202,8 +201,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using Base::CalculateMPadded;
using Base::CalculateNBlock;
using Base::CalculateNPadded;
using Base::MakeAGridDescriptor_AK0_M_AK1;
using Base::MakeBGridDescriptor_BK0_N_BK1;
using Base::MakeAsGridDescriptor_AK0_M_AK1;
using Base::MakeBsGridDescriptor_BK0_N_BK1;
using Base::MakeDEGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
@@ -217,7 +216,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
using Base::NumATensor;
using Base::NumBTensor;
using Base::NumDTensor;
using typename Base::AsGridPointer;
using typename Base::BsGridPointer;
using typename Base::DsGridPointer;
struct Problem
@@ -225,8 +228,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumATensor> StrideAs_,
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleB_,
@@ -234,8 +237,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideAs{StrideAs_},
StrideBs{StrideBs_},
StrideDs{StrideDs_},
StrideE{StrideE_},
StrideScaleB{StrideScaleB_},
@@ -254,7 +257,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__host__ void Print() const
{
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
<< "SAs: {";
static_for<0, NumATensor, 1>{}([&](auto i) {
std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
});
std::cout << "}, " << "SBs: {";
static_for<0, NumBTensor, 1>{}([&](auto i) {
std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
});
std::cout << "}, ";
if constexpr(NumDTensor > 0)
{
std::cout << "SDs: { ";
@@ -273,8 +284,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumATensor> StrideAs;
std::array<index_t, NumBTensor> StrideBs;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t StrideScaleB;
@@ -292,15 +303,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
__host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
std::array<const void*, NumBTensor> p_bs_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumATensor> StrideAs_,
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleB_,
@@ -310,9 +321,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, StrideScaleB_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
: Problem{M_,
N_,
K_,
StrideAs_,
StrideBs_,
StrideDs_,
StrideE_,
StrideScaleB_,
k_batch_},
p_as_grid{},
p_bs_grid{},
p_ds_grid{},
p_e_grid{p_e_grid_},
p_b_scale_grid{p_b_scale_grid_},
@@ -321,6 +340,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
cde_element_op{cde_element_op_},
is_reduce(is_reduce_)
{
// populate pointer, desc for As
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
// A pointer
p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
});
// populate pointer, desc for Bs
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
// B pointer
p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
@@ -338,8 +373,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
return (Problem::KBatch > 1) && (!is_reduce);
}
const ADataType* p_a_grid;
const BDataType* p_b_grid;
AsGridPointer p_as_grid;
BsGridPointer p_bs_grid;
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
@@ -355,29 +390,39 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{
// Note: in xdl implementation multiple AB supports one layout
// but multiple strides, so we create an array of offsets with
// the same values.
// It should be fixed later on. Once we will have a thread transfer
// more flexible.
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = k_id * karg.KRead / APackedSize;
static_for<0, NumATensor, 1>{}(
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
static_for<0, NumATensor, 1>{}(
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
b_k_split_offset = k_id * karg.KRead / BPackedSize;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = k_id * k0_offset / BPackedSize;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
}
}
@@ -410,8 +455,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
std::array<index_t, NumATensor> a_k_split_offset;
std::array<index_t, NumBTensor> b_k_split_offset;
index_t scale_k_split_offset; // New member for scale matrix offset
index_t c_reduce_offset;
};
@@ -423,7 +468,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
template <index_t NumberOfBuffers, typename BScaleGridDesc_BN_AK, typename BScaleType>
template <index_t NumberOfBuffers, typename BScaleGridDesc_BN_AK>
__device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
const BScaleType* p_b_scale_grid,
index_t block_n_id)
@@ -488,8 +533,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
const BScaleType* p_b_scale_grid,
@@ -499,10 +544,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
@@ -542,20 +587,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_b_grid,
TailNum>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
@@ -575,10 +620,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__device__ static void
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
{
// shift A matrices pointer for splitk
AsGridPointer p_as_grid_splitk;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
splitk_batch_offset.a_k_split_offset[i];
});
// shift B matrices pointer for splitk
BsGridPointer p_bs_grid_splitk;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
splitk_batch_offset.b_k_split_offset[i];
});
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
p_as_grid_splitk,
p_bs_grid_splitk,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_shared,

View File

@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -61,8 +62,8 @@ template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AsDataType,
typename BsDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
@@ -119,6 +120,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
using LDSTypeA =
typename std::conditional<(NumATensor > 1),
ComputeTypeA,
remove_cvref_t<tuple_element_t<0, AsDataType>>>::type;
using LDSTypeB =
typename std::conditional<(NumBTensor > 1),
ComputeTypeB,
remove_cvref_t<tuple_element_t<0, BsDataType>>>::type;
static constexpr auto EShuffleBlockTransferScalarPerVector =
CDEShuffleBlockTransferScalarPerVectors{}[I0];
@@ -136,14 +149,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
if constexpr(is_same_v<remove_cvref_t<LDSTypeA>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
if constexpr(is_same_v<remove_cvref_t<LDSTypeB>, pk_i4_t>)
return 2;
else
return 1;
@@ -230,6 +243,31 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
static constexpr auto MakeAsGridPointer()
{
return generate_tuple(
[&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
return static_cast<const ADataType_*>(nullptr);
},
Number<NumATensor>{});
}
static constexpr auto MakeBsGridPointer()
{
return generate_tuple(
[&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
return static_cast<const BDataType_*>(nullptr);
},
Number<NumBTensor>{});
}
using AsGridPointer = decltype(MakeAsGridPointer());
using BsGridPointer = decltype(MakeBsGridPointer());
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
@@ -314,6 +352,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
__host__ __device__ static auto
MakeAsGridDescriptor_AK0_M_AK1(const index_t M,
const index_t MPad,
const index_t K,
const index_t KPad,
const std::array<index_t, NumATensor>& StrideAs,
const index_t AK0)
{
return generate_tuple(
[&](auto i) {
return MakeAGridDescriptor_AK0_M_AK1(M, MPad, K, KPad, StrideAs[i], AK0);
},
Number<NumATensor>{});
}
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
@@ -330,7 +383,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
static_assert(!(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> &&
static_assert(!(is_same_v<remove_cvref_t<LDSTypeB>, pk_i4_t> &&
GemmSpec != GemmSpecialization::Default),
"pk_i4_t does not support padding");
@@ -424,6 +477,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
__host__ __device__ static auto
MakeBsGridDescriptor_BK0_N_BK1(const index_t K,
const index_t KPad,
const index_t N,
const index_t NPad,
const std::array<index_t, NumBTensor>& StrideBs,
const index_t BK0)
{
return generate_tuple(
[&](auto i) {
return MakeBGridDescriptor_BK0_N_BK1(K, KPad, N, NPad, StrideBs[i], BK0);
},
Number<NumBTensor>{});
}
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&)
{
@@ -557,7 +625,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeA) / APackedSize;
constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
@@ -604,20 +672,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
constexpr auto KThreadRead = 64 / MPerWmma;
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
? 1
: 128 / (AK1Number * M0 * sizeof(ADataType));
: 128 / (AK1Number * M0 * sizeof(LDSTypeA));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=n0
constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128)
constexpr auto mpair = (AK1Number * MPerWmma * sizeof(LDSTypeA) > 128)
? 1
: ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0
: ((128 / (AK1Number * MPerWmma * sizeof(LDSTypeA))) > M0
? M0
: 128 / (AK1Number * MPerWmma * sizeof(ADataType)));
: 128 / (AK1Number * MPerWmma * sizeof(LDSTypeA)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
@@ -694,7 +762,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
// NLdsLayer * K0 as logical Bank
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeB) / BPackedSize;
constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
@@ -738,20 +806,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
constexpr auto KThreadRead = 64 / NPerWmma;
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128)
? 1
: 128 / (BK1Number * N0 * sizeof(BDataType));
: 128 / (BK1Number * N0 * sizeof(LDSTypeB));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=n0
constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128)
constexpr auto npair = (BK1Number * NPerWmma * sizeof(LDSTypeB) > 128)
? 1
: ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0
: ((128 / (BK1Number * NPerWmma * sizeof(LDSTypeB))) > N0
? N0
: 128 / (BK1Number * NPerWmma * sizeof(BDataType)));
: 128 / (BK1Number * NPerWmma * sizeof(LDSTypeB)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
@@ -836,8 +904,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
ADataType,
BDataType,
LDSTypeA,
LDSTypeB,
ComputeTypeA,
ComputeTypeB,
AccDataType,
@@ -1120,11 +1188,24 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize),
c_block_size * sizeof(CShuffleDataType));
}
template <index_t numElements, typename Type>
__device__ __forceinline__ static auto get_first_element_workaround(Type& array)
{
if constexpr(numElements > 1)
{
return array;
}
else
{
return array[I0];
}
}
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -1133,13 +1214,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
__device__ static void Run(AsGridPointer p_as_grid,
BsGridPointer p_bs_grid,
DsGridPointer p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const AGridDesc_AK0_M_K1& as_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& bs_grid_desc_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
@@ -1152,10 +1233,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const index_t& num_k_block_per_scale,
BScaleStruct& b_scale_struct)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto as_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize());
},
Number<NumATensor>{});
const auto bs_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize());
},
Number<NumBTensor>{});
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -1183,66 +1274,144 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// workaround because v7r2 is not as general as v4r1
auto get_a_blockwise_transfer = [&]() {
if constexpr(NumATensor > 1)
{
const auto idx_as_block_begin = generate_tuple(
[&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
return ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
AsDataType,
Tuple<LDSTypeA>,
AGridDesc_AK0_M_K1,
decltype(tie(a_block_desc_ak0_m_ak1)),
AElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
uniform_sequence_gen_t<NumATensor, AThreadTransferSrcResetCoordinateAfterRun>,
Sequence<true>,
BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1,
idx_as_block_begin,
tie(a_block_desc_ak0_m_ak1),
make_tuple(make_multi_index(0, 0, 0)),
a_element_op};
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
remove_cvref_t<tuple_element_t<0, AsDataType>>,
remove_cvref_t<tuple_element_t<0, AsDataType>>,
decltype(as_grid_desc_ak0_m_ak1[I0]),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
as_grid_desc_ak0_m_ak1[I0],
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
auto a_blockwise_copy = get_a_blockwise_transfer();
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// workaround because v7r2 is not as general as v4r1
auto get_b_blockwise_transfer = [&]() {
if constexpr(NumBTensor > 1)
{
const auto idx_bs_block_begin = generate_tuple(
[&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
return ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
BsDataType,
Tuple<LDSTypeB>,
BGridDesc_BK0_N_K1,
decltype(tie(b_block_desc_bk0_n_bk1)),
BElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
uniform_sequence_gen_t<NumBTensor, BThreadTransferSrcResetCoordinateAfterRun>,
Sequence<true>,
BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1,
idx_bs_block_begin,
tie(b_block_desc_bk0_n_bk1),
make_tuple(make_multi_index(0, 0, 0)),
b_element_op};
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
remove_cvref_t<tuple_element_t<0, BsDataType>>,
remove_cvref_t<tuple_element_t<0, BsDataType>>,
decltype(bs_grid_desc_bk0_n_bk1[I0]),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
bs_grid_desc_bk0_n_bk1[I0],
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
auto b_blockwise_copy = get_b_blockwise_transfer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
@@ -1250,12 +1419,12 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
sizeof(ADataType) /
APackedSize),
reinterpret_cast<LDSTypeB*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
sizeof(LDSTypeA) /
APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
@@ -1267,25 +1436,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
(as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) /
KPerBlock);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
b_scale_struct,
num_k_block_main_loop,
num_k_block_per_scale);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
get_first_element_workaround<NumATensor>(as_grid_buf),
a_block_buf,
a_block_slice_copy_step,
get_first_element_workaround<NumBTensor>(bs_grid_desc_bk0_n_bk1),
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
get_first_element_workaround<NumBTensor>(bs_grid_buf),
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
b_scale_struct,
num_k_block_main_loop,
num_k_block_per_scale);
// shuffle C and write out
{

View File

@@ -1,16 +1,26 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(GEMM_MULTI_ABD_INSTANCES)
list(APPEND GEMM_MULTI_ABD_INSTANCES
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
)
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp
device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
)
add_instance_library(device_gemm_multi_abd_instance ${GEMM_MULTI_ABD_INSTANCES})

View File

@@ -0,0 +1,109 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = BF16;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using D0Layout = Row;
using ELayout = Row;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
template <typename BsLayout,
typename DsLayout,
typename BsDataType,
typename DsDataType,
typename BElementOp,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched>
using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple<
// clang-format off
//###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer|
//###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | |
//###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | |
//###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <typename BsLayout,
typename DsLayout,
typename BsDataType,
typename DsDataType,
typename BElementOp,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched>
using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple<
// clang-format off
//###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer|
//###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | |
//###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | |
//###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 32, 16, 16, 256, 8, 8, 16, 16, 1, 1, S<32, 1, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 2>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 64, 16, 32, 256, 8, 8, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABDSplitK<AsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ELayout,
AsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
EDataType,
AElementOp,
Multiply,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
PassThrough,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
PassThrough,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,85 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using ELayout = Row;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
using BElementOp = Multiply;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched>
using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances = std::tuple<
// clang-format off
//###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer|
//###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | |
//###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | |
//###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABDSplitK<AsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
Multiply,
Add>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
Multiply,
Add,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
Multiply,
Add,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABDSplitK<AsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
Multiply,
AddFastGelu>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
Multiply,
AddFastGelu,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
Multiply,
AddFastGelu,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,111 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABDSplitK<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
AddFastGelu,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABDSplitK<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
Add,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABDSplitK<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances<ck::Tuple<>,
ck::Tuple<>,
PassThrough,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABDSplitK<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances<ck::Tuple<>,
ck::Tuple<>,
FastGelu,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABDSplitK<AsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ELayout,
AsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
EDataType,
AElementOp,
Multiply,
FastGelu>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
FastGelu,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
FastGelu,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABDSplitK<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
EDataType,
AElementOp,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
Multiply,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(
instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
Multiply,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABDSplitK<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
EDataType,
AElementOp,
PassThrough,
MultiplyAdd>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
PassThrough,
MultiplyAdd,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
PassThrough,
MultiplyAdd,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABDSplitK<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
EDataType,
AElementOp,
PassThrough,
MultiplyAddFastGelu>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
PassThrough,
MultiplyAddFastGelu,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
PassThrough,
MultiplyAddFastGelu,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABDSplitK<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
EDataType,
AElementOp,
PassThrough,
MultiplyFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
MultiplyFastGelu,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(
instances,
device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
MultiplyFastGelu,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,424 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace profiler {
// this function is also defined in CK but because of the way we use it in
// profile_gemm_multi_impl, it requires the arguments to not be const
template <typename... X, typename... Y>
auto concat_tuple_of_refs(ck::Tuple<X&...>& tx, ck::Tuple<Y&...>& ty)
{
return ck::unpack2(
[&](auto&&... zs) { return ck::Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
template <typename AsDataType,
typename BsDataType,
typename AccDataType,
typename DsDataType,
typename EDataType,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
bool profile_gemm_multi_abd_impl(int do_verification,
int init_method,
bool /*do_log*/,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideD,
int StrideE)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
static constexpr index_t NumATensor = AsDataType::Size();
auto as_m_k = generate_tuple(
[&](auto i) {
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
return Tensor<ADataType>(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
},
Number<NumATensor>{});
static constexpr index_t NumBTensor = BsDataType::Size();
auto bs_k_n = generate_tuple(
[&](auto i) {
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
return Tensor<BDataType>(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
},
Number<NumBTensor>{});
static constexpr index_t NumDTensor = DsDataType::Size();
auto ds_m_n = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return Tensor<DDataType>(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
},
Number<NumDTensor>{});
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
static_for<0, NumATensor, 1>{}(
[&](auto i) { std::cout << "a" << i.value << "_m_k: " << as_m_k(i).mDesc << std::endl; });
static_for<0, NumBTensor, 1>{}(
[&](auto i) { std::cout << "b" << i.value << "_k_n: " << bs_k_n(i).mDesc << std::endl; });
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "d" << i.value << "_m_n: " << ds_m_n(i).mDesc << std::endl; });
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
as_m_k(i).GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
bs_k_n(i).GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
ds_m_n(i).GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
});
break;
default:
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
as_m_k(i).GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
bs_k_n(i).GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
ds_m_n(i).GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
});
}
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{};
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// run reference
if(do_verification)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
Tensor<AccDataType> c_m_n({M, N});
using AComputeType =
typename std::conditional<(NumATensor > 1),
EDataType,
remove_cvref_t<tuple_element_t<0, AsDataType>>>::type;
auto get_a_matrix = [&]() -> auto {
// in case of pass through we avoid allocating a new
// tensor and copying values
if constexpr(is_same_v<AElementOp, PassThrough>)
{
return as_m_k(Number<0>{});
}
else
{
Tensor<AComputeType> a_m_k({M, K});
for(int m = 0; m < M; ++m)
{
for(int k = 0; k < K; ++k)
{
// result
auto data_refs1 = ck::tie(a_m_k(m, k));
// inputs
auto data_refs2 =
generate_tie([&](auto i) -> auto& { return as_m_k(Number<i>{})(m, k); },
Number<NumATensor>{});
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
unpack(a_element_op, data_refs);
}
}
return a_m_k;
}
};
using BComputeType =
typename std::conditional<(NumBTensor > 1),
EDataType,
remove_cvref_t<tuple_element_t<0, BsDataType>>>::type;
auto get_b_matrix = [&]() -> auto {
// in case of pass through we avoid allocating a new
// tensor and copying values
if constexpr(is_same_v<AElementOp, PassThrough>)
{
return bs_k_n(Number<0>{});
}
else
{
Tensor<BComputeType> b_k_n({K, N});
for(int k = 0; k < K; ++k)
{
for(int n = 0; n < N; ++n)
{
// result
auto data_refs1 = ck::tie(b_k_n(k, n));
// inputs
auto data_refs2 =
generate_tie([&](auto i) -> auto& { return bs_k_n(Number<i>{})(k, n); },
Number<NumBTensor>{});
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
unpack(b_element_op, data_refs);
}
}
return b_k_n;
}
};
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<AComputeType,
BComputeType,
AccDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
get_a_matrix(), get_b_matrix(), c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
// compulsory
auto data_refs1 = ck::tie(e_m_n_host_result(m, n), c_m_n(m, n));
// optional (if multiple Ds)
auto data_refs2 =
generate_tie([&](auto i) -> auto& { return ds_m_n(Number<i>{})(m, n); },
Number<NumDTensor>{});
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
unpack(cde_element_op, data_refs);
}
}
}
std::array<DeviceMem*, NumATensor> as_device_buf;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
as_device_buf[i] = new DeviceMem(sizeof(ADataType) * as_m_k(i).mDesc.GetElementSpaceSize());
});
std::array<DeviceMem*, NumBTensor> bs_device_buf;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
bs_device_buf[i] = new DeviceMem(sizeof(BDataType) * bs_k_n(i).mDesc.GetElementSpaceSize());
});
std::array<DeviceMem*, NumDTensor> ds_device_buf;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
ds_device_buf[i] = new DeviceMem(sizeof(DDataType) * ds_m_n(i).mDesc.GetElementSpaceSize());
});
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
static_for<0, NumATensor, 1>{}(
[&](auto i) { as_device_buf[i]->ToDevice(as_m_k(i).mData.data()); });
static_for<0, NumBTensor, 1>{}(
[&](auto i) { bs_device_buf[i]->ToDevice(bs_k_n(i).mData.data()); });
static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_device_buf[i]->ToDevice(ds_m_n(i).mData.data()); });
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
bool pass = true;
// profile device operation instances
for(auto& op_ptr : op_ptrs)
{
std::array<const void*, NumATensor> as_pointer;
std::array<ck::index_t, NumATensor> as_stride;
static_for<0, NumATensor, 1>{}([&](auto i) {
as_pointer[i] = as_device_buf[i]->GetDeviceBuffer();
as_stride[i] = StrideA;
});
std::array<const void*, NumBTensor> bs_pointer;
std::array<ck::index_t, NumBTensor> bs_stride;
static_for<0, NumBTensor, 1>{}([&](auto i) {
bs_pointer[i] = bs_device_buf[i]->GetDeviceBuffer();
bs_stride[i] = StrideB;
});
std::array<const void*, NumDTensor> ds_pointer;
std::array<ck::index_t, NumDTensor> ds_stride;
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_pointer[i] = ds_device_buf[i]->GetDeviceBuffer();
ds_stride[i] = StrideD;
});
auto argument_ptr = op_ptr->MakeArgumentPointer(as_pointer,
bs_pointer,
ds_pointer,
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
as_stride,
bs_stride,
ds_stride,
StrideE,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init E to zero before profiling a kernel
e_device_buf.SetZero();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t sizeADataType = 0;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
sizeADataType = std::max(sizeADataType, sizeof(ADataType));
});
std::size_t sizeBDataType = 0;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
sizeBDataType = std::max(sizeBDataType, sizeof(BDataType));
});
std::size_t num_btype =
sizeADataType * M * K + sizeBDataType * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
static_for<0, NumATensor, 1>{}([&](auto i) { delete as_device_buf[i]; });
static_for<0, NumBTensor, 1>{}([&](auto i) { delete bs_device_buf[i]; });
static_for<0, NumDTensor, 1>{}([&](auto i) { delete ds_device_buf[i]; });
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck

View File

@@ -96,6 +96,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
@@ -234,6 +235,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)

View File

@@ -0,0 +1,180 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemm_multi_abd_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmDataType
{
BF16_I8_BF16_BF16, // 0
};
enum struct GemmElementOp
{
PASS_THROUGH, // 0
MULTIPLY, // 1
ADD, // 2
FASTGELU, // 3
ADD_FASTGELU, // 4
MULTIPLY_ADD, // 5
MULTIPLY_FASTGELU, // 6
MULTIPLY_ADD_FASTGELU, // 7
};
#define OP_NAME "gemm_multi_abd"
#define OP_DESC "GEMM_Multiple_ABD"
int profile_gemm_multi_abd(int argc, char* argv[])
{
if(argc != 18)
{
// clang-format off
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: bf16@int8/bf16->bf16;)\n");
printf("arg3: matrix layout (0: E[m, n] = A[m, k] * B[k, n];\n");
printf(" 1: E[m, n] = A[m, k] * B[n, k];\n");
printf(" 2: E[m, n] = A[k, m] * B[k, n];\n");
printf(" 3: E[m, n] = A[k, m] * B[n, k])\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8: number of As (1)\n");
printf("arg9: number of Bs (1/2)\n");
printf("arg10: number of Ds (0/1/2)\n");
printf("arg11 to 17: M, N, K, StrideA, StrideB, StrideE, StrideD\n");
// clang-format on
exit(1);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const int num_as = std::stoi(argv[8]);
const int num_bs = std::stoi(argv[9]);
const int num_ds = std::stoi(argv[10]);
const int M = std::stoi(argv[11]);
const int N = std::stoi(argv[12]);
const int K = std::stoi(argv[13]);
const int StrideA = std::stoi(argv[14]);
const int StrideB = std::stoi(argv[15]);
const int StrideE = std::stoi(argv[16]);
const int StrideD = std::stoi(argv[17]);
using F32 = float;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
auto profile = [&](auto b_layout, auto b_element_op, auto cde_element_op, auto num_d_tensor) {
using ADataType = BF16;
using B0DataType = I8;
using B1DataType = BF16;
using DDataType = BF16;
using EDataType = BF16;
using ALayout = Row;
using BLayout = decltype(b_layout);
using DLayout = Row;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = decltype(b_element_op);
using CDEElementOp = decltype(cde_element_op);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideD = ck::is_same_v<DLayout, Row> ? N : M;
const int DefaultStrideE = ck::is_same_v<ELayout, Row> ? N : M;
constexpr auto NumberDTensor = decltype(num_d_tensor){};
// Only num_d_tensor == 0 and 1 are supported
using DsDataType = typename std::
conditional<(NumberDTensor == 0), ck::Tuple<>, ck::Tuple<DDataType>>::type;
using DsLayout =
typename std::conditional<(NumberDTensor == 0), ck::Tuple<>, ck::Tuple<DLayout>>::type;
bool pass = ck::profiler::profile_gemm_multi_abd_impl<ck::Tuple<ADataType>,
ck::Tuple<B0DataType, B1DataType>,
F32,
DsDataType,
EDataType,
ck::Tuple<ALayout>,
ck::Tuple<BLayout, BLayout>,
DsLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideD < 0) ? DefaultStrideD : StrideD,
(StrideE < 0) ? DefaultStrideE : StrideE);
return pass ? 0 : 1;
};
// num_as == 1 is only supported
if(data_type != GemmDataType::BF16_I8_BF16_BF16 || num_as != 1)
{
std::cout << "The provided input parameters are not supported" << std::endl;
return 1;
}
// Supported configurations
if(layout == GemmMatrixLayout::MK_KN_MN && num_bs == 2 && num_ds == 1)
{
return profile(Row{}, Multiply{}, AddFastGelu{}, ck::Number<1>{});
}
else if(layout == GemmMatrixLayout::MK_KN_MN && num_bs == 2 && num_ds == 0)
{
return profile(Row{}, Multiply{}, FastGelu{}, ck::Number<0>{});
}
else if(layout == GemmMatrixLayout::MK_NK_MN && num_bs == 2 && num_ds == 1)
{
return profile(Col{}, Multiply{}, AddFastGelu{}, ck::Number<1>{});
}
else if(layout == GemmMatrixLayout::MK_NK_MN && num_bs == 2 && num_ds == 0)
{
return profile(Col{}, Multiply{}, FastGelu{}, ck::Number<0>{});
}
std::cout << "The provided input parameters are not supported" << std::endl;
return 1;
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_multi_abd);

View File

@@ -243,6 +243,7 @@ add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
add_subdirectory(gemm_add)
add_subdirectory(gemm_layernorm)
add_subdirectory(gemm_multi_abd)
add_subdirectory(gemm_split_k)
add_subdirectory(gemm_universal)
add_subdirectory(gemm_b_scale)

View File

@@ -0,0 +1,9 @@
add_gtest_executable(test_gemm_multi_abd_wmma test_gemm_multi_abd_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_multi_abd_wmma PRIVATE utility device_gemm_multi_abd_instance)
endif()
add_gtest_executable(test_gemm_multi_abd_xdl test_gemm_multi_abd_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_multi_abd_xdl PRIVATE utility device_gemm_multi_abd_instance)
endif()

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
namespace ck {
namespace test {
using Row = ck::tensor_layout::gemm::RowMajor;
using F32 = float;
template <typename Tuple>
class TestGemmCommon : public ::testing::Test
{
protected:
using AsLayout = std::tuple_element_t<0, Tuple>;
using BsLayout = std::tuple_element_t<1, Tuple>;
using DsLayout = std::tuple_element_t<2, Tuple>;
using ELayout = Row;
using AsDataType = std::tuple_element_t<3, Tuple>;
using BsDataType = std::tuple_element_t<4, Tuple>;
using DsDataType = std::tuple_element_t<5, Tuple>;
using EDataType = std::tuple_element_t<6, Tuple>;
using AElementOp = std::tuple_element_t<7, Tuple>;
using BElementOp = std::tuple_element_t<8, Tuple>;
using CDEElementOp = std::tuple_element_t<9, Tuple>;
void Run()
{
std::vector<std::vector<ck::index_t>> lengths = {
{16, 32, 64}, {512, 1024, 2048}, {1024, 512, 32}};
bool all_success = true;
for(auto length : lengths)
{
int M = length[0];
int N = length[1];
int K = length[2];
// Assuming same layout for all A matrices (same applies for Bs and Ds)
int StrideA = ck::is_same_v<remove_cvref_t<tuple_element_t<0, AsLayout>>, Row> ? K : M;
int StrideB = ck::is_same_v<remove_cvref_t<tuple_element_t<0, BsLayout>>, Row> ? N : K;
// In case no D matrices are provided, set stride to 0
int StrideD = 0;
if constexpr(DsDataType::Size() > 0)
{
StrideD = ck::is_same_v<remove_cvref_t<tuple_element_t<0, DsLayout>>, Row> ? N : M;
}
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
all_success =
all_success & ck::profiler::profile_gemm_multi_abd_impl<AsDataType,
BsDataType,
F32,
DsDataType,
EDataType,
AsLayout,
BsLayout,
DsLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp>(
1, 2, false, false, M, N, K, StrideA, StrideB, StrideD, StrideE);
}
EXPECT_TRUE(all_success);
}
};
} // namespace test
} // namespace ck

View File

@@ -0,0 +1,154 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_multi_abd_impl.hpp"
#include "test_gemm_common.hpp"
namespace ck {
namespace test {
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using I8 = int8_t;
using BF16 = ck::bhalf_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using Add = ck::tensor_operation::element_wise::Add;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using KernelTypesABD = ::testing::Types<std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
Add>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
Add>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
AddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
AddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
FastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
FastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
PassThrough>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
PassThrough>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD);
TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }
} // namespace test
} // namespace ck

View File

@@ -0,0 +1,154 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_multi_abd_impl.hpp"
#include "test_gemm_common.hpp"
namespace ck {
namespace test {
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using I8 = int8_t;
using BF16 = ck::bhalf_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using Add = ck::tensor_operation::element_wise::Add;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using KernelTypesABD = ::testing::Types<std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
Add>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
Add>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
AddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Multiply,
AddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
FastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
FastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
PassThrough>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<>,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Multiply,
PassThrough>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD);
TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }
} // namespace test
} // namespace ck