Implement batched gemm add relu gemm add for rdna4 (#3391)

* wip: test suite for batched gemm multiple d gemm multiple d, working on gridwise implenentation

* wip: many fixes in implementation of batched gemm gemm multiple d

* wip: batched gemm gemm multiple d gridwise op compiling, not working yet

* fix: incorrect d0 grid indexing in batched gemm gemm multipled

* feat: add instances for batched gemm add relu gemm add

* chore: configure instance with low vector transfer size for odd sizes

* chore: add some more validation to device batched gemm gemm multiple d, and removed template parameter that didn't really make sense

* fix: upate device_batched_gemm_gemm_wmma to work with new gridwise changes

* fix: disable odd size tests on XDL archs

* chore: removed temporary logging

* chore: update some references to C tensor to E tensor

* Tentative fix for example template params

* Tentative fix for non-multi-D batched gemm gemm device impl.

* Tentative fix for xdl example template params

* Tentative fix for profiler build on gfx90a

* chore: improve device batched gemm gemm multi D comment to include all ops and dimensions

* chore: explicitly call ck::make_tuple to prevent issues when std::make_tuple would apply

* fix: make the gemm1 data types match what happens in the device op

* feat: add d0s/d1s datatypes and layouts to the device op type string

* chore: change element-wise op so addition happens in fp32

* chore: add static asserts for gemm0/gemm1 calculated wave sizes

* chore: also updated other element-wise ops to use fp32 calculations

* chore: log number of supported instances

* chore: update instance comment

* chore: disable kernel timing in example by default

* fix: gemm1 wave size calculation

* fix: make sure batched gemm multiple d gemm multiple d profiler performs correct type conversions

* chore: remove increased tolerance in batched gemm gemm multiple d example

* chore: add comment explaining that verification fails for certain input values

* chore: clarify instance comment

---------

Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>

[ROCm/composable_kernel commit: d5ae81b292]
This commit is contained in:
Erwin Terpstra
2026-01-20 22:06:59 +01:00
committed by GitHub
parent b8595c5684
commit fa56471c91
22 changed files with 2956 additions and 499 deletions

View File

@@ -2,3 +2,4 @@
# SPDX-License-Identifier: MIT
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_wmma_fp16 batched_gemm_add_add_relu_gemm_add_wmma_fp16.cpp)

View File

@@ -0,0 +1,135 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/*
Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o]
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "element_ops.h"
using ::ck::DeviceMem;
using ::ck::HostTensorDescriptor;
using ::ck::Tensor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using A0DataType = F16;
using B0DataType = F16;
using D00DataType = F16;
using D01DataType = F16;
using B1DataType = F16;
using D1DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using E1DataType = F16;
using A0Layout = Row;
using B0Layout = Col;
using D00Layout = Row;
using D01Layout = Row;
using B1Layout = Row;
using D1Layout = Row;
using E1Layout = Row;
using A0ElementOp = PassThrough;
using B0ElementOp = PassThrough;
using CDE0ElementOp = AddAddRelu;
using A1ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3<
A0Layout,
B0Layout,
ck::Tuple<D00Layout, D01Layout>,
B1Layout,
ck::Tuple<D1Layout>,
E1Layout,
A0DataType,
B0DataType,
ck::Tuple<D00DataType, D01DataType>,
B1DataType,
ck::Tuple<D1DataType>,
E1DataType,
AccDataType,
CShuffleDataType,
A0ElementOp,
B0ElementOp,
CDE0ElementOp,
B1ElementOp,
CDE1ElementOp,
GemmSpec,
32, // BlockSize
16, // MPerBlock
64, // LPerBlock
64, // KPerBlock
64, // NPerBlock (Gemm1NPerBlock)
64, // LTilePerBlock (Gemm1KPerBlock)
8, // AK1
8, // BK1
8, // L1 (B1K1)
16, // MPerWmma
16, // LPerWmma
1, // MRepeat
4, // LRepeat (Gemm0NRepeat)
4, // NRepeat (Gemm1NRepeat)
S<2, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM
S<2, 16, 1>, // B0BlockTransferThreadClusterLengths_K0_L_K1
S<1, 0, 2>, // B0BlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // B0BlockTransferSrcAccessOrder
2, // B0BlockTransferSrcVectorDim
8, // B0BlockTransferSrcScalarPerVector
8, // B0BlockTransferDstScalarPerVector_K1
false, // B0BlockLdsAddExtraL
4, // CDE0BlockTransferSrcScalarPerVector
S<2, 16, 1>, // B1BlockTransferThreadClusterLengths_L0_N_L1
S<0, 2, 1>, // B1BlockTransferThreadClusterArrangeOrder
S<0, 2, 1>, // B1BlockTransferSrcAccessOrder
1, // B1BlockTransferSrcVectorDim
4, // B1BlockTransferSrcScalarPerVector
2, // B1BlockTransferDstScalarPerVector_L1
true, // B1BlockLdsAddExtraN
1, // CShuffleMRepeatPerShuffle
2, // CShuffleNRepeatPerShuffle
S<1, 16, 1, 2>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
#include "batched_gemm_multiple_d_gemm_multiple_d.inc"
int main(int argc, char* argv[]) { return run_example(argc, argv); }

View File

@@ -22,6 +22,8 @@ Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "element_ops.h"
using ::ck::DeviceMem;
using ::ck::HostTensorDescriptor;
using ::ck::Tensor;
@@ -39,11 +41,10 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using A0DataType = F16;
using B0DataType = F16;
using Acc0DataType = F32;
using AccDataType = F32;
using D00DataType = F16;
using D01DataType = F16;
using B1DataType = F16;
using Acc1DataType = F32;
using C1ShuffleDataType = F32;
using D1DataType = F16;
using E1DataType = F16;
@@ -56,58 +57,6 @@ using B1Layout = Row;
using D1Layout = Row;
using E1Layout = Row;
// E = Relu(C + D0 + D1)
struct AddAddRelu
{
__host__ __device__ void
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const ck::half_t x = c + d0 + d1;
ck::tensor_operation::element_wise::Relu{}.operator()(e, x);
}
__host__ __device__ void
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const float x = c + (d0 + d1);
ck::tensor_operation::element_wise::Relu{}.operator()(e, x);
}
};
// E = Gelu(C + D0 + D1)
struct AddAddGelu
{
__host__ __device__ void
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const ck::half_t x = c + d0 + d1;
ck::tensor_operation::element_wise::Gelu{}.template operator()<ck::half_t, ck::half_t>(e,
x);
}
__host__ __device__ void
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const float x = c + (d0 + d1);
ck::tensor_operation::element_wise::Gelu{}.template operator()<float, float>(e, x);
}
};
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
__host__ __device__ void
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const float x = c + (d0 + d1);
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(e, x);
}
};
using A0ElementOp = PassThrough;
using B0ElementOp = PassThrough;
using CDE0ElementOp = AddAddRelu;
@@ -131,10 +80,10 @@ using DeviceGemmInstance =
E1Layout,
A0DataType,
B0DataType,
Acc0DataType,
AccDataType,
ck::Tuple<D00DataType, D01DataType>,
B1DataType,
Acc1DataType,
AccDataType,
C1ShuffleDataType,
ck::Tuple<D1DataType>,
E1DataType,
@@ -191,337 +140,5 @@ using DeviceGemmInstance =
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA0 = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideD00 = -1;
ck::index_t StrideD01 = -1;
ck::index_t StrideB1 = -1;
ck::index_t StrideD1 = -1;
ck::index_t StrideE1 = -1;
ck::index_t BatchStrideA0 = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideD00 = -1;
ck::index_t BatchStrideD01 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideD1 = -1;
ck::index_t BatchStrideE1 = -1;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
}
else if(argc == 23)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
StrideA0 = std::stoi(argv[9]);
StrideB0 = std::stoi(argv[10]);
StrideD00 = std::stoi(argv[11]);
StrideD01 = std::stoi(argv[12]);
StrideB1 = std::stoi(argv[13]);
StrideD1 = std::stoi(argv[14]);
StrideE1 = std::stoi(argv[15]);
BatchStrideA0 = std::stoi(argv[16]);
BatchStrideB0 = std::stoi(argv[17]);
BatchStrideD00 = std::stoi(argv[18]);
BatchStrideD01 = std::stoi(argv[19]);
BatchStrideB1 = std::stoi(argv[20]);
BatchStrideD1 = std::stoi(argv[21]);
BatchStrideE1 = std::stoi(argv[22]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 8: M, N, K, O, Batch\n");
printf(
"arg9 to 15: StrideA0, StrideB0, StrideD00, StrideD01, StrideB1, StrideD1, StrideE1\n");
printf("arg16 to 22: BatchStrideA0, BatchStrideB0, BatchStrideD00, BatchStrideD01, "
"BatchStrideB1, BatchStrideD1, BatchStrideE1 \n");
exit(0);
}
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideD00 = ck::is_same_v<D00Layout, Row> ? N : M;
const int DefaultStrideD01 = ck::is_same_v<D01Layout, Row> ? N : M;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideD00 = (StrideD00 < 0) ? DefaultStrideD00 : StrideD00;
StrideD01 = (StrideD01 < 0) ? DefaultStrideD01 : StrideD01;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideD00 = (ck::is_same_v<D00Layout, Col> ? N : M) * StrideD00;
const int DefaultBatchStrideD01 = (ck::is_same_v<D01Layout, Col> ? N : M) * StrideD01;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideD00 = BatchStrideD00 < 0 ? DefaultBatchStrideD00 : BatchStrideD00;
BatchStrideD01 = BatchStrideD01 < 0 ? DefaultBatchStrideD01 : BatchStrideD01;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor(
{batch_count, row, col}, {batch_stride, stride, 1_uz}, layout);
}
else
{
return HostTensorDescriptor(
{batch_count, row, col}, {batch_stride, 1_uz, stride}, layout);
}
};
// E_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<A0DataType> a0_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<D00DataType> d00_g_m_n(
f_host_tensor_descriptor(BatchCount, M, N, StrideD00, BatchStrideD00, D00Layout{}));
Tensor<D01DataType> d01_g_m_n(
f_host_tensor_descriptor(BatchCount, M, N, StrideD01, BatchStrideD01, D01Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<D1DataType> d1_g_m_o(
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
Tensor<E1DataType> e1_g_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
Tensor<E1DataType> e1_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "d00_g_m_n: " << d00_g_m_n.mDesc
<< " size: " << d00_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
std::cout << "d01_g_m_n: " << d01_g_m_n.mDesc
<< " size: " << d01_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
d00_g_m_n.GenerateTensorValue(GeneratorTensor_2<D00DataType>{-2, 3});
d01_g_m_n.GenerateTensorValue(GeneratorTensor_2<D01DataType>{-2, 3});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
break;
case 2:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
d00_g_m_n.GenerateTensorValue(GeneratorTensor_3<D00DataType>{0.0, 1.0});
d01_g_m_n.GenerateTensorValue(GeneratorTensor_3<D01DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
break;
default:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
d00_g_m_n.GenerateTensorValue(GeneratorTensor_1<D00DataType>{1});
d01_g_m_n.GenerateTensorValue(GeneratorTensor_1<D01DataType>{1});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1});
}
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
DeviceMem d00_g_m_n_device_buf(sizeof(D00DataType) * d00_g_m_n.mDesc.GetElementSpaceSize());
DeviceMem d01_g_m_n_device_buf(sizeof(D01DataType) * d01_g_m_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
e1_g_m_o_device_result.mDesc.GetElementSize());
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
d00_g_m_n_device_buf.ToDevice(d00_g_m_n.mData.data());
d01_g_m_n_device_buf.ToDevice(d01_g_m_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
auto a0_element_op = A0ElementOp{};
auto b0_element_op = B0ElementOp{};
auto cde0_element_op = CDE0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto cde1_element_op = CDE1ElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
std::array<const void*, 2>{d00_g_m_n_device_buf.GetDeviceBuffer(),
d01_g_m_n_device_buf.GetDeviceBuffer()},
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
std::array<ck::index_t, 2>{StrideD00, StrideD01},
StrideB1,
std::array<ck::index_t, 1>{StrideD1},
StrideE1,
BatchStrideA0,
BatchStrideB0,
std::array<ck::index_t, 2>{BatchStrideD00, BatchStrideD01},
BatchStrideB1,
std::array<ck::index_t, 1>{BatchStrideD1},
BatchStrideE1,
a0_element_op,
b0_element_op,
cde0_element_op,
b1_element_op,
cde1_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype =
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D00DataType) * N +
sizeof(D01DataType) * N + sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O +
sizeof(D1DataType) * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
if(do_verification)
{
using ReferenceGemm0Instance =
ck::tensor_operation::host::ReferenceBatchedGemm<A0DataType,
B0DataType,
Acc0DataType,
Acc0DataType,
A0ElementOp,
B0ElementOp,
PassThrough>;
using ReferenceGemm1Instance =
ck::tensor_operation::host::ReferenceBatchedGemm<Acc0DataType,
B1DataType,
Acc1DataType,
Acc1DataType,
PassThrough,
B1ElementOp,
PassThrough>;
// Output of Gemm0 is input A of Gemm1
Tensor<Acc0DataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<Acc0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<Acc1DataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias+bias+relu
e0_g_m_n.ForEach([&](auto&, auto idx) {
cde0_element_op(e0_g_m_n(idx), c0_g_m_n(idx), d00_g_m_n(idx), d01_g_m_n(idx));
});
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
// bias
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
});
return ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result) ? 0 : 1;
}
return 0;
}
#include "batched_gemm_multiple_d_gemm_multiple_d.inc"
int main(int argc, char* argv[]) { return run_example(argc, argv); }

View File

@@ -0,0 +1,350 @@
int run_example(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 256;
ck::index_t O = 512;
ck::index_t BatchCount = 4;
ck::index_t StrideA0 = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideD00 = -1;
ck::index_t StrideD01 = -1;
ck::index_t StrideB1 = -1;
ck::index_t StrideD1 = -1;
ck::index_t StrideE1 = -1;
ck::index_t BatchStrideA0 = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideD00 = -1;
ck::index_t BatchStrideD01 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideD1 = -1;
ck::index_t BatchStrideE1 = -1;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
}
else if(argc == 23)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
StrideA0 = std::stoi(argv[9]);
StrideB0 = std::stoi(argv[10]);
StrideD00 = std::stoi(argv[11]);
StrideD01 = std::stoi(argv[12]);
StrideB1 = std::stoi(argv[13]);
StrideD1 = std::stoi(argv[14]);
StrideE1 = std::stoi(argv[15]);
BatchStrideA0 = std::stoi(argv[16]);
BatchStrideB0 = std::stoi(argv[17]);
BatchStrideD00 = std::stoi(argv[18]);
BatchStrideD01 = std::stoi(argv[19]);
BatchStrideB1 = std::stoi(argv[20]);
BatchStrideD1 = std::stoi(argv[21]);
BatchStrideE1 = std::stoi(argv[22]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 8: M, N, K, O, Batch\n");
printf(
"arg9 to 15: StrideA0, StrideB0, StrideD00, StrideD01, StrideB1, StrideD1, StrideE1\n");
printf("arg16 to 22: BatchStrideA0, BatchStrideB0, BatchStrideD00, BatchStrideD01, "
"BatchStrideB1, BatchStrideD1, BatchStrideE1 \n");
exit(0);
}
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideD00 = ck::is_same_v<D00Layout, Row> ? N : M;
const int DefaultStrideD01 = ck::is_same_v<D01Layout, Row> ? N : M;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideD00 = (StrideD00 < 0) ? DefaultStrideD00 : StrideD00;
StrideD01 = (StrideD01 < 0) ? DefaultStrideD01 : StrideD01;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideD00 = (ck::is_same_v<D00Layout, Col> ? N : M) * StrideD00;
const int DefaultBatchStrideD01 = (ck::is_same_v<D01Layout, Col> ? N : M) * StrideD01;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideD00 = BatchStrideD00 < 0 ? DefaultBatchStrideD00 : BatchStrideD00;
BatchStrideD01 = BatchStrideD01 < 0 ? DefaultBatchStrideD01 : BatchStrideD01;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor(
{batch_count, row, col}, {batch_stride, stride, 1_uz}, layout);
}
else
{
return HostTensorDescriptor(
{batch_count, row, col}, {batch_stride, 1_uz, stride}, layout);
}
};
// E_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<A0DataType> a0_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<D00DataType> d00_g_m_n(
f_host_tensor_descriptor(BatchCount, M, N, StrideD00, BatchStrideD00, D00Layout{}));
Tensor<D01DataType> d01_g_m_n(
f_host_tensor_descriptor(BatchCount, M, N, StrideD01, BatchStrideD01, D01Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<D1DataType> d1_g_m_o(
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
Tensor<E1DataType> e1_g_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
Tensor<E1DataType> e1_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "d00_g_m_n: " << d00_g_m_n.mDesc
<< " size: " << d00_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
std::cout << "d01_g_m_n: " << d01_g_m_n.mDesc
<< " size: " << d01_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
d00_g_m_n.GenerateTensorValue(GeneratorTensor_2<D00DataType>{-2, 3});
d01_g_m_n.GenerateTensorValue(GeneratorTensor_2<D01DataType>{-2, 3});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
break;
case 2:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
d00_g_m_n.GenerateTensorValue(GeneratorTensor_3<D00DataType>{0.0, 1.0});
d01_g_m_n.GenerateTensorValue(GeneratorTensor_3<D01DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
break;
default:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
d00_g_m_n.GenerateTensorValue(GeneratorTensor_1<D00DataType>{1});
d01_g_m_n.GenerateTensorValue(GeneratorTensor_1<D01DataType>{1});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1});
}
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
DeviceMem d00_g_m_n_device_buf(sizeof(D00DataType) * d00_g_m_n.mDesc.GetElementSpaceSize());
DeviceMem d01_g_m_n_device_buf(sizeof(D01DataType) * d01_g_m_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
e1_g_m_o_device_result.mDesc.GetElementSize());
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
d00_g_m_n_device_buf.ToDevice(d00_g_m_n.mData.data());
d01_g_m_n_device_buf.ToDevice(d01_g_m_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
auto a0_element_op = A0ElementOp{};
auto b0_element_op = B0ElementOp{};
auto cde0_element_op = CDE0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto cde1_element_op = CDE1ElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
std::array<const void*, 2>{d00_g_m_n_device_buf.GetDeviceBuffer(),
d01_g_m_n_device_buf.GetDeviceBuffer()},
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
std::array<ck::index_t, 2>{StrideD00, StrideD01},
StrideB1,
std::array<ck::index_t, 1>{StrideD1},
StrideE1,
BatchStrideA0,
BatchStrideB0,
std::array<ck::index_t, 2>{BatchStrideD00, BatchStrideD01},
BatchStrideB1,
std::array<ck::index_t, 1>{BatchStrideD1},
BatchStrideE1,
a0_element_op,
b0_element_op,
cde0_element_op,
b1_element_op,
cde1_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype =
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D00DataType) * N +
sizeof(D01DataType) * N + sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O +
sizeof(D1DataType) * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
if(do_verification)
{
using ReferenceGemm0Instance =
ck::tensor_operation::host::ReferenceBatchedGemm<A0DataType,
B0DataType,
AccDataType,
AccDataType,
A0ElementOp,
B0ElementOp,
PassThrough>;
using ReferenceGemm1Instance =
ck::tensor_operation::host::ReferenceBatchedGemm<A0DataType,
B1DataType,
AccDataType,
AccDataType,
PassThrough,
B1ElementOp,
PassThrough>;
// Output of Gemm0 is input A of Gemm1
Tensor<AccDataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<A0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<AccDataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// bias+bias+relu
// Note that we also convert from AccDataType to A0DataType to match what the device
// operation does
e0_g_m_n.ForEach([&](auto&, auto idx) {
AccDataType out;
cde0_element_op(out, c0_g_m_n(idx), d00_g_m_n(idx), d01_g_m_n(idx));
e0_g_m_n(idx) = ck::type_convert<A0DataType>(out);
});
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
// bias
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
});
// NOTE: For float initialization (mode 2) verification currently fails due to inaccuracy.
// This seems to just be accumulating errors due to double gemm. It only seems to happen
// when using B1 tensor containing negative values, as this can get large values from gemm0
// back to zero again but reduce the tolerance allowed by the relative tolerance.
//
// There doesn't seem to be any bug with the implementation, just a difference in order of
// operations between CPU and GPU causing an accumulating error.
bool validation_result = ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result);
std::cout << "Validation result: " << (validation_result ? "SUCCESS" : "FAIL") << "."
<< std::endl;
return validation_result ? 0 : 1;
}
return 0;
}

View File

@@ -0,0 +1,58 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/ck.hpp"
// E = Relu(C + D0 + D1)
struct AddAddRelu
{
__host__ __device__ void
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const ck::half_t x = c + d0 + d1;
ck::tensor_operation::element_wise::Relu{}.operator()(e, x);
}
__host__ __device__ void
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const float x = c + ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
ck::tensor_operation::element_wise::Relu{}.operator()(e, x);
}
};
// E = Gelu(C + D0 + D1)
struct AddAddGelu
{
__host__ __device__ void
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const ck::half_t x = c + d0 + d1;
ck::tensor_operation::element_wise::Gelu{}.template operator()<ck::half_t, ck::half_t>(e,
x);
}
__host__ __device__ void
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const float x = c + ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
ck::tensor_operation::element_wise::Gelu{}.template operator()<float, float>(e, x);
}
};
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
__host__ __device__ void
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const float x = c + ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(e, x);
}
};

View File

@@ -20,6 +20,7 @@
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/utility/tuple.hpp"
namespace ck {
namespace tensor_operation {
@@ -51,12 +52,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
arg.p_a_grid + a_batch_offset,
arg.p_b0_grid + b0_batch_offset,
Tuple<>{}, // p_d0s_grid
arg.p_b1_grid + b1_batch_offset,
Tuple<>{}, // p_d1s_grid
arg.p_c_grid + c_batch_offset,
p_shared,
arg.a_grid_desc,
arg.b0_grid_desc,
Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
arg.b1_grid_desc,
Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op,
arg.b0_element_op,
@@ -240,8 +245,10 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
// DataType Family
ADataType,
B0DataType,
Tuple<>, // Ds0DataType
AccDataType, // Acc0DataType
B1DataType,
Tuple<>, // Ds1DataType
AccDataType, // Acc1DataType
CShuffleDataType,
CDataType,
@@ -255,7 +262,9 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
// InMemory Data Descriptor
AGridDesc,
B0GridDesc,
Tuple<>, // Ds0GridDesc
B1GridDesc,
Tuple<>, // Ds1GridDesc
CGridDesc_M_N,
// Tiling Family
MPerBlock,
@@ -290,6 +299,7 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
B0BlockTransferDstScalarPerVector_K1,
true,
B0BlockLdsAddExtraL,
1, // CDE0BlockTransferSrcScalarPerVector
B1BlockTransferThreadClusterLengths_L0_N_L1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
@@ -369,8 +379,8 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1);
}
// Pointers
const ADataType* p_a_grid;
@@ -405,10 +415,10 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
B0GridDesc b0_grid_desc;
B1GridDesc b1_grid_desc;
CGridDesc_M_N c_grid_desc_m_n;
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
};
@@ -500,7 +510,9 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
Tuple<>{},
arg.b1_grid_desc,
Tuple<>{},
arg.c_grid_desc_m_n,
arg.block_2_ctile_map))
{

View File

@@ -825,6 +825,11 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
{
if(!ck::is_xdl_wmma_supported<A0DataType, B0DataType, Gemm0MPerXdl, Gemm0NPerXdl>())
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "wrong! XDL/WMMA not supported for these datatypes or operation sizes."
<< std::endl;
}
return false;
}
@@ -843,6 +848,11 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>() &&
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "wrong! Unsupported tensor layout combination." << std::endl;
}
return false;
}

View File

@@ -101,6 +101,15 @@ struct GemmGemmPadder
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
}
// D0[M, N]
template <typename D0Desc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadD0Descriptor_N_K(const D0Desc_MRaw_NRaw& d0_desc_mraw_nraw) const
{
return PadTensorDescriptor(
d0_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
}
// B1[Gemm1N, Gemm1K] = B1[O, N]
template <typename B1Desc_NRaw_KRaw>
__host__ __device__ constexpr auto

View File

@@ -13,31 +13,35 @@
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
// Gemm0: AccOp(A [M x K] x B0 [K x L], D0) = Acc [M x L]
// Gemm1: CDEOp1(Acc [M x L] x B1 [L x N], D1) = E [M x N]
template <typename ADataType,
typename B0DataType,
typename D0sDataType,
typename Acc0DataType,
typename B1DataType,
typename D1sDataType,
typename Acc1DataType,
typename CShuffleDataType,
typename CDataType,
typename E1DataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc,
typename B0GridDesc,
typename D0sGridDesc,
typename B1GridDesc,
typename CGridDesc_M_N,
typename D1sGridDesc,
typename E1GridDesc,
index_t MPerBlock,
index_t LPerBlock,
index_t KPerBlock,
@@ -69,6 +73,7 @@ template <typename ADataType,
index_t B0BlockTransferDstScalarPerVector_K1,
bool B0ThreadTransferSrcResetCoordinateAfterRun,
bool B0BlockLdsExtraL,
index_t CDE0BlockTransferSrcScalarPerVector,
typename B1BlockTransferThreadClusterLengths_L0_N_L1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
@@ -79,8 +84,8 @@ template <typename ADataType,
bool B1BlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
bool PadN,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1>
@@ -94,6 +99,9 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
@@ -105,9 +113,19 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
static constexpr auto BL0 = Number<L0PerBlock>{};
static constexpr auto BL1 = Number<L1Value>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WaveSize0 = BlockSize / (MWaves * LWaves);
static constexpr auto WaveSize1 = BlockSize / (MWaves * NWaves);
static constexpr auto WaveSize = WaveSize0;
static_assert(
WaveSize0 == 32 || WaveSize0 == 64,
"Misconfigured wave parameters: BlockSize / (MWaves * LWaves) != 32/64 threads per wave");
static_assert(
WaveSize1 == 32 || WaveSize1 == 64,
"Misconfigured wave parameters: BlockSize / (MWaves * NWaves) != 32/64 threads per wave");
static constexpr index_t KPerWmmaBlk =
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma
@@ -212,6 +230,52 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
return b1_block_copy_step;
}
// ck::Tuple<const D0DataType1*, const D0DataType2*, ...>
static constexpr auto MakeD0sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D0iDataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return static_cast<const D0iDataType*>(nullptr);
},
Number<NumD0Tensor>{});
}
// ck::Tuple<const D1DataType1*, const D1DataType2*, ...>
static constexpr auto MakeD1sGridPointer()
{
return generate_tuple(
[&](auto i) {
using D1iDataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
return static_cast<const D1iDataType*>(nullptr);
},
Number<NumD1Tensor>{});
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, LWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / LPerWmma, LPerWmma))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
template <index_t MNRepeat, index_t MNWaves, index_t MNPerWmma, typename BlockDesc>
__host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&)
{
@@ -369,14 +433,14 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
{
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
constexpr auto c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
I1,
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
return c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
@@ -432,12 +496,14 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
true>())>; // TransposeC (must be true to work), C' = B' x A'
// block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01}
template <typename Block2CTileMap>
template <typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
const B0GridDesc& b0_grid_desc,
const D0sGridDesc& d0s_grid_desc,
const B1GridDesc& b1_grid_desc,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
const D1sGridDesc& d1s_grid_desc,
const E1GridDesc& c_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
// Print lambda with env check and printf() style formmating.
const char* curFunc = __func__;
@@ -482,6 +548,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
return false;
}
bool d0s_desc_valid = true;
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
if(!(M == d0s_grid_desc[i].GetLength(I0) && L == d0s_grid_desc[i].GetLength(I1)))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
print("GridwiseOp: M/L Length err, A_M/B0_L = %d, %d | D0s_M/N = %d, %d\n",
M,
L,
d0s_grid_desc[i].GetLength(I0),
d0s_grid_desc[i].GetLength(I1));
}
d0s_desc_valid = false;
}
});
bool d1s_desc_valid = true;
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
if(!(M == d1s_grid_desc[i].GetLength(I0) && N == d1s_grid_desc[i].GetLength(I1)))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
print("GridwiseOp: M/N Length err, A_M/N = %d, %d | D1s_M/N = %d, %d\n",
M,
N,
d1s_grid_desc[i].GetLength(I0),
d1s_grid_desc[i].GetLength(I1));
}
d1s_desc_valid = false;
}
});
if(!(d0s_desc_valid && d1s_desc_valid))
{
return false;
}
if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
@@ -513,11 +617,11 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
return false;
}
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
if(!block_2_etile_map.CheckValidity(c_grid_desc_m_n))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
print("GridwiseOp: invalid block_2_ctile_map\n");
print("GridwiseOp: invalid block_2_etile_map\n");
}
return false;
}
@@ -539,37 +643,94 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const E1GridDesc& e_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
const auto e1_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
return e1_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
// D0 desc for source in blockwise copy
template <typename D0GridDesc_M_N>
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
const D0GridDesc_M_N& d0_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto wmma =
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma;
return transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / MPerBlock, MRepeat, MWaves, MPerWmma)),
make_unmerge_transform(make_tuple(N / LPerBlock,
LRepeat,
LWaves,
WaveSize / LPerWmma,
wmma.num_acc_vgprs_per_wave))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 3, 4>{}, Sequence<1, 5, 6, 7, 8>{}));
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
// D0s desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
const D0sGridDesc& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeD0GridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
ds_grid_desc_m_n[i]);
},
Number<NumD0Tensor>{});
}
// Ds desc for source in blockwise copy
template <typename DsGridDescriptor_M_N>
__host__ __device__ static constexpr auto
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumD1Tensor>{});
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2ETileMap(
const E1GridDesc& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, E1GridDesc>(c_grid_desc_m_n);
}
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
E1GridDesc{}))>;
using D0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs =
remove_cvref_t<
decltype(MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
D0sGridDesc{}))>;
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
D1sGridDesc{}))>;
using DefaultBlock2ETileMap =
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(E1GridDesc{}, 1, 1))>;
struct SharedMemTrait
{
@@ -600,45 +761,69 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
.GetElementSpaceSize();
};
using D0sGridPointer = decltype(MakeD0sGridPointer());
using D1sGridPointer = decltype(MakeD1sGridPointer());
template <bool HasMainKBlockLoop,
TailNumber TailNum,
typename Block2CTileMap = DefaultBlock2CTileMap>
typename Block2ETileMap = DefaultBlock2ETileMap>
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
D0sGridPointer p_d0s_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
D1sGridPointer p_d1s_grid,
E1DataType* __restrict__ p_e1_grid,
void* __restrict__ p_shared,
const AGridDesc& a_grid_desc,
const B0GridDesc& b0_grid_desc,
const D0sGridDesc& d0s_grid_desc,
const B1GridDesc& b1_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
e1_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation& a_element_op,
const B0ElementwiseOperation& b0_element_op,
const AccElementwiseOperation& acc_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map)
const CDEElementwiseOperation& c_element_op,
const Block2ETileMap& block_2_etile_map)
{
// clang-format off
/*******************************************************************************/
// Memory buffer zone.
const auto d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(d0s_grid_desc);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc.GetElementSpaceSize());
const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b0_grid, b0_grid_desc.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto e1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e1_grid, e1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto d0s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0s_grid[i],
d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i].GetElementSpaceSize());
},
Number<NumD0Tensor>{});
const auto d1s_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d1s_grid[i],
d1s_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumD1Tensor>{});
/*******************************************************************************/
// BlockIdx.x -> [BlockId.m, BlockId.n]
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
const auto block_work_idx = block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_etile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
make_tuple(e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{ return; }
// Store BlockId into SGPR
@@ -757,6 +942,72 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
// d0 matrix threadwise copy
constexpr auto d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
make_naive_tensor_descriptor_packed(make_tuple(
I1, // MBlockId
I1, // NBlockID
mrepeat,
mwave,
mthreadpersubgroup,
lrepeat,
lwave,
lsubgroup,
laccvgprs));
auto d0s_thread_buf = generate_tuple(
[&](auto i) {
using D0iDataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return StaticBuffer<
AddressSpaceEnum::Vgpr,
D0iDataType,
d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetElementSpaceSize(),
true>{};
},
Number<NumD0Tensor>{});
const auto wave_id = GetGemm0WaveIdx(); // I0: MWaves, I1: LWaves, I2: WaveSize
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I0: WaveSize / LPerWmma, I1: LPerWmma
static_assert(CDE0BlockTransferSrcScalarPerVector <= laccvgprs,
"vector load must be not greater than n4");
static_assert(laccvgprs % CDE0BlockTransferSrcScalarPerVector == 0);
auto d0s_threadwise_copy = generate_tuple(
[&](auto i) {
using D0iDataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
return ThreadwiseTensorSliceTransfer_v2<
D0iDataType,
D0iDataType,
decltype(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i]),
decltype(d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
Sequence<I1, // MBlockId
I1, // NBlockID
mrepeat,
mwave,
mthreadpersubgroup,
lrepeat,
lwave,
lsubgroup,
laccvgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>,
8, // NOTE: XDL has this exposed as CDE0BlockTransferSrcVectorDim.
// But as the grid descriptor is built internally, the parameter doesn't really make sense to configure per instance
CDE0BlockTransferSrcScalarPerVector,
1,
false>(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i],
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
0, // mrepeat
wave_id[I0], // mwave
wave_m_n_id[I1], // mthreadpersubgroup
0, // nrepeat
wave_id[I1], // nwave
wave_m_n_id[I0], // nsubgroup
0)); // register number
},
Number<NumD0Tensor>{});
constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor(
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lwave, lsubgroup)),
@@ -924,9 +1175,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
b_scale_struct,
KBlockMainLoop,
1); // num_k_block_per_scale
// multiple d
if constexpr(NumD0Tensor)
{
constexpr auto d0s_thread_buf_size = d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetElementSpaceSize();
static_for<0, acc0_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).Run(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i],
d0s_grid_buf[i],
d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
d0s_thread_buf(i));
});
static_for<0, d0s_thread_buf_size, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; },
Number<NumD0Tensor>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& { return acc0_thread_buf(i); },
Number<2>{});
unpack2(acc_element_op, dst_data_refs, src_data_refs);
});
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
d0s_threadwise_copy(i).MoveSrcSliceWindow(
d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i],
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0));
});
}
else
{
static_for<0, acc0_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
}
block_sync_lds();
@@ -995,15 +1281,15 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
}
} // end gemm1
constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
constexpr auto c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
constexpr auto c_mrepeat = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
constexpr auto c_mwave = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
constexpr auto c_mthreadpersubgroup = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
constexpr auto c_nrepeat = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
constexpr auto c_nwave = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
constexpr auto c_nsubgroup = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
constexpr auto c_naccvgprs = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup,
@@ -1032,29 +1318,29 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
/*******************************************************************************/
// write out to C, implement shuffle
{
constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
constexpr auto c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
// This API Provide All dimension (size) you need
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp =
constexpr auto c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp =
blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1);
constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2);
constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4);
constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5);
constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6);
constexpr auto MWave = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1);
constexpr auto MThreadPerSubGroup = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2);
constexpr auto NWave = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4);
constexpr auto NSubGroup = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5);
constexpr auto NAccVgprs = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6);
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
constexpr auto c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
auto c1_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
@@ -1097,10 +1383,10 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
auto c1_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<Acc1DataType,
CShuffleDataType,
decltype(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
decltype(c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
decltype(c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
@@ -1125,36 +1411,68 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
const auto e1_d1s_desc_refs = concat_tuple_of_reference(
tie(c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumD1Tensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c1_d1s_buf_refs = concat_tuple_of_reference(
tie(c1_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return d1s_grid_buf[i]; },
Number<NumD1Tensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c1_d1s_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumD1Tensor>{}));
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
auto cde1_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), D1sDataType{})),
Tuple<E1DataType>,
decltype(e1_d1s_desc_refs),
decltype(tie(e1_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
Sequence<static_cast<index_t>(CGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray
// type
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumD1Tensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{e1_d1s_desc_refs,
idx_c1_d1s_block_begin,
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
c_element_op};
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
constexpr auto sfc_c1_vgpr =
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, NAccVgprs>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
Sequence<CShuffleMRepeatPerShuffle,
@@ -1166,7 +1484,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
NAccVgprs>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
constexpr auto sfc_e1_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
@@ -1174,37 +1492,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
constexpr index_t num_access = sfc_c1_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_assert(num_access == sfc_e1_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c1_thread_copy_vgpr_to_lds.Run(c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
sfc_c1_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
c_shuffle_block_buf);
c1_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
cde1_shuffle_block_copy_lds_to_global.Run(
e1_d1s_desc_refs,
c1_d1s_buf_refs,
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e1_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
constexpr auto e1_global_step = sfc_e1_global.GetForwardStep(access_id);
// move on D1s
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
cde1_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow(
e1_d1s_desc_refs, i + I1, e1_global_step);
});
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
cde1_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), I0, e1_global_step);
}
});
}

View File

@@ -219,6 +219,30 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
//
// D0
//
static auto MakeD0GridDescriptorPair(const std::vector<index_t>& d0_gs_ms_ns_lengths_vec,
const std::vector<index_t>& d0_gs_ms_ns_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimN, CSpec>(d0_gs_ms_ns_lengths_vec,
d0_gs_ms_ns_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
static auto MakeD0GridDescriptor_G_M_N(const std::vector<index_t>& d0_gs_ms_ns_lengths_vec,
const std::vector<index_t>& d0_gs_ms_ns_strides_vec)
{
return MakeD0GridDescriptorPair(d0_gs_ms_ns_lengths_vec, d0_gs_ms_ns_strides_vec).first;
}
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d0_gs_ms_ns_lengths_vec,
const std::vector<index_t>& d0_gs_ms_ns_strides_vec)
{
return matrix_padder.PadD0Descriptor_M_N(
MakeD0GridDescriptorPair(d0_gs_ms_ns_lengths_vec, d0_gs_ms_ns_strides_vec).second);
}
//
// B1
//

View File

@@ -14,7 +14,7 @@ namespace ck {
template <typename F, index_t... ids>
__host__ __device__ constexpr auto generate_tuple_for(F&& f, Sequence<ids...>)
{
return make_tuple(f(Number<ids>{})...);
return ck::make_tuple(f(Number<ids>{})...);
}
template <typename F, index_t N>

View File

@@ -20,6 +20,52 @@ namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_USE_WMMA
#ifdef CK_ENABLE_FP16
void add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Row,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>&
instances);
void add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Col,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>&
instances);
#endif // CK_ENABLE_FP16
#endif // CK_USE_WMMA
#ifdef CK_USE_XDL
#ifdef CK_ENABLE_FP16
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
@@ -59,7 +105,8 @@ void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_
PassThrough,
CDE1ElementOp>>>&
instances);
#endif // CK_ENABLE_FP16
#endif // CK_USE_XDL
template <typename A0Layout,
typename B0Layout,
typename D0sLayout,
@@ -113,22 +160,36 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<A0DataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<E1DataType, half_t>)
{
if constexpr(is_same_v<A0Layout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<E1Layout, Row>)
{
#ifdef CK_USE_XDL
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
#endif
#ifdef CK_USE_WMMA
add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
#endif
}
else if constexpr(is_same_v<A0Layout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Col> && is_same_v<E1Layout, Row>)
{
#ifdef CK_USE_XDL
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
op_ptrs);
#endif
#ifdef CK_USE_WMMA
add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
op_ptrs);
#endif
}
}
#endif
return op_ptrs;
}
};

View File

@@ -1,8 +1,11 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
)

View File

@@ -0,0 +1,72 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
std::tuple<
// clang-format off
//#####################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| D0DataType| B1Data| D1DataType| E1Data| AccData| CShuffle| A0| B0| CDE0| B1| CDE1| GemmSpecialization| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| MRepeat| LRepeat| NRepeat|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CDE0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//#####################################################| | | | | | | Type| Type| | Type| | Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | Size| MPer| NPer| KPer| NPer| KPer| | | | WMMA| WMMA| | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//#####################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | | | |Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//#####################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// No padding
DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, F32, F32, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, GemmSpecialization::Default, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 4, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 1, 2, S<1, 16, 1, 2>, 8>,
// Fallback with padding
DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, F32, F32, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, GemmSpecialization::MNKOPadding, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, false, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, true, 1, 2, S<1, 16, 1, 2>, 1>
// clang-format on
>;
void add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Row,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,72 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances =
std::tuple<
// clang-format off
//#####################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| D0DataType| B1Data| D1DataType| E1Data| AccData| CShuffle| A0| B0| CDE0| B1| CDE1| GemmSpecialization| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| MRepeat| LRepeat| NRepeat|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CDE0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//#####################################################| | | | | | | Type| Type| | Type| | Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | Size| MPer| NPer| KPer| NPer| KPer| | | | WMMA| WMMA| | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//#####################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | | | |Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//#####################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// No padding
DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, F32, F32, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, GemmSpecialization::Default, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 2, true, 1, 2, S<1, 16, 1, 2>, 8>,
// Fallback with padding
DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, F32, F32, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, GemmSpecialization::MNKOPadding, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, false, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 2, true, 1, 2, S<1, 16, 1, 2>, 1>
// clang-format on
>;
void add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Col,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,387 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
namespace ck {
namespace profiler {
template <typename A0Layout,
typename B0Layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename E1Layout,
typename A0DataType,
typename B0DataType,
typename D0sDataType,
typename B1DataType,
typename D1sDataType,
typename E1DataType,
typename A0ElementOp,
typename B0ElementOp,
typename CDE0ElementOp,
typename B1ElementOp,
typename CDE1ElementOp>
bool profile_batched_gemm_multiple_d_gemm_multiple_d_impl(
bool do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int O,
int BatchCount = 1,
int StrideA0 = -1,
int StrideB0 = -1,
int StrideD0 = -1,
int StrideB1 = -1,
int StrideD1 = -1,
int StrideE1 = -1,
int BatchStrideA0 = -1,
int BatchStrideB0 = -1,
int BatchStrideD0 = -1,
int BatchStrideB1 = -1,
int BatchStrideD1 = -1,
int BatchStrideE1 = -1,
bool fail_if_no_supported_instance = false)
{
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough;
using D0DataType = remove_cvref_t<tuple_element_t<0, D0sDataType>>;
using D0Layout = remove_cvref_t<tuple_element_t<0, D0sLayout>>;
using D1DataType = remove_cvref_t<tuple_element_t<0, D1sDataType>>;
using D1Layout = remove_cvref_t<tuple_element_t<0, D1sLayout>>;
// for reference
using RefAcc0DataType = float;
using RefAcc1DataType = float;
bool pass = true;
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideD0 = (StrideD0 < 0) ? DefaultStrideD0 : StrideD0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideD0 = (ck::is_same_v<D0Layout, Col> ? N : M) * StrideD0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideD0 = BatchStrideD0 < 0 ? DefaultBatchStrideD0 : BatchStrideD0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor(
{batch_count, row, col}, {batch_stride, stride, 1_uz}, layout);
}
else
{
return HostTensorDescriptor(
{batch_count, row, col}, {batch_stride, 1_uz, stride}, layout);
}
};
// E_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<A0DataType> a0_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<D0DataType> d0_g_m_n(
f_host_tensor_descriptor(BatchCount, M, N, StrideD0, BatchStrideD0, D0Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<D1DataType> d1_g_m_o(
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
Tensor<E1DataType> e1_g_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
Tensor<E1DataType> e1_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
// Host verification: Output of Gemm0 is input A of Gemm1
Tensor<RefAcc0DataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<A0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<RefAcc1DataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "d0_g_m_n: " << d0_g_m_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "d1_g_m_o: " << d1_g_m_o.mDesc << std::endl;
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
d0_g_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 3});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
break;
default:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_g_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
DeviceMem d0_g_m_n_device_buf(sizeof(D0DataType) * d0_g_m_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
e1_g_m_o_device_result.mDesc.GetElementSize());
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
d0_g_m_n_device_buf.ToDevice(d0_g_m_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
auto a0_element_op = A0ElementOp{};
auto b0_element_op = B0ElementOp{};
auto cde0_element_op = CDE0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto cde1_element_op = CDE1ElementOp{};
using DeviceOp =
tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
B0Layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
A0DataType,
B0DataType,
D0sDataType,
B1DataType,
D1sDataType,
E1DataType,
A0ElementOp,
B0ElementOp,
CDE0ElementOp,
B1ElementOp,
CDE1ElementOp>;
// get device op instances
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
if(do_verification)
{
// Ref Gemm0
using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm<A0DataType,
B0DataType,
RefAcc0DataType,
RefAcc0DataType,
A0ElementOp,
B0ElementOp,
PassThrough>;
// Ref Gemm1
using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm<A0DataType,
B1DataType,
RefAcc1DataType,
RefAcc1DataType,
PassThrough,
B1ElementOp,
PassThrough>;
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// cde0_elementwise
// Note that we also convert from Acc0DataType to A0DataType to match what the device
// operation does
e0_g_m_n.ForEach([&](auto&, auto idx) {
RefAcc0DataType out;
cde0_element_op(out, c0_g_m_n(idx), d0_g_m_n(idx));
e0_g_m_n(idx) = ck::type_convert<A0DataType>(out);
});
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
// cde1_elementwise
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
});
}
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int instances_supported = 0;
// profile device op instances
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d0_g_m_n_device_buf.GetDeviceBuffer()},
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
std::array<ck::index_t, 1>{StrideD0},
StrideB1,
std::array<ck::index_t, 1>{StrideD1},
StrideE1,
BatchStrideA0,
BatchStrideB0,
std::array<ck::index_t, 1>{BatchStrideD0},
BatchStrideB1,
std::array<ck::index_t, 1>{BatchStrideD1},
BatchStrideE1,
a0_element_op,
b0_element_op,
cde0_element_op,
b1_element_op,
cde1_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
++instances_supported;
std::string op_name = op_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype =
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D0DataType) * N +
sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O + sizeof(D1DataType) * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
pass = pass & ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result);
if(do_log)
{
LogRangeAsType<float>(
std::cout << "e1_g_m_o_host_result : ", e1_g_m_o_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "e1_g_m_o_device_result : ", e1_g_m_o_device_result.mData, ",")
<< std::endl;
}
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
}
if(instances_supported == 0)
{
if(do_log)
{
std::cout << "Warning! No supported instances found." << std::endl;
}
if(fail_if_no_supported_instance)
{
return false;
}
};
printf("\033[36mFound %d supported instances\033[0m\n", instances_supported);
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck

View File

@@ -275,6 +275,7 @@ add_subdirectory(batched_contraction)
add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_multiple_d_gemm_multiple_d)
add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(batched_gemm_softmax_gemm_permute)
add_subdirectory(batched_gemm_b_scale)

View File

@@ -0,0 +1,12 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
# the instance library if there's no instances present for the current arch.
if (CK_USE_XDL OR CK_USE_WMMA)
add_gtest_executable(test_batched_gemm_add_relu_gemm_add test_batched_gemm_add_relu_gemm_add.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_add_relu_gemm_add PRIVATE utility device_batched_gemm_add_relu_gemm_add_instance)
endif()
endif()

View File

@@ -0,0 +1,27 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_batched_gemm_multiple_d_gemm_multiple_d.hpp"
template <typename Tuple>
class TestBatchedGemmMultipleDGemmMultipleD
: public BaseTestBatchedGemmMultipleDGemmMultipleD<Tuple>
{
};
using A0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, A0ElementOp, B0ElementOp, CDE0ElementOp, B1ElementOp, CDE1ElementOp>,
std::tuple<F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, A0ElementOp, B0ElementOp, CDE0ElementOp, B1ElementOp, CDE1ElementOp>
>;
// clang-format on
TYPED_TEST_SUITE(TestBatchedGemmMultipleDGemmMultipleD, KernelTypes);
#include "test_batched_gemm_multiple_d_gemm_multiple_d_ut_cases.inc"

View File

@@ -0,0 +1,121 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <vector>
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "profiler/profile_batched_gemm_multiple_d_gemm_multiple_d_impl.hpp"
using ck::tensor_operation::device::GemmSpecialization;
template <ck::index_t N>
using I = ck::Number<N>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
struct BaseTestBatchedGemmMultipleDGemmMultipleD : public ::testing::Test
{
using ADataType = std::tuple_element_t<0, Tuple>;
using B0DataType = std::tuple_element_t<1, Tuple>;
using D0sDataType = std::tuple_element_t<2, Tuple>;
using B1DataType = std::tuple_element_t<3, Tuple>;
using D1sDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
using ALayout = std::tuple_element_t<6, Tuple>;
using B0Layout = std::tuple_element_t<7, Tuple>;
using D0sLayout = std::tuple_element_t<8, Tuple>;
using B1Layout = std::tuple_element_t<9, Tuple>;
using D1sLayout = std::tuple_element_t<10, Tuple>;
using ELayout = std::tuple_element_t<11, Tuple>;
using A0ElementOp = std::tuple_element_t<12, Tuple>;
using B0ElementOp = std::tuple_element_t<13, Tuple>;
using CDE0ElementOp = std::tuple_element_t<14, Tuple>;
using B1ElementOp = std::tuple_element_t<15, Tuple>;
using CDE1ElementOp = std::tuple_element_t<16, Tuple>;
std::vector<std::vector<int>> lengths_ = {
{256, 256, 64, 64, 4},
{256, 256, 128, 128, 4},
{512, 512, 64, 64, 2},
{512, 512, 128, 128, 2},
{1024, 1024, 64, 64, 1},
{1024, 1024, 128, 128, 1},
};
bool bench_ = false;
bool verify_ = true;
void RunSingle(int M, int N, int K, int O, int BatchCount)
{
// WMMA instances are setup to support all the test cases
// XDL instances are not.
bool fail_if_no_supported_instances = ck::is_gfx11_supported() || ck::is_gfx12_supported();
bool pass =
ck::profiler::profile_batched_gemm_multiple_d_gemm_multiple_d_impl<ALayout,
B0Layout,
D0sLayout,
B1Layout,
D1sLayout,
ELayout,
ADataType,
B0DataType,
D0sDataType,
B1DataType,
D1sDataType,
EDataType,
A0ElementOp,
B0ElementOp,
CDE0ElementOp,
B1ElementOp,
CDE1ElementOp>(
verify_,
1,
false,
bench_,
M,
N,
K,
O,
BatchCount,
-1,
-1,
-1,
-1,
-1,
-1,
-1,
-1,
-1,
-1,
-1,
-1,
fail_if_no_supported_instances);
EXPECT_TRUE(pass);
}
void Run()
{
for(auto lengths : this->lengths_)
{
int M = lengths[0];
int N = lengths[1];
int K = lengths[2];
int O = lengths[3];
int BatchCount = lengths[4];
this->RunSingle(M, N, K, O, BatchCount);
}
}
};

View File

@@ -0,0 +1,88 @@
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test) { this->Run(); }
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadM)
{
this->lengths_ = std::vector<std::vector<int>>{
{136, 128, 32, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 136, 32, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 40, 128, 1},
{128, 128, 136, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 136, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddM)
{
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
}
this->lengths_ = std::vector<std::vector<int>>{
{129, 128, 32, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddN)
{
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
}
this->lengths_ = std::vector<std::vector<int>>{
{128, 129, 32, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddK)
{
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
}
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 33, 128, 1},
{128, 128, 129, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddO)
{
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
}
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 129, 1},
};
this->Run();
}