Add GemmAddSoftmaxGemm support for MSFT ORT (instances and client API) (#576)

* add instance for gemm bias softmax gemm

* add client example

* change CGridDesc_G_M_N to CGridDesc_G_M_O

* add gridwise

* change c grid name

* device add d0s data

* fix 08 client_example

* add example 47_fused_attention

* example output correct

* add d0 to example

* add d0 element op

* rechange instance code

* change Acc0ElementwiseOperation to C0DEElementwiseOperation

* change example name

* update instance for cdeelementwiseop

* add bhalf_t ScaleAdd

* add test

* not surport geem1 bias

* remove some ignore

* fix test bug

[ROCm/composable_kernel commit: 332ccc3367]
This commit is contained in:
ltqin
2023-02-09 04:34:45 +08:00
committed by GitHub
parent 09d4d0821a
commit 7cb7031d6a
18 changed files with 3787 additions and 106 deletions

View File

@@ -1,2 +1,5 @@
add_executable(client_fused_attention fused_attention.cpp)
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_operations)
add_executable(client_fused_attention_bias fused_attention_bias.cpp)
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_operations)

View File

@@ -0,0 +1,226 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
constexpr static auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
using ADataType = ck::half_t;
using B0DataType = ck::half_t;
using B1DataType = ck::half_t;
using CDataType = ck::half_t;
using D0DataType = ck::half_t;
using AccDataType = float;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
int G0 = 48;
int G1 = 16;
int M = 1024;
int N = 1024;
int K = 64;
int O = 64;
// A layout [G0, M, G1, K]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
// B0 layout [G0, N, G1, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
// B1 layout [G0, N, G1, O]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
// C layout [G0, M, G1, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
// D layout [G0, M, G1, N]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
SimpleDeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K);
SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K);
SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * G0 * G1 * M * N);
SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N);
SimpleDeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O);
using DeviceOp =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<D0DataType>,
ck::Tuple<>,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
MaskingSpec>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device op instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
std::array<void*, 1>{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases
{}, // p_acc1_biases
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides
AElementOp{},
B0ElementOp{},
Acc0ElementOp{1 / sqrtf(K)},
B1ElementOp{},
CElementOp{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(D0DataType) * M * N) *
G0 * G1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best instance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
std::array<void*, 1>{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases
{}, // p_acc1_biases
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides
AElementOp{},
B0ElementOp{},
Acc0ElementOp{1 / sqrtf(K)},
B1ElementOp{},
CElementOp{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}

View File

@@ -0,0 +1 @@
add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp)

View File

@@ -0,0 +1,408 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using C0DEElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using Acc0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
constexpr static auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using F16 = ck::half_t;
using F32 = float;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using D0DataType = F16;
using Acc0BiasDataType = ck::Tuple<D0DataType>;
using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
using DeviceOpInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
C0DEElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
AccDataType,
AccDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
// Ref Gemm1: fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int G0 = 3;
int G1 = 2;
int M = 1024;
int N = 1024;
int K = 64;
int O = 64;
float alpha = 1;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
exit(0);
}
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides{
M * G1 * K, K, G1 * K, 1}; // A layout [G0, M, G1, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides{
N * G1 * K, K, G1 * K, 1}; // B0 layout [G0, N, G1, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides{
N * G1 * O, O, 1, G1 * O}; // B1 layout [G0, N, G1, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{
M * G1 * O, O, G1 * O, 1}; // C layout [G0, M, G1, O]
// D layout [G0, M, G1, N]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<D0DataType> d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides);
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-1, 1});
break;
case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
}
DeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K);
DeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K);
DeviceMem d0_device_buf(sizeof(D0DataType) * G0 * G1 * M * N);
DeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N);
DeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O);
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
d0_device_buf.ToDevice(d0_gs_ms_ns.mData.data());
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto c0de_element_op = C0DEElementOp{alpha};
auto acc0_element_op = Acc0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
auto argument = device_op.MakeArgument(
static_cast<const ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<const B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<const B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases
{}, // p_acc1_biases
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides
a_element_op,
b0_element_op,
c0de_element_op,
b1_element_op,
c_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t BatchCount = G0 * G1;
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype =
(sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
sizeof(CDataType) * M * O + sizeof(D0DataType) * M * N) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
if(do_verification)
{
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
Tensor<D0DataType> d0_g_m_n({BatchCount, M, N});
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
// gemm 0
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
acc0_g_m_n.ForEach([&](auto&, auto idx) {
c0de_element_op(acc0_g_m_n(idx), acc0_g_m_n(idx), d0_g_m_n(idx));
});
// masking
const auto mask = DeviceOpInstance::C0MatrixMask(N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
// softmax
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
// gemm1
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
// default absolute error and relative error is 0.001
double rtol = 1e-3;
double atol = 1e-3;
return ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData,
"Error: Incorrect results!",
rtol,
atol)
? 0
: 1;
}
return 0;
}

View File

@@ -26,9 +26,9 @@ template <index_t NumDimG,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
{
@@ -58,9 +58,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
C0DEElementwiseOperation c0de_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
C1DEElementwiseOperation c1de_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};

View File

@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -25,15 +25,17 @@ namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename D0sPointer,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename C1DEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
@@ -47,16 +49,19 @@ __global__ void
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
D0sPointer p_d0s_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const C0DEElementwiseOperation c0de_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const C1DEElementwiseOperation c1de_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
@@ -77,20 +82,28 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
});
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_d0s_grid,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
c0de_element_op,
b1_element_op,
c_element_op,
c1de_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c1_grid_desc_mblock_mperblock_nblock_nperblock,
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
block_2_ctile_map,
c0_matrix_mask);
#else
@@ -100,13 +113,14 @@ __global__ void
ignore = p_c_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
ignore = c0de_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = c1de_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
@@ -126,15 +140,15 @@ template <index_t NumDimG,
typename BDataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename D0sDataType,
typename D1sDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename C1DEElementwiseOperation,
GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
@@ -192,23 +206,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
BDataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
D0sDataType,
D1sDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
static_assert(NumD1Tensor == 0, "Gemm1 Bias addition is unimplemented");
#if 0
// TODO ANT: use alias
@@ -261,14 +275,40 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number<B1K1>{});
}
static auto MakeD0sGridDescriptor_M_N(
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i],
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
}
static auto MakeD0sGridDescriptor_G_M_N(
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
{
return generate_tuple(
[&](auto i) {
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i],
acc0_biases_gs_ms_ns_strides[i]);
},
Number<NumD0Tensor>{});
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using C1GridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using C1GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0sGridDesc_M_N = decltype(MakeD0sGridDescriptor_M_N({}, {}));
using D0sGridDesc_G_M_N = decltype(MakeD0sGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate()
{
@@ -288,11 +328,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n)
const C1GridDesc_G_M_N& c1_grid_desc_g_m_n,
const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
c1_grid_desc_g_m_n_(c1_grid_desc_g_m_n),
d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n)
{
}
@@ -313,32 +355,42 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
{
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
return c1_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
template <index_t I>
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d0_idx) const
{
return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0));
}
private:
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
C1GridDesc_G_M_N c1_grid_desc_g_m_n_;
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
using GridwiseGemm = GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
D0sDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
C1DEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
C1GridDesc_M_N,
D0sGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
@@ -395,8 +447,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const BDataType* p_b_grid,
const B1DataType* p_b1_grid,
CDataType* p_c_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::array<void*, NumD0Tensor> p_acc0_biases,
const std::array<void*, NumD1Tensor> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
@@ -405,44 +457,48 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD1Tensor>&
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
const std::array<std::vector<ck::index_t>, NumD1Tensor>&
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
C0DEElementwiseOperation c0de_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
C1DEElementwiseOperation c1de_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
p_d0s_grid_{},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
c1_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_g_n_k_{
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
c1_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
d0s_grid_desc_g_m_n_{DeviceOp::MakeD0sGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)},
c1_grid_desc_mblock_mperblock_nblock_nperblock_{},
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c1_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
c0de_element_op_{c0de_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
c1de_element_op_{c1de_element_op},
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
@@ -456,27 +512,39 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_}
batch_count_{c1_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
b1_grid_desc_g_n_k_,
c1_grid_desc_g_m_n_,
d0s_grid_desc_g_m_n_}
{
// TODO ANT: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides;
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
// D0 pointer
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
});
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
c1_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
c1_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c1_grid_desc_m_n_);
D0sGridDesc_M_N d0s_grid_desc_m_n{DeviceOp::MakeD0sGridDescriptor_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)};
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0s_grid_desc_m_n);
}
}
@@ -491,9 +559,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
<< c_grid_desc_g_m_n_.GetLength(I1) << ", "
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n';
std::cout << "c1_grid_desc_g_m_n_: " << c1_grid_desc_g_m_n_.GetLength(I0) << ", "
<< c1_grid_desc_g_m_n_.GetLength(I1) << ", "
<< c1_grid_desc_g_m_n_.GetLength(I2) << '\n';
}
// pointers
@@ -501,18 +569,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
typename GridwiseGemm::D0sGridPointer p_d0s_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
C1GridDesc_M_N c1_grid_desc_m_n_;
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
C1GridDesc_G_M_N c1_grid_desc_g_m_n_;
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
@@ -520,9 +593,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
C0DEElementwiseOperation c0de_element_op_;
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
C1DEElementwiseOperation c1de_element_op_;
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
@@ -551,7 +624,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
arg.block_2_ctile_map_.CalculateGridSize(arg.c1_grid_desc_m_n_) * arg.batch_count_;
// Gemm0_K
const auto K =
@@ -564,15 +637,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
typename GridwiseGemm::D0sGridPointer,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
C1DEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
@@ -587,15 +662,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg.p_b_grid_,
arg.p_b1_grid_,
arg.p_c_grid_,
arg.p_d0s_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.c0de_element_op_,
arg.b1_element_op_,
arg.c_element_op_,
arg.c1de_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_,
@@ -644,9 +721,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// TODO ANT: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t c_g = arg.c1_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c1_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c1_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
@@ -696,7 +773,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.c1_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
@@ -711,8 +788,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const BDataType* p_b,
const B1DataType* p_b1,
CDataType* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::array<void*, NumD0Tensor> p_acc0_biases,
const std::array<void*, NumD1Tensor> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
@@ -721,17 +798,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD1Tensor>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
const std::array<std::vector<ck::index_t>, NumD1Tensor>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
C0DEElementwiseOperation c0de_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
C1DEElementwiseOperation c1de_element_op)
{
return Argument{p_a,
p_b,
@@ -753,9 +830,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
a_element_op,
b_element_op,
acc_element_op,
c0de_element_op,
b1_element_op,
c_element_op};
c1de_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -767,8 +844,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const void* p_b,
const void* p_b1,
void* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::array<void*, NumD0Tensor> p_acc0_biases,
const std::array<void*, NumD1Tensor> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
@@ -777,17 +854,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD1Tensor>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
const std::array<std::vector<ck::index_t>, NumD1Tensor>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
C0DEElementwiseOperation c0de_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
C1DEElementwiseOperation c1de_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
@@ -809,9 +886,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
acc1_biases_gs_ms_gemm1ns_strides,
a_element_op,
b_element_op,
acc_element_op,
c0de_element_op,
b1_element_op,
c_element_op);
c1de_element_op);
}
// polymorphic

View File

@@ -49,6 +49,14 @@ struct Add
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x1);
y = x0 + x1_tmp;
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
@@ -67,6 +75,30 @@ struct Add
};
};
struct ScaleAdd
{
__host__ __device__ ScaleAdd(float scale) : scale_(scale) {}
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <>
__host__ __device__ void
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
{
y = scale_ * x0 + ck::type_convert<float>(x1);
};
template <>
__host__ __device__ void
operator()<float, float, bhalf_t>(float& y, const float& x0, const bhalf_t& x1) const
{
y = scale_ * x0 + ck::type_convert<float>(x1);
};
float scale_;
};
struct Subtract
{
template <typename T>

View File

@@ -91,6 +91,7 @@ using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;

View File

@@ -0,0 +1,190 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
template <typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpec>>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpec>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t> &&
Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, half_t>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16> &&
Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, BF16>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,7 @@
add_instance_library(device_batched_gemm_softmax_gemm_permute_instance
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
)

View File

@@ -0,0 +1,133 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
MaskingSpecialization MaskingSpec>
using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances =
std::tuple<
// clang-format off
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec|
// #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>
// clang-format on
>;
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<
2,
1,
1,
1,
1,
MaskingSpecialization::MaskOutUpperTriangle>{});
}
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<
2,
1,
1,
1,
1,
MaskingSpecialization::MaskDisabled>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,133 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
MaskingSpecialization MaskingSpec>
using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
std::tuple<
// clang-format off
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec|
// #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>
// clang-format on
>;
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<
2,
1,
1,
1,
1,
MaskingSpecialization::MaskOutUpperTriangle>{});
}
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<
2,
1,
1,
1,
1,
MaskingSpecialization::MaskDisabled>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,395 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
namespace ck {
namespace profiler {
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasesDataType,
typename Acc1BiasesDataType,
tensor_operation::device::MaskingSpecialization MaskingSpec>
bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int O,
int G0,
int G1,
float alpha = -1.f)
{
using PassThrough = tensor_operation::element_wise::PassThrough;
using ScaleAdd = tensor_operation::element_wise::ScaleAdd;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using C0DEElementOp = ScaleAdd;
using Acc0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
using AccDataType = float;
using D0DataType = tuple_element_t<0, Acc0BiasesDataType>;
using tensor_operation::device::MaskingSpecialization;
// Ref Gemm0: various type in, fp32 out
using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
AccDataType,
AccDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp>;
// Ref Softmax: fp32 in, various type out
using ReferenceSoftmaxInstance =
tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
// Ref Gemm1: various type in, various type out
using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
bool pass = true;
// A layout [G0, M, G1, K]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
// B0 layout [G0, N, G1, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
// B1 layout [G0, N, G1, O]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
// C layout [G0, M, G1, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
// D layout [G0, M, G1, N]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
const int BatchCount = G0 * G1;
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<D0DataType> d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
std::srand(1); // work around test flakiness
switch(init_method)
{
case 0: break;
case 1:
// Still unsure whether this kind of deterministic floating point accurary issue is expected
// or not. May want to try exact same approach as the GPU kernel in the host reference
// GEMM+Softmax+GEMM function to see if the accuracy discrepancy goes away. Until then,
// shrink the input value range as it is less likely to produce errors of around ~1e-3.
// a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
// b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
// b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
break;
case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
d0_device_buf.ToDevice(d0_gs_ms_ns.mData.data());
if(alpha < 0)
{
alpha = 1.f / std::sqrt(K); // usually 1 / sqrt(head_dim)
}
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto c0de_element_op = C0DEElementOp{alpha};
auto acc0_element_op = Acc0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
using DeviceOp =
tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasesDataType,
ck::Tuple<>,
AElementOp,
B0ElementOp,
C0DEElementOp,
B1ElementOp,
CElementOp,
MaskingSpec>;
// get device op instances
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
if(do_verification)
{
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
Tensor<D0DataType> d0_g_m_n({BatchCount, M, N});
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
acc0_g_m_n.ForEach([&](auto&, auto idx) {
c0de_element_op(acc0_g_m_n(idx), acc0_g_m_n(idx), d0_g_m_n(idx));
});
// mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle && idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
});
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
}
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device op instances
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
std::array<void*, 1>{
d0_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_strides}, // std::array<std::vector<ck::index_t>,
// 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
a_element_op,
b0_element_op,
c0de_element_op,
b1_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::string op_name = op_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(D0DataType) * M * N) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
// default absolute error and relative error is 0.001
double rtol = 1e-3;
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> &&
std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> &&
std::is_same_v<CDataType, ck::bhalf_t> &&
std::is_same_v<D0DataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
}
pass = pass & ck::utils::check_err(c_gs_ms_os_device_result,
c_gs_ms_os_host_result,
"Error: Incorrect results!",
rtol,
atol);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a_gs_ms_ks: ", a_gs_ms_ks.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b0_gs_ns_ks : ", b0_gs_ns_ks.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b1_gs_os_ns : ", b1_gs_os_ns.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "c_gs_ms_os_device_result : ",
c_gs_ms_os_device_result.mData,
",")
<< std::endl;
}
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck

View File

@@ -5,4 +5,11 @@ add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_ge
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp)
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16)

View File

@@ -0,0 +1,182 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_batched_gemm_bias_softmax_gemm_permute_util.hpp"
template <typename Tuple>
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16
: public TestBatchedGemmMaskingScaleSoftmaxGemmPermute<Tuple>
{
};
using I1_t = ck::Number<1>;
using I2_t = ck::Number<2>;
using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskOutUpperTriangle_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>;
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskOutUpperTriangle_t>
>;
// clang-format on
TYPED_TEST_SUITE(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, KernelTypes);
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Test_BF16) { this->Run(); }
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadM)
{
this->lengths_ = std::vector<std::vector<int>>{
{136, 128, 32, 128, 2, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 136, 32, 128, 3, 2},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 40, 128, 2, 4},
{128, 128, 136, 128, 4, 2},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 136, 1, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddM)
{
this->lengths_ = std::vector<std::vector<int>>{
{129, 128, 32, 128, 2, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 129, 32, 128, 4, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 33, 128, 2, 3},
{128, 128, 129, 128, 2, 3},
};
this->Run();
}
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 129, 2, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Bench_BF16_IrregularK)
{
this->lengths_ = std::vector<std::vector<int>>{{256, 256, 160, 160, 1, 16},
{256, 64, 160, 64, 1, 16},
{1024, 1024, 80, 80, 1, 16},
{1024, 64, 80, 64, 1, 16},
{4096, 4096, 40, 40, 1, 16},
{4096, 64, 40, 64, 1, 16}};
this->bench_ = true;
this->verify_ = false;
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Bench_BF16)
{
this->lengths_ = std::vector<std::vector<int>>{
{256, 256, 64, 64, 48, 16},
{256, 256, 128, 128, 48, 16},
{512, 512, 64, 64, 48, 16},
{512, 512, 128, 128, 48, 16},
{1024, 1024, 64, 64, 48, 16},
{1024, 1024, 128, 128, 48, 16},
{2048, 2048, 64, 64, 48, 16},
{2048, 2048, 128, 128, 48, 16},
{4096, 4096, 64, 64, 48, 16},
{4096, 4096, 128, 128, 48, 16},
};
this->bench_ = true;
this->verify_ = false;
this->Run();
}
using ck::tensor_operation::device::GemmSpecialization;
TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch)
{
int P = 120; // requires padding
int Q = 128; // do not require padding
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(Q, Q, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MPadding>{}.IsSupported(P, Q, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::NPadding>{}.IsSupported(Q, P, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNPadding>{}.IsSupported(P, P, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::OPadding>{}.IsSupported(Q, Q, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MOPadding>{}.IsSupported(P, Q, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::NOPadding>{}.IsSupported(Q, P, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNOPadding>{}.IsSupported(P, P, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
// clang-format on
}
TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMismatch)
{
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
// clang-format on
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, AdhocTest)
{
this->lengths_ = std::vector<std::vector<int>>{
{49, 49, 64, 64, 4, 6},
{64, 49, 64, 64, 4, 6},
{1020, 1020, 64, 128, 4, 6},
{576, 576, 64, 64, 4, 6},
};
this->Run();
}

View File

@@ -0,0 +1,182 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_batched_gemm_softmax_gemm_permute_util.hpp"
template <typename Tuple>
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
: public TestBatchedGemmMaskingScaleSoftmaxGemmPermute<Tuple>
{
};
using I1_t = ck::Number<1>;
using I2_t = ck::Number<2>;
using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskOutUpperTriangle_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>;
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskOutUpperTriangle_t>
>;
// clang-format on
TYPED_TEST_SUITE(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, KernelTypes);
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16) { this->Run(); }
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadM)
{
this->lengths_ = std::vector<std::vector<int>>{
{136, 128, 32, 128, 2, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 136, 32, 128, 3, 2},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 40, 128, 2, 4},
{128, 128, 136, 128, 4, 2},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 136, 1, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddM)
{
this->lengths_ = std::vector<std::vector<int>>{
{129, 128, 32, 128, 2, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 129, 32, 128, 4, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 33, 128, 2, 3},
{128, 128, 129, 128, 2, 3},
};
this->Run();
}
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 129, 2, 3},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP16_IrregularK)
{
this->lengths_ = std::vector<std::vector<int>>{{256, 256, 160, 160, 1, 16},
{256, 64, 160, 64, 1, 16},
{1024, 1024, 80, 80, 1, 16},
{1024, 64, 80, 64, 1, 16},
{4096, 4096, 40, 40, 1, 16},
{4096, 64, 40, 64, 1, 16}};
this->bench_ = true;
this->verify_ = false;
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP16)
{
this->lengths_ = std::vector<std::vector<int>>{
{256, 256, 64, 64, 48, 16},
{256, 256, 128, 128, 48, 16},
{512, 512, 64, 64, 48, 16},
{512, 512, 128, 128, 48, 16},
{1024, 1024, 64, 64, 48, 16},
{1024, 1024, 128, 128, 48, 16},
{2048, 2048, 64, 64, 48, 16},
{2048, 2048, 128, 128, 48, 16},
{4096, 4096, 64, 64, 48, 16},
{4096, 4096, 128, 128, 48, 16},
};
this->bench_ = true;
this->verify_ = false;
this->Run();
}
using ck::tensor_operation::device::GemmSpecialization;
TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch)
{
int P = 120; // requires padding
int Q = 128; // do not require padding
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(Q, Q, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MPadding>{}.IsSupported(P, Q, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NPadding>{}.IsSupported(Q, P, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNPadding>{}.IsSupported(P, P, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::OPadding>{}.IsSupported(Q, Q, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MOPadding>{}.IsSupported(P, Q, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NOPadding>{}.IsSupported(Q, P, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNOPadding>{}.IsSupported(P, P, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
// clang-format on
}
TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMismatch)
{
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
// clang-format on
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
{
this->lengths_ = std::vector<std::vector<int>>{
{49, 49, 64, 64, 4, 6},
{64, 49, 64, 64, 4, 6},
{1020, 1020, 64, 128, 4, 6},
{576, 576, 64, 64, 4, 6},
};
this->Run();
}

View File

@@ -0,0 +1,380 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp"
using ck::tensor_operation::device::GemmSpecialization;
using ck::tensor_operation::device::MaskingSpecialization;
using ck::tensor_operation::device::TensorSpecialization;
template <ck::index_t N>
using I = ck::Number<N>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
{
using NumDimGType = std::tuple_element_t<0, Tuple>;
using NumDimMType = std::tuple_element_t<1, Tuple>;
using NumDimNType = std::tuple_element_t<2, Tuple>;
using NumDimKType = std::tuple_element_t<3, Tuple>;
using NumDimOType = std::tuple_element_t<4, Tuple>;
using ADataType = std::tuple_element_t<5, Tuple>;
using B0DataType = std::tuple_element_t<6, Tuple>;
using B1DataType = std::tuple_element_t<7, Tuple>;
using CDataType = std::tuple_element_t<8, Tuple>;
using Acc0BiasDataType = std::tuple_element_t<9, Tuple>;
using Acc1BiasDataType = std::tuple_element_t<10, Tuple>;
using MaskingType = std::tuple_element_t<11, Tuple>;
std::vector<std::vector<int>> lengths_ = {
{256, 256, 64, 64, 6, 4},
{256, 256, 128, 128, 4, 6},
{512, 512, 64, 64, 3, 2},
{512, 512, 128, 128, 2, 3},
{1024, 1024, 64, 64, 3, 1},
{1024, 1024, 128, 128, 1, 1},
};
bool bench_ = false;
bool verify_ = true;
void RunSingle(int M, int N, int K, int O, int G0, int G1)
{
bool pass =
ck::profiler::profile_batched_gemm_bias_softmax_gemm_permute_impl<NumDimGType::value,
NumDimMType::value,
NumDimNType::value,
NumDimKType::value,
NumDimOType::value,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
MaskingType::value>(
verify_, 2, false, bench_, M, N, K, O, G0, G1);
EXPECT_TRUE(pass);
}
void Run()
{
for(auto lengths : this->lengths_)
{
int M = lengths[0];
int N = lengths[1];
int K = lengths[2];
int O = lengths[3];
int G0 = lengths[4];
int G1 = lengths[5];
this->RunSingle(M, N, K, O, G0, G1);
}
}
};
template <GemmSpecialization GemmSpec>
struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = float;
using CShuffleDataType = F16;
using CDataType = F16;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ScaleAdd;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using DeviceGemmGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<F16>,
ck::Tuple<>,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecialization::Default, // ATensorSpec
TensorSpecialization::Default, // B0TensorSpec
TensorSpecialization::Default, // B1TensorSpec
TensorSpecialization::Default, // CTensorSpec
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle
bool IsSupported(int M, int N, int K, int O)
{
const int G0 = 1, G1 = 1;
// A layout [G0, M, G1, K]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
// B0 layout [G0, N, G1, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
// B1 layout [G0, N, G1, O]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
// C layout [G0, M, G1, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
// D layout [G0, M, G1, N]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
auto gemm = DeviceGemmGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
static_cast<B0DataType*>(nullptr),
static_cast<B1DataType*>(nullptr),
static_cast<CDataType*>(nullptr),
std::array<void*, 1>{nullptr}, // p_acc0_biases
{}, // p_acc1_biases
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides
PassThrough{}, // a_element_op
PassThrough{}, // b0_element_op
Acc0ElementOp{1.f}, // acc0_element_op
PassThrough{}, // b1_element_op
PassThrough{}); // c_element_op
return gemm.IsSupportedArgument(argument);
}
};
template <GemmSpecialization GemmSpec>
struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using ADataType = BF16;
using B0DataType = BF16;
using B1DataType = BF16;
using AccDataType = float;
using CShuffleDataType = BF16;
using CDataType = BF16;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ScaleAdd;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using DeviceGemmGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<BF16>,
ck::Tuple<>,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecialization::Default, // ATensorSpec
TensorSpecialization::Default, // B0TensorSpec
TensorSpecialization::Default, // B1TensorSpec
TensorSpecialization::Default, // CTensorSpec
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle
bool IsSupported(int M, int N, int K, int O)
{
const int G0 = 1, G1 = 1;
// A layout [G0, M, G1, K]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
// B0 layout [G0, N, G1, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
// B1 layout [G0, N, G1, O]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
// C layout [G0, M, G1, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
// D layout [G0, M, G1, N]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
auto gemm = DeviceGemmGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
static_cast<B0DataType*>(nullptr),
static_cast<B1DataType*>(nullptr),
static_cast<CDataType*>(nullptr),
std::array<void*, 1>{nullptr}, // p_acc0_biases
{}, // p_acc1_biases
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides
PassThrough{}, // a_element_op
PassThrough{}, // b0_element_op
Acc0ElementOp{1.f}, // acc0_element_op
PassThrough{}, // b1_element_op
PassThrough{}); // c_element_op
return gemm.IsSupportedArgument(argument);
}
};