mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add GemmAddSoftmaxGemm support for MSFT ORT (instances and client API) (#576)
* add instance for gemm bias softmax gemm
* add client example
* change CGridDesc_G_M_N to CGridDesc_G_M_O
* add gridwise
* change c grid name
* device add d0s data
* fix 08 client_example
* add example 47_fused_attention
* example output correct
* add d0 to example
* add d0 element op
* rechange instance code
* change Acc0ElementwiseOperation to C0DEElementwiseOperation
* change example name
* update instance for cdeelementwiseop
* add bhalf_t ScaleAdd
* add test
* not surport geem1 bias
* remove some ignore
* fix test bug
[ROCm/composable_kernel commit: 332ccc3367]
This commit is contained in:
@@ -1,2 +1,5 @@
|
||||
add_executable(client_fused_attention fused_attention.cpp)
|
||||
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_operations)
|
||||
|
||||
add_executable(client_fused_attention_bias fused_attention_bias.cpp)
|
||||
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_operations)
|
||||
|
||||
226
client_example/08_fused_attention/fused_attention_bias.cpp
Normal file
226
client_example/08_fused_attention/fused_attention_bias.cpp
Normal file
@@ -0,0 +1,226 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
constexpr static auto MaskingSpec =
|
||||
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using B0DataType = ck::half_t;
|
||||
using B1DataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
using D0DataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
int G0 = 48;
|
||||
int G1 = 16;
|
||||
int M = 1024;
|
||||
int N = 1024;
|
||||
int K = 64;
|
||||
int O = 64;
|
||||
|
||||
// A layout [G0, M, G1, K]
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B0 layout [G0, N, G1, K]
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B1 layout [G0, N, G1, O]
|
||||
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
|
||||
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
|
||||
|
||||
// C layout [G0, M, G1, O]
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
|
||||
|
||||
// D layout [G0, M, G1, N]
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
|
||||
|
||||
SimpleDeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K);
|
||||
SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K);
|
||||
SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * G0 * G1 * M * N);
|
||||
SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N);
|
||||
SimpleDeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O);
|
||||
|
||||
using DeviceOp =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
ck::Tuple<>,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
MaskingSpec>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device op instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
a_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
std::array<void*, 1>{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases
|
||||
{}, // p_acc1_biases
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
|
||||
{}, // acc1_biases_gs_ms_os_lengths
|
||||
{}, // acc1_biases_gs_ms_os_strides
|
||||
AElementOp{},
|
||||
B0ElementOp{},
|
||||
Acc0ElementOp{1 / sqrtf(K)},
|
||||
B1ElementOp{},
|
||||
CElementOp{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1;
|
||||
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
|
||||
sizeof(D0DataType) * M * N) *
|
||||
G0 * G1;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
// run the best instance
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
a_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
std::array<void*, 1>{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases
|
||||
{}, // p_acc1_biases
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
|
||||
{}, // acc1_biases_gs_ms_os_lengths
|
||||
{}, // acc1_biases_gs_ms_os_strides
|
||||
AElementOp{},
|
||||
B0ElementOp{},
|
||||
Acc0ElementOp{1 / sqrtf(K)},
|
||||
B1ElementOp{},
|
||||
CElementOp{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
1
example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt
Normal file
1
example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp)
|
||||
@@ -0,0 +1,408 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using C0DEElementOp = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
constexpr static auto MaskingSpec =
|
||||
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
|
||||
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F16;
|
||||
using D0DataType = F16;
|
||||
using Acc0BiasDataType = ck::Tuple<D0DataType>;
|
||||
using Acc1BiasDataType = ck::Tuple<>;
|
||||
|
||||
static constexpr ck::index_t NumDimG = 2;
|
||||
static constexpr ck::index_t NumDimM = 1;
|
||||
static constexpr ck::index_t NumDimN = 1;
|
||||
static constexpr ck::index_t NumDimK = 1;
|
||||
static constexpr ck::index_t NumDimO = 1;
|
||||
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
|
||||
NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDimO,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
C0DEElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
TensorSpecA,
|
||||
TensorSpecB0,
|
||||
TensorSpecB1,
|
||||
TensorSpecC,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
64, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
2, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<16, 16, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
MaskingSpec>; // MaskingSpecialization
|
||||
|
||||
// Ref Gemm0: fp16 in, fp32 out
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B0DataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp>;
|
||||
|
||||
// Ref Softmax: fp32 in, fp16 out
|
||||
using ReferenceSoftmaxInstance =
|
||||
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
|
||||
|
||||
// Ref Gemm1: fp16 in, fp16 out
|
||||
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
int G0 = 3;
|
||||
int G1 = 2;
|
||||
int M = 1024;
|
||||
int N = 1024;
|
||||
int K = 64;
|
||||
int O = 64;
|
||||
float alpha = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
G0 = std::stoi(argv[8]);
|
||||
G1 = std::stoi(argv[9]);
|
||||
|
||||
alpha = std::stof(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 11: M, N, K, O, G0, G1\n");
|
||||
printf("arg10: scale (alpha)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{
|
||||
M * G1 * K, K, G1 * K, 1}; // A layout [G0, M, G1, K]
|
||||
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_strides{
|
||||
N * G1 * K, K, G1 * K, 1}; // B0 layout [G0, N, G1, K]
|
||||
|
||||
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
|
||||
std::vector<ck::index_t> b1_gs_os_ns_strides{
|
||||
N * G1 * O, O, 1, G1 * O}; // B1 layout [G0, N, G1, O]
|
||||
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{
|
||||
M * G1 * O, O, G1 * O, 1}; // C layout [G0, M, G1, O]
|
||||
|
||||
// D layout [G0, M, G1, N]
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
|
||||
|
||||
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
|
||||
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
|
||||
Tensor<D0DataType> d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides);
|
||||
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
|
||||
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
|
||||
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
|
||||
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
|
||||
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-1, 1});
|
||||
break;
|
||||
case 3:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K);
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K);
|
||||
DeviceMem d0_device_buf(sizeof(D0DataType) * G0 * G1 * M * N);
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N);
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O);
|
||||
|
||||
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
|
||||
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
|
||||
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
|
||||
d0_device_buf.ToDevice(d0_gs_ms_ns.mData.data());
|
||||
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto c0de_element_op = C0DEElementOp{alpha};
|
||||
auto acc0_element_op = Acc0ElementOp{};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
auto argument = device_op.MakeArgument(
|
||||
static_cast<const ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<const B0DataType*>(b0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<const B1DataType*>(b1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
std::array<void*, 1>{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases
|
||||
{}, // p_acc1_biases
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
|
||||
{}, // acc1_biases_gs_ms_os_lengths
|
||||
{}, // acc1_biases_gs_ms_os_strides
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
c0de_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error("wrong! this device_op instance does not support this problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
ck::index_t BatchCount = G0 * G1;
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype =
|
||||
(sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
|
||||
sizeof(CDataType) * M * O + sizeof(D0DataType) * M * N) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
|
||||
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
|
||||
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
|
||||
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
|
||||
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
|
||||
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
|
||||
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
|
||||
Tensor<D0DataType> d0_g_m_n({BatchCount, M, N});
|
||||
|
||||
// permute
|
||||
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
|
||||
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
|
||||
});
|
||||
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
|
||||
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
|
||||
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
|
||||
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
|
||||
});
|
||||
|
||||
// gemm 0
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
acc0_g_m_n.ForEach([&](auto&, auto idx) {
|
||||
c0de_element_op(acc0_g_m_n(idx), acc0_g_m_n(idx), d0_g_m_n(idx));
|
||||
});
|
||||
// masking
|
||||
const auto mask = DeviceOpInstance::C0MatrixMask(N);
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
if(mask.IsMaskedElement(idx[1], idx[2]))
|
||||
self(idx) = -ck::NumericLimits<float>::Infinity();
|
||||
});
|
||||
|
||||
// softmax
|
||||
auto ref_softmax = ReferenceSoftmaxInstance{};
|
||||
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
|
||||
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
|
||||
|
||||
ref_softmax_invoker.Run(ref_softmax_argument);
|
||||
|
||||
// gemm1
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// permute
|
||||
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
|
||||
const size_t& g0 = idx[0];
|
||||
const size_t& g1 = idx[1];
|
||||
|
||||
const size_t g = g0 * G1 + g1;
|
||||
|
||||
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
|
||||
});
|
||||
|
||||
// default absolute error and relative error is 0.001
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
|
||||
return ck::utils::check_err(c_gs_ms_os_device_result.mData,
|
||||
c_gs_ms_os_host_result.mData,
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -26,9 +26,9 @@ template <index_t NumDimG,
|
||||
typename Acc1BiasDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename C0DEElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1DEElementwiseOperation,
|
||||
MaskingSpecialization MaskingSpec>
|
||||
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
|
||||
{
|
||||
@@ -58,9 +58,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
Acc0ElementwiseOperation acc0_element_op,
|
||||
C0DEElementwiseOperation c0de_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
C1DEElementwiseOperation c1de_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -25,15 +25,17 @@ namespace device {
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename D0sPointer,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename C0DEElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1DEElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename B1GridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
|
||||
typename Block2CTileMap,
|
||||
typename ComputeBasePtrOfStridedBatch,
|
||||
typename C0MatrixMask,
|
||||
@@ -47,16 +49,19 @@ __global__ void
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
D0sPointer p_d0s_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const AccElementwiseOperation acc_element_op,
|
||||
const C0DEElementwiseOperation c0de_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const C1DEElementwiseOperation c1de_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
@@ -77,20 +82,28 @@ __global__ void
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
|
||||
|
||||
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
|
||||
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
|
||||
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
|
||||
});
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_b1_grid + b1_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_d0s_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
acc_element_op,
|
||||
c0de_element_op,
|
||||
b1_element_op,
|
||||
c_element_op,
|
||||
c1de_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
block_2_ctile_map,
|
||||
c0_matrix_mask);
|
||||
#else
|
||||
@@ -100,13 +113,14 @@ __global__ void
|
||||
ignore = p_c_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = acc_element_op;
|
||||
ignore = c0de_element_op;
|
||||
ignore = b1_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = c1de_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = b1_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = batch_count;
|
||||
ignore = compute_base_ptr_of_batch;
|
||||
@@ -126,15 +140,15 @@ template <index_t NumDimG,
|
||||
typename BDataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasDataType,
|
||||
typename Acc1BiasDataType,
|
||||
typename D0sDataType,
|
||||
typename D1sDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename C0DEElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1DEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
TensorSpecialization ASpec,
|
||||
TensorSpecialization BSpec,
|
||||
@@ -192,23 +206,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
BDataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
D0sDataType,
|
||||
D1sDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
C0DEElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1DEElementwiseOperation,
|
||||
MaskingSpec>
|
||||
{
|
||||
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
|
||||
"Number of dimension must be greater than 0");
|
||||
|
||||
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
|
||||
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
|
||||
static constexpr index_t NumD0Tensor = D0sDataType::Size();
|
||||
static constexpr index_t NumD1Tensor = D1sDataType::Size();
|
||||
|
||||
// TODO ANT: implement bias combination
|
||||
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
|
||||
static_assert(NumD1Tensor == 0, "Gemm1 Bias addition is unimplemented");
|
||||
|
||||
#if 0
|
||||
// TODO ANT: use alias
|
||||
@@ -261,14 +275,40 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
Number<B1K1>{});
|
||||
}
|
||||
|
||||
static auto MakeD0sGridDescriptor_M_N(
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i],
|
||||
acc0_biases_gs_ms_ns_strides[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
static auto MakeD0sGridDescriptor_G_M_N(
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i],
|
||||
acc0_biases_gs_ms_ns_strides[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
|
||||
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
using C1GridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
|
||||
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
|
||||
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
|
||||
using C1GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
|
||||
using D0sGridDesc_M_N = decltype(MakeD0sGridDescriptor_M_N({}, {}));
|
||||
using D0sGridDesc_G_M_N = decltype(MakeD0sGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
constexpr static auto make_MaskOutPredicate()
|
||||
{
|
||||
@@ -288,11 +328,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
|
||||
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
|
||||
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
|
||||
const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
const C1GridDesc_G_M_N& c1_grid_desc_g_m_n,
|
||||
const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n)
|
||||
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
|
||||
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
|
||||
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
|
||||
c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
|
||||
c1_grid_desc_g_m_n_(c1_grid_desc_g_m_n),
|
||||
d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -313,32 +355,42 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
{
|
||||
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
return c1_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
|
||||
Number<I> d0_idx) const
|
||||
{
|
||||
return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
private:
|
||||
AGridDesc_G_M_K a_grid_desc_g_m_k_;
|
||||
BGridDesc_G_N_K b_grid_desc_g_n_k_;
|
||||
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n_;
|
||||
C1GridDesc_G_M_N c1_grid_desc_g_m_n_;
|
||||
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
|
||||
};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
|
||||
using GridwiseGemm = GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
D0sDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
C0DEElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1DEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
B1GridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
C1GridDesc_M_N,
|
||||
D0sGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -395,8 +447,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const BDataType* p_b_grid,
|
||||
const B1DataType* p_b1_grid,
|
||||
CDataType* p_c_grid,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::array<void*, NumD0Tensor> p_acc0_biases,
|
||||
const std::array<void*, NumD1Tensor> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
@@ -405,44 +457,48 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>&
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>&
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
C0DEElementwiseOperation c0de_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
C1DEElementwiseOperation c1de_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_b1_grid_{p_b1_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_d0s_grid_{},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
|
||||
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
|
||||
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
c1_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
a_grid_desc_g_m_k_{
|
||||
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
|
||||
b_grid_desc_g_n_k_{
|
||||
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
|
||||
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
|
||||
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
|
||||
c1_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
d0s_grid_desc_g_m_n_{DeviceOp::MakeD0sGridDescriptor_G_M_N(
|
||||
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)},
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c1_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
acc_element_op_{acc_element_op},
|
||||
c0de_element_op_{c0de_element_op},
|
||||
b1_element_op_{b1_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
c1de_element_op_{c1de_element_op},
|
||||
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
|
||||
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
|
||||
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
|
||||
@@ -456,27 +512,39 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
|
||||
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
|
||||
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
|
||||
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
|
||||
compute_base_ptr_of_batch_{
|
||||
a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_}
|
||||
batch_count_{c1_grid_desc_g_m_n_.GetLength(I0)},
|
||||
compute_base_ptr_of_batch_{a_grid_desc_g_m_k_,
|
||||
b_grid_desc_g_n_k_,
|
||||
b1_grid_desc_g_n_k_,
|
||||
c1_grid_desc_g_m_n_,
|
||||
d0s_grid_desc_g_m_n_}
|
||||
{
|
||||
// TODO ANT: implement bias addition
|
||||
ignore = p_acc0_biases;
|
||||
ignore = p_acc1_biases;
|
||||
ignore = acc0_biases_gs_ms_ns_lengths;
|
||||
ignore = acc0_biases_gs_ms_ns_strides;
|
||||
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
|
||||
ignore = acc1_biases_gs_ms_gemm1ns_strides;
|
||||
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
// D0 pointer
|
||||
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
|
||||
});
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
b1_grid_desc_bk0_n_bk1_,
|
||||
c_grid_desc_m_n_,
|
||||
c1_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_);
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c1_grid_desc_m_n_);
|
||||
|
||||
D0sGridDesc_M_N d0s_grid_desc_m_n{DeviceOp::MakeD0sGridDescriptor_M_N(
|
||||
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)};
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
|
||||
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
|
||||
d0s_grid_desc_m_n);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -491,9 +559,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
|
||||
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
|
||||
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
|
||||
std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
|
||||
<< c_grid_desc_g_m_n_.GetLength(I1) << ", "
|
||||
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n';
|
||||
std::cout << "c1_grid_desc_g_m_n_: " << c1_grid_desc_g_m_n_.GetLength(I0) << ", "
|
||||
<< c1_grid_desc_g_m_n_.GetLength(I1) << ", "
|
||||
<< c1_grid_desc_g_m_n_.GetLength(I2) << '\n';
|
||||
}
|
||||
|
||||
// pointers
|
||||
@@ -501,18 +569,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const BDataType* p_b_grid_;
|
||||
const B1DataType* p_b1_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
typename GridwiseGemm::D0sGridPointer p_d0s_grid_;
|
||||
|
||||
// tensor descriptor
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
C1GridDesc_M_N c1_grid_desc_m_n_;
|
||||
AGridDesc_G_M_K a_grid_desc_g_m_k_;
|
||||
BGridDesc_G_N_K b_grid_desc_g_n_k_;
|
||||
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
C1GridDesc_G_M_N c1_grid_desc_g_m_n_;
|
||||
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
|
||||
|
||||
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
|
||||
|
||||
// block-to-c-tile map
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
@@ -520,9 +593,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
AccElementwiseOperation acc_element_op_;
|
||||
C0DEElementwiseOperation c0de_element_op_;
|
||||
B1ElementwiseOperation b1_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
C1DEElementwiseOperation c1de_element_op_;
|
||||
|
||||
// check C0 masking and padding
|
||||
C0MatrixMask c0_matrix_mask_;
|
||||
@@ -551,7 +624,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c1_grid_desc_m_n_) * arg.batch_count_;
|
||||
|
||||
// Gemm0_K
|
||||
const auto K =
|
||||
@@ -564,15 +637,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
typename GridwiseGemm::D0sGridPointer,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
C0DEElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1DEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::B1GridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
C0MatrixMask,
|
||||
@@ -587,15 +662,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
arg.p_b_grid_,
|
||||
arg.p_b1_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0s_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.acc_element_op_,
|
||||
arg.c0de_element_op_,
|
||||
arg.b1_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.c1de_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.batch_count_,
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
@@ -644,9 +721,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
// TODO ANT: Check if tensor specialization & strides mismatch
|
||||
|
||||
// Check if C permute dimension matches GEMM + GEMM shape
|
||||
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
|
||||
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
|
||||
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
|
||||
const index_t c_g = arg.c1_grid_desc_g_m_n_.GetLength(I0); // unpadded
|
||||
const index_t c_m = arg.c1_grid_desc_m_n_.GetLength(I0);
|
||||
const index_t c_gemm1n = arg.c1_grid_desc_m_n_.GetLength(I1);
|
||||
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
|
||||
@@ -696,7 +773,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.c1_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
@@ -711,8 +788,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const BDataType* p_b,
|
||||
const B1DataType* p_b1,
|
||||
CDataType* p_c,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::array<void*, NumD0Tensor> p_acc0_biases,
|
||||
const std::array<void*, NumD1Tensor> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
@@ -721,17 +798,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
C0DEElementwiseOperation c0de_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
C1DEElementwiseOperation c1de_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
@@ -753,9 +830,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
acc_element_op,
|
||||
c0de_element_op,
|
||||
b1_element_op,
|
||||
c_element_op};
|
||||
c1de_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -767,8 +844,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const void* p_b,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::array<void*, NumD0Tensor> p_acc0_biases,
|
||||
const std::array<void*, NumD1Tensor> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
@@ -777,17 +854,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
C0DEElementwiseOperation c0de_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
C1DEElementwiseOperation c1de_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
@@ -809,9 +886,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
acc1_biases_gs_ms_gemm1ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
acc_element_op,
|
||||
c0de_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
c1de_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -49,6 +49,14 @@ struct Add
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x1_tmp = ck::type_convert<float>(x1);
|
||||
y = x0 + x1_tmp;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
@@ -67,6 +75,30 @@ struct Add
|
||||
};
|
||||
};
|
||||
|
||||
struct ScaleAdd
|
||||
{
|
||||
__host__ __device__ ScaleAdd(float scale) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void
|
||||
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = scale_ * x0 + ck::type_convert<float>(x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void
|
||||
operator()<float, float, bhalf_t>(float& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
y = scale_ * x0 + ck::type_convert<float>(x1);
|
||||
};
|
||||
|
||||
float scale_;
|
||||
};
|
||||
|
||||
struct Subtract
|
||||
{
|
||||
template <typename T>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -91,6 +91,7 @@ using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
|
||||
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
|
||||
template <typename Activation>
|
||||
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
|
||||
|
||||
@@ -0,0 +1,190 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ScaleAdd,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskOutUpperTriangle>>>&
|
||||
instances);
|
||||
|
||||
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ScaleAdd,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskDisabled>>>&
|
||||
instances);
|
||||
|
||||
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ScaleAdd,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskOutUpperTriangle>>>&
|
||||
instances);
|
||||
|
||||
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ScaleAdd,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskDisabled>>>&
|
||||
instances);
|
||||
|
||||
template <typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasDataType,
|
||||
MaskingSpecialization MaskingSpec>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ScaleAdd,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpec>>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ScaleAdd,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpec>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
|
||||
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t> &&
|
||||
Acc0BiasDataType::Size() == 1 &&
|
||||
is_same_v<tuple_element_t<0, Acc0BiasDataType>, half_t>)
|
||||
{
|
||||
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
|
||||
{
|
||||
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
|
||||
{
|
||||
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
|
||||
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16> &&
|
||||
Acc0BiasDataType::Size() == 1 &&
|
||||
is_same_v<tuple_element_t<0, Acc0BiasDataType>, BF16>)
|
||||
{
|
||||
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
|
||||
{
|
||||
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
|
||||
{
|
||||
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,7 @@
|
||||
add_instance_library(device_batched_gemm_softmax_gemm_permute_instance
|
||||
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
|
||||
// c[g, m, n] = a[g, m, k] * b[g, n, k]
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO,
|
||||
MaskingSpecialization MaskingSpec>
|
||||
using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec|
|
||||
// #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
|
||||
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
|
||||
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
// Padded fallback kernel
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ScaleAdd,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskOutUpperTriangle>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
MaskingSpecialization::MaskOutUpperTriangle>{});
|
||||
}
|
||||
|
||||
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ScaleAdd,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskDisabled>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
MaskingSpecialization::MaskDisabled>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,133 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
|
||||
// c[g, m, n] = a[g, m, k] * b[g, n, k]
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO,
|
||||
MaskingSpecialization MaskingSpec>
|
||||
using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec|
|
||||
// #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
|
||||
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
|
||||
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
// Padded fallback kernel
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ScaleAdd,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskOutUpperTriangle>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
MaskingSpecialization::MaskOutUpperTriangle>{});
|
||||
}
|
||||
|
||||
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ScaleAdd,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskDisabled>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
MaskingSpecialization::MaskDisabled>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,395 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasesDataType,
|
||||
typename Acc1BiasesDataType,
|
||||
tensor_operation::device::MaskingSpecialization MaskingSpec>
|
||||
bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int O,
|
||||
int G0,
|
||||
int G1,
|
||||
float alpha = -1.f)
|
||||
|
||||
{
|
||||
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
using ScaleAdd = tensor_operation::element_wise::ScaleAdd;
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using C0DEElementOp = ScaleAdd;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
using AccDataType = float;
|
||||
using D0DataType = tuple_element_t<0, Acc0BiasesDataType>;
|
||||
using tensor_operation::device::MaskingSpecialization;
|
||||
|
||||
// Ref Gemm0: various type in, fp32 out
|
||||
using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B0DataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp>;
|
||||
|
||||
// Ref Softmax: fp32 in, various type out
|
||||
using ReferenceSoftmaxInstance =
|
||||
tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
|
||||
|
||||
// Ref Gemm1: various type in, various type out
|
||||
using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
// A layout [G0, M, G1, K]
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B0 layout [G0, N, G1, K]
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B1 layout [G0, N, G1, O]
|
||||
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
|
||||
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
|
||||
|
||||
// C layout [G0, M, G1, O]
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
|
||||
|
||||
// D layout [G0, M, G1, N]
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
|
||||
|
||||
const int BatchCount = G0 * G1;
|
||||
|
||||
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
|
||||
Tensor<D0DataType> d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides);
|
||||
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
|
||||
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
|
||||
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
|
||||
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
|
||||
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
|
||||
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
|
||||
|
||||
std::srand(1); // work around test flakiness
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
// Still unsure whether this kind of deterministic floating point accurary issue is expected
|
||||
// or not. May want to try exact same approach as the GPU kernel in the host reference
|
||||
// GEMM+Softmax+GEMM function to see if the accuracy discrepancy goes away. Until then,
|
||||
// shrink the input value range as it is less likely to produce errors of around ~1e-3.
|
||||
// a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
// b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
// b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) *
|
||||
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
|
||||
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
|
||||
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
|
||||
d0_device_buf.ToDevice(d0_gs_ms_ns.mData.data());
|
||||
|
||||
if(alpha < 0)
|
||||
{
|
||||
alpha = 1.f / std::sqrt(K); // usually 1 / sqrt(head_dim)
|
||||
}
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto c0de_element_op = C0DEElementOp{alpha};
|
||||
auto acc0_element_op = Acc0ElementOp{};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
using DeviceOp =
|
||||
tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasesDataType,
|
||||
ck::Tuple<>,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
C0DEElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
MaskingSpec>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
|
||||
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
|
||||
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
|
||||
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
|
||||
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
|
||||
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
|
||||
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
|
||||
Tensor<D0DataType> d0_g_m_n({BatchCount, M, N});
|
||||
|
||||
// permute
|
||||
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
|
||||
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
|
||||
});
|
||||
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
|
||||
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
|
||||
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
|
||||
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
|
||||
});
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
acc0_g_m_n.ForEach([&](auto&, auto idx) {
|
||||
c0de_element_op(acc0_g_m_n(idx), acc0_g_m_n(idx), d0_g_m_n(idx));
|
||||
});
|
||||
// mask out upper triangle
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
if(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle && idx[1] < idx[2])
|
||||
self(idx) = -ck::NumericLimits<float>::Infinity();
|
||||
});
|
||||
|
||||
auto ref_softmax = ReferenceSoftmaxInstance{};
|
||||
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
|
||||
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
|
||||
|
||||
ref_softmax_invoker.Run(ref_softmax_argument);
|
||||
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// permute
|
||||
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
|
||||
const size_t& g0 = idx[0];
|
||||
const size_t& g1 = idx[1];
|
||||
|
||||
const size_t g = g0 * G1 + g1;
|
||||
|
||||
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
|
||||
});
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device op instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
std::array<void*, 1>{
|
||||
d0_device_buf.GetDeviceBuffer()}, // std::array<void*, 1> p_acc0_biases;
|
||||
{}, // std::array<void*, 1> p_acc1_biases;
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_strides}, // std::array<std::vector<ck::index_t>,
|
||||
// 1>{acc0_biases_gs_ms_ns_strides},
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
c0de_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
|
||||
sizeof(D0DataType) * M * N) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
|
||||
// default absolute error and relative error is 0.001
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
|
||||
// when BF16 is taken, set absolute error and relative error to 0.01
|
||||
if(std::is_same_v<ADataType, ck::bhalf_t> &&
|
||||
std::is_same_v<B0DataType, ck::bhalf_t> &&
|
||||
std::is_same_v<B1DataType, ck::bhalf_t> &&
|
||||
std::is_same_v<CDataType, ck::bhalf_t> &&
|
||||
std::is_same_v<D0DataType, ck::bhalf_t>)
|
||||
{
|
||||
rtol = 1e-2;
|
||||
atol = 1e-2;
|
||||
}
|
||||
|
||||
pass = pass & ck::utils::check_err(c_gs_ms_os_device_result,
|
||||
c_gs_ms_os_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a_gs_ms_ks: ", a_gs_ms_ks.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b0_gs_ns_ks : ", b0_gs_ns_ks.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b1_gs_os_ns : ", b1_gs_os_ns.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_gs_ms_os_device_result : ",
|
||||
c_gs_ms_os_device_result.mData,
|
||||
",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -5,4 +5,11 @@ add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_ge
|
||||
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
|
||||
|
||||
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp)
|
||||
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp)
|
||||
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16)
|
||||
@@ -0,0 +1,182 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_bias_softmax_gemm_permute_util.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16
|
||||
: public TestBatchedGemmMaskingScaleSoftmaxGemmPermute<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
using I1_t = ck::Number<1>;
|
||||
using I2_t = ck::Number<2>;
|
||||
|
||||
using MaskDisabled_t =
|
||||
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
|
||||
using MaskOutUpperTriangle_t =
|
||||
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskDisabled_t>,
|
||||
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, MaskOutUpperTriangle_t>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Test_BF16) { this->Run(); }
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{136, 128, 32, 128, 2, 3},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 136, 32, 128, 3, 2},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 40, 128, 2, 4},
|
||||
{128, 128, 136, 128, 4, 2},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 136, 1, 3},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{129, 128, 32, 128, 2, 3},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 129, 32, 128, 4, 3},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 33, 128, 2, 3},
|
||||
{128, 128, 129, 128, 2, 3},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
// If kernel B1Layout is RowMajor, expect not to support odd O size
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 129, 2, 3},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Bench_BF16_IrregularK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{{256, 256, 160, 160, 1, 16},
|
||||
{256, 64, 160, 64, 1, 16},
|
||||
{1024, 1024, 80, 80, 1, 16},
|
||||
{1024, 64, 80, 64, 1, 16},
|
||||
{4096, 4096, 40, 40, 1, 16},
|
||||
{4096, 64, 40, 64, 1, 16}};
|
||||
this->bench_ = true;
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Bench_BF16)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{256, 256, 64, 64, 48, 16},
|
||||
{256, 256, 128, 128, 48, 16},
|
||||
{512, 512, 64, 64, 48, 16},
|
||||
{512, 512, 128, 128, 48, 16},
|
||||
{1024, 1024, 64, 64, 48, 16},
|
||||
{1024, 1024, 128, 128, 48, 16},
|
||||
{2048, 2048, 64, 64, 48, 16},
|
||||
{2048, 2048, 128, 128, 48, 16},
|
||||
{4096, 4096, 64, 64, 48, 16},
|
||||
{4096, 4096, 128, 128, 48, 16},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch)
|
||||
{
|
||||
int P = 120; // requires padding
|
||||
int Q = 128; // do not require padding
|
||||
|
||||
// IsSupported(M, N, K, O)
|
||||
// clang-format off
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(Q, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MPadding>{}.IsSupported(P, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::NPadding>{}.IsSupported(Q, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNPadding>{}.IsSupported(P, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::OPadding>{}.IsSupported(Q, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MOPadding>{}.IsSupported(P, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::NOPadding>{}.IsSupported(Q, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNOPadding>{}.IsSupported(P, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMismatch)
|
||||
{
|
||||
// IsSupported(M, N, K, O)
|
||||
// clang-format off
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
|
||||
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
|
||||
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, AdhocTest)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{49, 49, 64, 64, 4, 6},
|
||||
{64, 49, 64, 64, 4, 6},
|
||||
{1020, 1020, 64, 128, 4, 6},
|
||||
{576, 576, 64, 64, 4, 6},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_softmax_gemm_permute_util.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
|
||||
: public TestBatchedGemmMaskingScaleSoftmaxGemmPermute<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
using I1_t = ck::Number<1>;
|
||||
using I2_t = ck::Number<2>;
|
||||
|
||||
using MaskDisabled_t =
|
||||
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
|
||||
using MaskOutUpperTriangle_t =
|
||||
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskDisabled_t>,
|
||||
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, MaskOutUpperTriangle_t>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16) { this->Run(); }
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{136, 128, 32, 128, 2, 3},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 136, 32, 128, 3, 2},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 40, 128, 2, 4},
|
||||
{128, 128, 136, 128, 4, 2},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 136, 1, 3},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{129, 128, 32, 128, 2, 3},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 129, 32, 128, 4, 3},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 33, 128, 2, 3},
|
||||
{128, 128, 129, 128, 2, 3},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
// If kernel B1Layout is RowMajor, expect not to support odd O size
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 129, 2, 3},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP16_IrregularK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{{256, 256, 160, 160, 1, 16},
|
||||
{256, 64, 160, 64, 1, 16},
|
||||
{1024, 1024, 80, 80, 1, 16},
|
||||
{1024, 64, 80, 64, 1, 16},
|
||||
{4096, 4096, 40, 40, 1, 16},
|
||||
{4096, 64, 40, 64, 1, 16}};
|
||||
this->bench_ = true;
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP16)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{256, 256, 64, 64, 48, 16},
|
||||
{256, 256, 128, 128, 48, 16},
|
||||
{512, 512, 64, 64, 48, 16},
|
||||
{512, 512, 128, 128, 48, 16},
|
||||
{1024, 1024, 64, 64, 48, 16},
|
||||
{1024, 1024, 128, 128, 48, 16},
|
||||
{2048, 2048, 64, 64, 48, 16},
|
||||
{2048, 2048, 128, 128, 48, 16},
|
||||
{4096, 4096, 64, 64, 48, 16},
|
||||
{4096, 4096, 128, 128, 48, 16},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch)
|
||||
{
|
||||
int P = 120; // requires padding
|
||||
int Q = 128; // do not require padding
|
||||
|
||||
// IsSupported(M, N, K, O)
|
||||
// clang-format off
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(Q, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MPadding>{}.IsSupported(P, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NPadding>{}.IsSupported(Q, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNPadding>{}.IsSupported(P, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::OPadding>{}.IsSupported(Q, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MOPadding>{}.IsSupported(P, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NOPadding>{}.IsSupported(Q, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNOPadding>{}.IsSupported(P, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMismatch)
|
||||
{
|
||||
// IsSupported(M, N, K, O)
|
||||
// clang-format off
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
|
||||
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
|
||||
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{49, 49, 64, 64, 4, 6},
|
||||
{64, 49, 64, 64, 4, 6},
|
||||
{1020, 1020, 64, 128, 4, 6},
|
||||
{576, 576, 64, 64, 4, 6},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
@@ -0,0 +1,380 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <vector>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp"
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
using ck::tensor_operation::device::MaskingSpecialization;
|
||||
using ck::tensor_operation::device::TensorSpecialization;
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <typename Tuple>
|
||||
struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
|
||||
{
|
||||
using NumDimGType = std::tuple_element_t<0, Tuple>;
|
||||
using NumDimMType = std::tuple_element_t<1, Tuple>;
|
||||
using NumDimNType = std::tuple_element_t<2, Tuple>;
|
||||
using NumDimKType = std::tuple_element_t<3, Tuple>;
|
||||
using NumDimOType = std::tuple_element_t<4, Tuple>;
|
||||
using ADataType = std::tuple_element_t<5, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<6, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<7, Tuple>;
|
||||
using CDataType = std::tuple_element_t<8, Tuple>;
|
||||
using Acc0BiasDataType = std::tuple_element_t<9, Tuple>;
|
||||
using Acc1BiasDataType = std::tuple_element_t<10, Tuple>;
|
||||
using MaskingType = std::tuple_element_t<11, Tuple>;
|
||||
|
||||
std::vector<std::vector<int>> lengths_ = {
|
||||
{256, 256, 64, 64, 6, 4},
|
||||
{256, 256, 128, 128, 4, 6},
|
||||
{512, 512, 64, 64, 3, 2},
|
||||
{512, 512, 128, 128, 2, 3},
|
||||
{1024, 1024, 64, 64, 3, 1},
|
||||
{1024, 1024, 128, 128, 1, 1},
|
||||
};
|
||||
bool bench_ = false;
|
||||
bool verify_ = true;
|
||||
|
||||
void RunSingle(int M, int N, int K, int O, int G0, int G1)
|
||||
{
|
||||
bool pass =
|
||||
ck::profiler::profile_batched_gemm_bias_softmax_gemm_permute_impl<NumDimGType::value,
|
||||
NumDimMType::value,
|
||||
NumDimNType::value,
|
||||
NumDimKType::value,
|
||||
NumDimOType::value,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
MaskingType::value>(
|
||||
verify_, 2, false, bench_, M, N, K, O, G0, G1);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
void Run()
|
||||
{
|
||||
for(auto lengths : this->lengths_)
|
||||
{
|
||||
int M = lengths[0];
|
||||
int N = lengths[1];
|
||||
int K = lengths[2];
|
||||
int O = lengths[3];
|
||||
int G0 = lengths[4];
|
||||
int G1 = lengths[5];
|
||||
|
||||
this->RunSingle(M, N, K, O, G0, G1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = F16;
|
||||
using CDataType = F16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ScaleAdd;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
|
||||
|
||||
using DeviceGemmGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ck::Tuple<F16>,
|
||||
ck::Tuple<>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
TensorSpecialization::Default, // ATensorSpec
|
||||
TensorSpecialization::Default, // B0TensorSpec
|
||||
TensorSpecialization::Default, // B1TensorSpec
|
||||
TensorSpecialization::Default, // CTensorSpec
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle
|
||||
|
||||
bool IsSupported(int M, int N, int K, int O)
|
||||
{
|
||||
const int G0 = 1, G1 = 1;
|
||||
|
||||
// A layout [G0, M, G1, K]
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B0 layout [G0, N, G1, K]
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B1 layout [G0, N, G1, O]
|
||||
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
|
||||
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
|
||||
|
||||
// C layout [G0, M, G1, O]
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
|
||||
|
||||
// D layout [G0, M, G1, N]
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
|
||||
|
||||
auto gemm = DeviceGemmGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
|
||||
static_cast<B0DataType*>(nullptr),
|
||||
static_cast<B1DataType*>(nullptr),
|
||||
static_cast<CDataType*>(nullptr),
|
||||
std::array<void*, 1>{nullptr}, // p_acc0_biases
|
||||
{}, // p_acc1_biases
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
|
||||
{}, // acc1_biases_gs_ms_os_lengths
|
||||
{}, // acc1_biases_gs_ms_os_strides
|
||||
PassThrough{}, // a_element_op
|
||||
PassThrough{}, // b0_element_op
|
||||
Acc0ElementOp{1.f}, // acc0_element_op
|
||||
PassThrough{}, // b1_element_op
|
||||
PassThrough{}); // c_element_op
|
||||
|
||||
return gemm.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using ADataType = BF16;
|
||||
using B0DataType = BF16;
|
||||
using B1DataType = BF16;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = BF16;
|
||||
using CDataType = BF16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ScaleAdd;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
|
||||
|
||||
using DeviceGemmGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
TensorSpecialization::Default, // ATensorSpec
|
||||
TensorSpecialization::Default, // B0TensorSpec
|
||||
TensorSpecialization::Default, // B1TensorSpec
|
||||
TensorSpecialization::Default, // CTensorSpec
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle
|
||||
|
||||
bool IsSupported(int M, int N, int K, int O)
|
||||
{
|
||||
const int G0 = 1, G1 = 1;
|
||||
|
||||
// A layout [G0, M, G1, K]
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B0 layout [G0, N, G1, K]
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B1 layout [G0, N, G1, O]
|
||||
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
|
||||
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
|
||||
|
||||
// C layout [G0, M, G1, O]
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
|
||||
|
||||
// D layout [G0, M, G1, N]
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
|
||||
|
||||
auto gemm = DeviceGemmGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
|
||||
static_cast<B0DataType*>(nullptr),
|
||||
static_cast<B1DataType*>(nullptr),
|
||||
static_cast<CDataType*>(nullptr),
|
||||
std::array<void*, 1>{nullptr}, // p_acc0_biases
|
||||
{}, // p_acc1_biases
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
|
||||
{}, // acc1_biases_gs_ms_os_lengths
|
||||
{}, // acc1_biases_gs_ms_os_strides
|
||||
PassThrough{}, // a_element_op
|
||||
PassThrough{}, // b0_element_op
|
||||
Acc0ElementOp{1.f}, // acc0_element_op
|
||||
PassThrough{}, // b1_element_op
|
||||
PassThrough{}); // c_element_op
|
||||
|
||||
return gemm.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user