mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Input/output permutation for fused attention (#460)
* reopen masking att instance due to CI is upgraded
* re-enable instances previously failed on 9110
* enable ksize-kpadding pair validity test
* add non-masked attention+permute test; expose masking boolean to attention kernel handles
* disable bench
* fix test
* move files
* bulk rename batched_gemm_masking_scale_softmax_gemm_permute to batched_gemm_softmax_gemm_permute
* format
* amend rename
* disable bench in test
* add mask/no-mask test for non-permute attention kernels
* disable broken kernel instance
* example working
add non-permuted problem statement
evaluating whether overhead comes from permutation or the extra kernel arg
* interface for bias addition without implementing it
* test and profiler running
* tidy
* mask type determined by enum class
* unify example code
* move masking specialization to its own header
* align formats
* extract helper functions
* experiment merging dims for attn w/ permute; shows perf parity with attn wo/ permute
* add tensor specialization to template args
since tensor spec packed shows perf parity when permutation isn't needed
remove redundant template args
comment on 'packed' tensor specialization
* grouped attention with input/output permute example
* format
* clean up
* refactor acc0 tile visitor
Co-authored-by: shaojiewang <wsjmessi@163.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: de37550f72]
This commit is contained in:
@@ -2,9 +2,11 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
|
||||
add_custom_target(example_gemm_scale_softmax_gemm)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
|
||||
|
||||
@@ -33,9 +33,6 @@ using S = ck::Sequence<Is...>;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
@@ -44,13 +41,14 @@ using B1DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F16;
|
||||
using Acc0BiasDataType = ck::Tuple<>;
|
||||
using Acc1BiasDataType = ck::Tuple<>;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
|
||||
using CPermuteNumDims_G_M_O =
|
||||
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
|
||||
static constexpr ck::index_t NumDimG = 2;
|
||||
static constexpr ck::index_t NumDimM = 1;
|
||||
static constexpr ck::index_t NumDimN = 1;
|
||||
static constexpr ck::index_t NumDimK = 1;
|
||||
static constexpr ck::index_t NumDimO = 1;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
@@ -59,17 +57,27 @@ using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
static constexpr auto MaskingSpec =
|
||||
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
|
||||
|
||||
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_O,
|
||||
NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDimO,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
@@ -78,6 +86,10 @@ using DeviceGemmInstance =
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
TensorSpecA,
|
||||
TensorSpecB0,
|
||||
TensorSpecB1,
|
||||
TensorSpecC,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
@@ -118,7 +130,7 @@ using DeviceGemmInstance =
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
true>; // MaskOutUpperTriangle
|
||||
MaskingSpec>; // MaskingSpecialization
|
||||
|
||||
// Ref Gemm0: fp16 in, fp32 out
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
@@ -142,268 +154,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
#include "run_batched_gemm_scale_softmax_gemm_permute.inc"
|
||||
|
||||
// GEMM shape for A/B0/B1/C
|
||||
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
|
||||
ck::index_t M = 512;
|
||||
ck::index_t N = 512;
|
||||
ck::index_t K = 64;
|
||||
ck::index_t O = 128;
|
||||
ck::index_t StrideA = -1;
|
||||
ck::index_t StrideB0 = -1;
|
||||
ck::index_t StrideB1 = -1;
|
||||
ck::index_t BatchStrideA = -1;
|
||||
ck::index_t BatchStrideB0 = -1;
|
||||
ck::index_t BatchStrideB1 = -1;
|
||||
float alpha = 1;
|
||||
|
||||
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
|
||||
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
|
||||
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
|
||||
ck::index_t G0 = 7;
|
||||
ck::index_t G1 = 13;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
G0 = std::stoi(argv[8]);
|
||||
G1 = std::stoi(argv[9]);
|
||||
|
||||
alpha = std::stof(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 11: M, N, K, O, G0, G1\n");
|
||||
printf("arg10: scale (alpha)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
|
||||
|
||||
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
|
||||
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
|
||||
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
|
||||
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
|
||||
|
||||
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
|
||||
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
|
||||
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
|
||||
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
|
||||
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
|
||||
|
||||
const int BatchCount = G0 * G1;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, 1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
// C_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<ADataType> a_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<CDataType> c_gs_ms_os_host_result(
|
||||
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
|
||||
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
|
||||
Tensor<CDataType> c_gs_ms_os_device_result(
|
||||
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
|
||||
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
|
||||
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) *
|
||||
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument =
|
||||
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_gs_ms_os_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc0_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
|
||||
// Output of Gemm0 is input A of Gemm1
|
||||
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
|
||||
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
|
||||
Tensor<CDataType> c_g_m_o_host_result(std::vector<int>{BatchCount, M, O},
|
||||
std::vector<int>{M * O, O, 1});
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
|
||||
|
||||
// gemm 0
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
// mask out upper triangle
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
if(idx[1] < idx[2])
|
||||
self(idx) = -ck::NumericLimits<float>::Infinity();
|
||||
});
|
||||
|
||||
auto ref_softmax = ReferenceSoftmaxInstance{};
|
||||
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
|
||||
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
|
||||
|
||||
// softmax
|
||||
ref_softmax_invoker.Run(ref_softmax_argument);
|
||||
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
|
||||
|
||||
// gemm1
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// permute
|
||||
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
|
||||
const size_t& g0 = idx[0];
|
||||
const size_t& g1 = idx[1];
|
||||
|
||||
const size_t g = g0 * G1 + g1;
|
||||
|
||||
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
|
||||
});
|
||||
|
||||
return ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
int main(int argc, char* argv[]) { return run(argc, argv); }
|
||||
|
||||
@@ -33,9 +33,6 @@ using S = ck::Sequence<Is...>;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
@@ -44,13 +41,14 @@ using B1DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F16;
|
||||
using Acc0BiasDataType = ck::Tuple<>;
|
||||
using Acc1BiasDataType = ck::Tuple<>;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
|
||||
using CPermuteNumDims_G_M_O =
|
||||
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
|
||||
static constexpr ck::index_t NumDimG = 2;
|
||||
static constexpr ck::index_t NumDimM = 1;
|
||||
static constexpr ck::index_t NumDimN = 1;
|
||||
static constexpr ck::index_t NumDimK = 1;
|
||||
static constexpr ck::index_t NumDimO = 1;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
@@ -59,17 +57,27 @@ using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
static constexpr auto MaskingSpec =
|
||||
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
|
||||
|
||||
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_O,
|
||||
NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDimO,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
@@ -78,6 +86,10 @@ using DeviceGemmInstance =
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
TensorSpecA,
|
||||
TensorSpecB0,
|
||||
TensorSpecB1,
|
||||
TensorSpecC,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
@@ -118,7 +130,7 @@ using DeviceGemmInstance =
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
false>; // MaskOutUpperTriangle
|
||||
MaskingSpec>; // MaskingSpecialization
|
||||
|
||||
// Ref Gemm0: fp16 in, fp32 out
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
@@ -142,258 +154,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
#include "run_batched_gemm_scale_softmax_gemm_permute.inc"
|
||||
|
||||
// GEMM shape for A/B0/B1/C
|
||||
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
|
||||
ck::index_t M = 120;
|
||||
ck::index_t N = 1000;
|
||||
ck::index_t K = 64;
|
||||
ck::index_t O = 128;
|
||||
ck::index_t StrideA = -1;
|
||||
ck::index_t StrideB0 = -1;
|
||||
ck::index_t StrideB1 = -1;
|
||||
ck::index_t BatchStrideA = -1;
|
||||
ck::index_t BatchStrideB0 = -1;
|
||||
ck::index_t BatchStrideB1 = -1;
|
||||
float alpha = 1;
|
||||
|
||||
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
|
||||
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
|
||||
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
|
||||
ck::index_t G0 = 7;
|
||||
ck::index_t G1 = 13;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
G0 = std::stoi(argv[8]);
|
||||
G1 = std::stoi(argv[9]);
|
||||
|
||||
alpha = std::stof(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 11: M, N, K, O, G0, G1\n");
|
||||
printf("arg10: scale (alpha)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
|
||||
|
||||
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
|
||||
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
|
||||
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
|
||||
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
|
||||
|
||||
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
|
||||
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
|
||||
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
|
||||
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
|
||||
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
|
||||
|
||||
const int BatchCount = G0 * G1;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, 1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
// C_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<ADataType> a_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<CDataType> c_gs_ms_os_host_result(
|
||||
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
|
||||
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
|
||||
Tensor<CDataType> c_gs_ms_os_device_result(
|
||||
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
|
||||
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
|
||||
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) *
|
||||
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument =
|
||||
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_gs_ms_os_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc0_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
|
||||
// Output of Gemm0 is input A of Gemm1
|
||||
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
|
||||
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
|
||||
Tensor<CDataType> c_g_m_o_host_result(std::vector<int>{BatchCount, M, O},
|
||||
std::vector<int>{M * O, O, 1});
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
auto ref_softmax = ReferenceSoftmaxInstance{};
|
||||
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
|
||||
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
|
||||
|
||||
ref_softmax_invoker.Run(ref_softmax_argument);
|
||||
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
|
||||
const size_t& g0 = idx[0];
|
||||
const size_t& g1 = idx[1];
|
||||
|
||||
const size_t g = g0 * G1 + g1;
|
||||
|
||||
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
|
||||
});
|
||||
|
||||
return ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
int main(int argc, char* argv[]) { return run(argc, argv); }
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/*
|
||||
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
|
||||
|-----------------|
|
||||
Gemm0
|
||||
|-------------------------------------|
|
||||
Gemm1
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F16;
|
||||
using Acc0BiasDataType = ck::Tuple<>;
|
||||
using Acc1BiasDataType = ck::Tuple<>;
|
||||
|
||||
static constexpr ck::index_t NumDimG = 2;
|
||||
static constexpr ck::index_t NumDimM = 1;
|
||||
static constexpr ck::index_t NumDimN = 1;
|
||||
static constexpr ck::index_t NumDimK = 1;
|
||||
static constexpr ck::index_t NumDimO = 1;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
static constexpr auto MaskingSpec =
|
||||
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
|
||||
|
||||
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle<
|
||||
NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDimO,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
TensorSpecA,
|
||||
TensorSpecB0,
|
||||
TensorSpecB1,
|
||||
TensorSpecC,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
64, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
2, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<16, 16, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
MaskingSpec>; // MaskingSpecialization
|
||||
|
||||
// Ref Gemm0: fp16 in, fp32 out
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B0DataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp>;
|
||||
|
||||
// Ref Softmax: fp32 in, fp16 out
|
||||
using ReferenceSoftmaxInstance =
|
||||
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
|
||||
|
||||
// Ref Gemm1: fp16 in, fp16 out
|
||||
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_grouped_gemm_scale_softmax_gemm_permute.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run(argc, argv); }
|
||||
@@ -33,9 +33,6 @@ using S = ck::Sequence<Is...>;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
@@ -44,13 +41,14 @@ using B1DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F16;
|
||||
using Acc0BiasDataType = ck::Tuple<>;
|
||||
using Acc1BiasDataType = ck::Tuple<>;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
|
||||
using CPermuteNumDims_G_M_O =
|
||||
S<1, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_M_O
|
||||
static constexpr ck::index_t NumDimG = 2;
|
||||
static constexpr ck::index_t NumDimM = 1;
|
||||
static constexpr ck::index_t NumDimN = 1;
|
||||
static constexpr ck::index_t NumDimK = 1;
|
||||
static constexpr ck::index_t NumDimO = 1;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
@@ -59,17 +57,27 @@ using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
static constexpr auto MaskingSpec =
|
||||
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
|
||||
|
||||
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_O,
|
||||
NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDimO,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
@@ -78,6 +86,10 @@ using DeviceGemmInstance =
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
TensorSpecA,
|
||||
TensorSpecB0,
|
||||
TensorSpecB1,
|
||||
TensorSpecC,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
@@ -118,7 +130,7 @@ using DeviceGemmInstance =
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
false>;
|
||||
MaskingSpec>; // MaskingSpecialization
|
||||
|
||||
// Ref Gemm0: fp16 in, fp32 out
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
@@ -142,303 +154,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
#include "run_grouped_gemm_scale_softmax_gemm_permute.inc"
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
float alpha = 1; // scaling after 1st gemm
|
||||
|
||||
std::size_t group_count = 13;
|
||||
|
||||
// Problem descs
|
||||
std::vector<DeviceGemmInstance::ProblemDesc> problem_descs;
|
||||
std::vector<const void*> p_a;
|
||||
std::vector<const void*> p_b0;
|
||||
std::vector<const void*> p_b1;
|
||||
std::vector<void*> p_c;
|
||||
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
int M = 128 * (rand() % 8 + 1);
|
||||
int N = 128 * (rand() % 8 + 1);
|
||||
int K = 40;
|
||||
int O = 40 * (rand() % 2 + 1);
|
||||
int Batch = rand() % 8 + 1;
|
||||
|
||||
const int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int StrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int StrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
|
||||
const int BatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
|
||||
const int BatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int BatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{Batch, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{O, Batch * O, 1};
|
||||
|
||||
problem_descs.push_back({M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides});
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, 1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<B0DataType>> b0_tensors;
|
||||
std::vector<Tensor<B1DataType>> b1_tensors;
|
||||
std::vector<Tensor<CDataType>> c_tensors;
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a_tensors_device;
|
||||
std::vector<DeviceMemPtr> b0_tensors_device;
|
||||
std::vector<DeviceMemPtr> b1_tensors_device;
|
||||
std::vector<DeviceMemPtr> c_tensors_device;
|
||||
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
|
||||
std::cout << "group count " << group_count << ". printing first 4 groups\n";
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
const auto& M = problem_descs[i].M;
|
||||
const auto& N = problem_descs[i].N;
|
||||
const auto& K = problem_descs[i].K;
|
||||
const auto& O = problem_descs[i].O;
|
||||
const auto& Batch = problem_descs[i].Batch;
|
||||
const auto& StrideA = problem_descs[i].StrideA;
|
||||
const auto& StrideB0 = problem_descs[i].StrideB0;
|
||||
const auto& StrideB1 = problem_descs[i].StrideB1;
|
||||
const auto& BatchStrideA = problem_descs[i].BatchStrideA;
|
||||
const auto& BatchStrideB0 = problem_descs[i].BatchStrideB0;
|
||||
const auto& BatchStrideB1 = problem_descs[i].BatchStrideB1;
|
||||
const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths;
|
||||
const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides;
|
||||
|
||||
// C_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<ADataType> a_g_m_k(
|
||||
f_host_tensor_descriptor(Batch, M, K, StrideA, BatchStrideA, ALayout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(Batch, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(Batch, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<CDataType> c_gs_ms_os_device_result(
|
||||
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
|
||||
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
|
||||
|
||||
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
|
||||
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
|
||||
Batch;
|
||||
|
||||
if(i < 4)
|
||||
{
|
||||
std::cout << "a_g_m_k[" << i << "]: " << a_g_m_k.mDesc << ", "
|
||||
<< "b0_g_k_n[" << i << "]: " << b0_g_k_n.mDesc << ", "
|
||||
<< "b1_g_n_o[" << i << "]: " << b1_g_n_o.mDesc << ", "
|
||||
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl;
|
||||
}
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
a_tensors.push_back(a_g_m_k);
|
||||
b0_tensors.push_back(b0_g_k_n);
|
||||
b1_tensors.push_back(b1_g_n_o);
|
||||
c_tensors.push_back(c_gs_ms_os_device_result);
|
||||
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()));
|
||||
b0_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize()));
|
||||
b1_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize()));
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()));
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_g_m_k.mData.data());
|
||||
b0_tensors_device[i]->ToDevice(b0_g_k_n.mData.data());
|
||||
b1_tensors_device[i]->ToDevice(b1_g_n_o.mData.data());
|
||||
|
||||
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
|
||||
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
|
||||
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
|
||||
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
|
||||
}
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(p_a,
|
||||
p_b0,
|
||||
p_b1,
|
||||
p_c,
|
||||
problem_descs,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc0_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
// specify workspace for problem_desc
|
||||
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
|
||||
|
||||
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer());
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(do_verification)
|
||||
{
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
const auto& M = problem_descs[i].M;
|
||||
const auto& N = problem_descs[i].N;
|
||||
const auto& O = problem_descs[i].O;
|
||||
const auto& Batch = problem_descs[i].Batch;
|
||||
const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths;
|
||||
const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides;
|
||||
|
||||
const auto& a_g_m_k = a_tensors[i];
|
||||
const auto& b0_g_k_n = b0_tensors[i];
|
||||
const auto& b1_g_n_o = b1_tensors[i];
|
||||
auto& c_gs_ms_os_device_result = c_tensors[i];
|
||||
auto& c_gs_ms_os_device_buf = *c_tensors_device[i];
|
||||
|
||||
Tensor<CDataType> c_gs_ms_os_host_result(
|
||||
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
|
||||
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
|
||||
|
||||
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
|
||||
// Output of Gemm0 is input A of Gemm1
|
||||
Tensor<AccDataType> acc0_m_n(f_host_tensor_descriptor(Batch, M, N, N, M * N, Row{}));
|
||||
|
||||
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(Batch, M, N, N, M * N, Row{}));
|
||||
|
||||
Tensor<CDataType> c_g_m_o_host_result(std::vector<int>{Batch, M, O},
|
||||
std::vector<int>{M * O, O, 1});
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_m_n, a_element_op, b0_element_op, acc0_element_op);
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
auto ref_softmax = ReferenceSoftmaxInstance{};
|
||||
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
|
||||
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_m_n, a1_g_m_n, 1, 0, {2});
|
||||
|
||||
ref_softmax_invoker.Run(ref_softmax_argument);
|
||||
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
|
||||
b1_g_n_o,
|
||||
c_g_m_o_host_result,
|
||||
PassThrough{},
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// Note: in this example, we merely permute the dimensions by changing underlying
|
||||
// strides so we simply access data as-is
|
||||
c_gs_ms_os_host_result.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = c_g_m_o_host_result(idx); });
|
||||
|
||||
bool pass_ =
|
||||
ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData);
|
||||
pass &= pass_;
|
||||
}
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
int main(int argc, char* argv[]) { return run(argc, argv); }
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape for A/B0/B1/C
|
||||
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
|
||||
ck::index_t M = 120;
|
||||
ck::index_t N = 1000;
|
||||
ck::index_t K = 64;
|
||||
ck::index_t O = 128;
|
||||
|
||||
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
|
||||
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
|
||||
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
|
||||
ck::index_t G0 = 7;
|
||||
ck::index_t G1 = 13;
|
||||
|
||||
float alpha = 1;
|
||||
|
||||
bool input_permute = false;
|
||||
bool output_permute = true;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 13)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
G0 = std::stoi(argv[8]);
|
||||
G1 = std::stoi(argv[9]);
|
||||
|
||||
alpha = std::stof(argv[10]);
|
||||
|
||||
input_permute = std::stoi(argv[11]);
|
||||
output_permute = std::stoi(argv[12]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 11: M, N, K, O, G0, G1\n");
|
||||
printf("arg10: scale (alpha)\n");
|
||||
printf("arg11 to 12: input / output permute\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides =
|
||||
input_permute
|
||||
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
|
||||
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
|
||||
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_strides =
|
||||
input_permute
|
||||
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
|
||||
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
|
||||
|
||||
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
|
||||
std::vector<ck::index_t> b1_gs_os_ns_strides =
|
||||
input_permute
|
||||
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
|
||||
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
|
||||
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides =
|
||||
output_permute
|
||||
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
|
||||
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
|
||||
|
||||
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
|
||||
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
|
||||
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
|
||||
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
|
||||
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
|
||||
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
|
||||
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) *
|
||||
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
|
||||
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
|
||||
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
// TODO ANT: replace array with vector?
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(
|
||||
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
{}, // std::array<void*, 1> p_acc0_biases;
|
||||
{}, // std::array<void*, 1> p_acc1_biases;
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc0_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
ck::index_t BatchCount = G0 * G1;
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
|
||||
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
|
||||
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
|
||||
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
|
||||
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
|
||||
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
|
||||
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
|
||||
|
||||
// permute
|
||||
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
|
||||
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
|
||||
});
|
||||
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
|
||||
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
|
||||
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
|
||||
// gemm 0
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
// masking
|
||||
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
if(mask.IsMaskedElement(idx[1], idx[2]))
|
||||
self(idx) = -ck::NumericLimits<float>::Infinity();
|
||||
});
|
||||
|
||||
// softmax
|
||||
auto ref_softmax = ReferenceSoftmaxInstance{};
|
||||
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
|
||||
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
|
||||
|
||||
ref_softmax_invoker.Run(ref_softmax_argument);
|
||||
|
||||
// gemm1
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// permute
|
||||
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
|
||||
const size_t& g0 = idx[0];
|
||||
const size_t& g1 = idx[1];
|
||||
|
||||
const size_t g = g0 * G1 + g1;
|
||||
|
||||
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
|
||||
});
|
||||
|
||||
return ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
bool input_permute = false;
|
||||
bool output_permute = true;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 6)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
input_permute = std::stoi(argv[4]);
|
||||
output_permute = std::stoi(argv[5]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 5: input / output permute\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
float alpha = 1; // scaling after 1st gemm
|
||||
|
||||
std::size_t group_count = 7;
|
||||
|
||||
// Problem descs
|
||||
std::vector<DeviceGemmInstance::ProblemDesc> problem_descs;
|
||||
std::vector<const void*> p_a;
|
||||
std::vector<const void*> p_b0;
|
||||
std::vector<const void*> p_b1;
|
||||
std::vector<void*> p_c;
|
||||
std::vector<std::vector<int>> g0_g1_m_n_k_o;
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<B0DataType>> b0_tensors;
|
||||
std::vector<Tensor<B1DataType>> b1_tensors;
|
||||
std::vector<Tensor<CDataType>> c_tensors;
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
std::vector<DeviceMemPtr> a_tensors_device;
|
||||
std::vector<DeviceMemPtr> b0_tensors_device;
|
||||
std::vector<DeviceMemPtr> b1_tensors_device;
|
||||
std::vector<DeviceMemPtr> c_tensors_device;
|
||||
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
|
||||
std::cout << "group count " << group_count << ". printing first 4 groups\n";
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
int M = 128 * (rand() % 8 + 1);
|
||||
int N = 128 * (rand() % 8 + 1);
|
||||
int K = 40;
|
||||
int O = 40 * (rand() % 2 + 1);
|
||||
int G0 = rand() % 3 + 1;
|
||||
int G1 = rand() % 5 + 1;
|
||||
|
||||
g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O});
|
||||
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides =
|
||||
input_permute
|
||||
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
|
||||
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
|
||||
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_strides =
|
||||
input_permute
|
||||
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
|
||||
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
|
||||
|
||||
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
|
||||
std::vector<ck::index_t> b1_gs_os_ns_strides =
|
||||
input_permute
|
||||
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
|
||||
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
|
||||
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides =
|
||||
output_permute
|
||||
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
|
||||
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
|
||||
|
||||
problem_descs.push_back({a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
{}, // acc0_biases_gs_ms_ns_lengths
|
||||
{}, // acc0_biases_gs_ms_ns_strides
|
||||
{}, // acc1_biases_gs_ms_os_lengths
|
||||
{}}); // acc1_biases_gs_ms_os_strides
|
||||
|
||||
// C_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
|
||||
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
|
||||
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
|
||||
int Batch = G0 * G1;
|
||||
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
|
||||
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
|
||||
Batch;
|
||||
|
||||
if(i < 4)
|
||||
{
|
||||
std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", "
|
||||
<< "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", "
|
||||
<< "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", "
|
||||
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl;
|
||||
}
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
a_tensors.push_back(a_gs_ms_ks);
|
||||
b0_tensors.push_back(b0_gs_ns_ks);
|
||||
b1_tensors.push_back(b1_gs_os_ns);
|
||||
c_tensors.push_back(c_gs_ms_os_device_result);
|
||||
|
||||
a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()));
|
||||
b0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()));
|
||||
b1_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()));
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()));
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_gs_ms_ks.mData.data());
|
||||
b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data());
|
||||
b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data());
|
||||
|
||||
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
|
||||
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
|
||||
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
|
||||
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
|
||||
}
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(p_a,
|
||||
p_b0,
|
||||
p_b1,
|
||||
p_c,
|
||||
{}, // p_acc0_biases
|
||||
{}, // p_acc1_biases
|
||||
problem_descs,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc0_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
// specify workspace for problem_desc
|
||||
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
|
||||
|
||||
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer());
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(do_verification)
|
||||
{
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
const int& G0 = g0_g1_m_n_k_o[i][0];
|
||||
const int& G1 = g0_g1_m_n_k_o[i][1];
|
||||
const int& M = g0_g1_m_n_k_o[i][2];
|
||||
const int& N = g0_g1_m_n_k_o[i][3];
|
||||
const int& K = g0_g1_m_n_k_o[i][4];
|
||||
const int& O = g0_g1_m_n_k_o[i][5];
|
||||
|
||||
const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths;
|
||||
const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides;
|
||||
|
||||
const auto& a_gs_ms_ks = a_tensors[i];
|
||||
const auto& b0_gs_ns_ks = b0_tensors[i];
|
||||
const auto& b1_gs_os_ns = b1_tensors[i];
|
||||
auto& c_gs_ms_os_device_result = c_tensors[i];
|
||||
auto& c_gs_ms_os_device_buf = *c_tensors_device[i];
|
||||
|
||||
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
|
||||
Tensor<ADataType> a_g_m_k({G0 * G1, M, K});
|
||||
Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N});
|
||||
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O});
|
||||
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0
|
||||
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax
|
||||
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1
|
||||
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
|
||||
// permute
|
||||
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
|
||||
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
|
||||
});
|
||||
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
|
||||
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
|
||||
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
|
||||
// gemm 0
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
// masking
|
||||
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
if(mask.IsMaskedElement(idx[1], idx[2]))
|
||||
self(idx) = -ck::NumericLimits<float>::Infinity();
|
||||
});
|
||||
|
||||
// softmax
|
||||
auto ref_softmax = ReferenceSoftmaxInstance{};
|
||||
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
|
||||
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
|
||||
|
||||
ref_softmax_invoker.Run(ref_softmax_argument);
|
||||
|
||||
// gemm 1
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
|
||||
b1_g_n_o,
|
||||
c_g_m_o_host_result,
|
||||
PassThrough{},
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// permute
|
||||
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
|
||||
const size_t& g0 = idx[0];
|
||||
const size_t& g1 = idx[1];
|
||||
|
||||
const size_t g = g0 * G1 + g1;
|
||||
|
||||
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
|
||||
});
|
||||
|
||||
bool pass_ =
|
||||
ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData);
|
||||
pass &= pass_;
|
||||
}
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -159,6 +159,11 @@
|
||||
// tuning parameter
|
||||
#define CK_WORKAROUND_SWDEV_325164 0
|
||||
|
||||
// workaround: disable broken fused attention kernel instance that does not pass validation
|
||||
// issue found on mi100/#10738 combo when irregular KPerBlock attention kernel has acc0 scaling
|
||||
// enabled
|
||||
#define CK_WORKAROUND_DISABLE_BROKEN_ATTN_KERNEL_INSTANCE 1
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct InMemoryDataOperationEnum
|
||||
|
||||
@@ -14,7 +14,8 @@ namespace ck {
|
||||
|
||||
template <typename TensorLengths,
|
||||
typename DimAccessOrder,
|
||||
typename ScalarsPerAccess> // # of scalars per access in each dimension
|
||||
typename ScalarsPerAccess,
|
||||
bool SnakeCurved = true> // # of scalars per access in each dimension
|
||||
struct SpaceFillingCurve
|
||||
{
|
||||
static constexpr index_t nDim = TensorLengths::Size();
|
||||
@@ -136,9 +137,10 @@ struct SpaceFillingCurve
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
ordered_idx(idim) = forward_sweep[idim] ? ordered_access_idx[idim]
|
||||
: ordered_access_lengths[idim] - 1 -
|
||||
ordered_access_idx[idim];
|
||||
ordered_idx(idim) =
|
||||
!SnakeCurved || forward_sweep[idim]
|
||||
? ordered_access_idx[idim]
|
||||
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
|
||||
@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
|
||||
|
||||
return make_tuple(Number<m0>{},
|
||||
Number<n0>{},
|
||||
waveId_m,
|
||||
waveId_n,
|
||||
blk_idx[I0],
|
||||
blk_idx[I1],
|
||||
blk_idx[I2],
|
||||
blk_idx[I3]);
|
||||
}
|
||||
|
||||
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
|
||||
{
|
||||
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
|
||||
@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
|
||||
|
||||
return make_tuple(
|
||||
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
|
||||
}
|
||||
|
||||
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
|
||||
|
||||
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
|
||||
@@ -24,7 +24,8 @@ template <typename ALayout,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
typename CElementwiseOperation,
|
||||
bool MaskOutUpperTriangle> // TODO: enum for mask type
|
||||
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
|
||||
@@ -7,49 +7,60 @@
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<>
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasDataType,
|
||||
typename Acc1BiasDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
typename CElementwiseOperation,
|
||||
MaskingSpecialization MaskingSpec>
|
||||
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b0,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t O,
|
||||
ck::index_t Batch,
|
||||
std::vector<index_t> c_gs_ms_os_lengths,
|
||||
std::vector<index_t> c_gs_ms_os_strides,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB0,
|
||||
ck::index_t StrideB1,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB0,
|
||||
ck::index_t BatchStrideB1,
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
Acc0ElementwiseOperation acc0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
|
||||
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a,
|
||||
const void* p_b0,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
Acc0ElementwiseOperation acc0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
@@ -7,46 +7,50 @@
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<>
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasDataType,
|
||||
typename Acc1BiasDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
typename CElementwiseOperation,
|
||||
MaskingSpecialization MaskingSpec>
|
||||
struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
|
||||
{
|
||||
struct ProblemDesc
|
||||
{
|
||||
// Overall problem shape
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t O;
|
||||
index_t Batch;
|
||||
std::vector<index_t> a_gs_ms_ks_lengths;
|
||||
std::vector<index_t> a_gs_ms_ks_strides;
|
||||
|
||||
// Stride for A/B0/B1; layout determined by template args
|
||||
index_t StrideA;
|
||||
index_t StrideB0;
|
||||
index_t StrideB1;
|
||||
index_t BatchStrideA;
|
||||
index_t BatchStrideB0;
|
||||
index_t BatchStrideB1;
|
||||
std::vector<index_t> b0_gs_ns_ks_lengths;
|
||||
std::vector<index_t> b0_gs_ns_ks_strides;
|
||||
|
||||
std::vector<index_t> b1_gs_os_ns_lengths;
|
||||
std::vector<index_t> b1_gs_os_ns_strides;
|
||||
|
||||
// Lengths and strides for output C
|
||||
std::vector<index_t> c_gs_ms_os_lengths;
|
||||
std::vector<index_t> c_gs_ms_os_strides;
|
||||
|
||||
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths;
|
||||
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides;
|
||||
|
||||
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths;
|
||||
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides;
|
||||
};
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
@@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
|
||||
std::vector<const void*> p_b0_vec,
|
||||
std::vector<const void*> p_b1_vec,
|
||||
std::vector<void*> p_c_vec,
|
||||
std::vector<std::vector<const void*>> p_acc0_biases_vec,
|
||||
std::vector<std::vector<const void*>> p_acc1_biases_vec,
|
||||
std::vector<ProblemDesc> problem_desc_vec,
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
@@ -54,9 +55,8 @@ __global__ void
|
||||
index_t right = group_count;
|
||||
index_t group_id = index_t((left + right) / 2);
|
||||
|
||||
while((!(block_id >= arg_ptr[group_id].block_start_ &&
|
||||
block_id < arg_ptr[group_id].block_end_)) &&
|
||||
left <= right)
|
||||
while(
|
||||
(!(block_id >= arg_ptr[group_id].block_start_ && block_id < arg_ptr[group_id].block_end_)))
|
||||
{
|
||||
if(block_id < arg_ptr[group_id].block_start_)
|
||||
{
|
||||
@@ -114,14 +114,17 @@ __global__ void
|
||||
// Computes C = A * B0 * B1
|
||||
// ^^^^^^ (Acc0)
|
||||
// ^^^^^^^^^^^ (Acc1)
|
||||
template <typename ALayout,
|
||||
typename BLayout, // B0Layout
|
||||
typename B1Layout,
|
||||
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<NumDimG, NumDimM, NumDimGemm1N>
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO, // NumDimGemm1N
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasDataType,
|
||||
typename Acc1BiasDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
@@ -130,6 +133,10 @@ template <typename ALayout,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
TensorSpecialization ASpec,
|
||||
TensorSpecialization BSpec,
|
||||
TensorSpecialization B1Spec,
|
||||
TensorSpecialization CSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -170,297 +177,152 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool MaskOutUpperTriangle,
|
||||
MaskingSpecialization MaskingSpec,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
: public DeviceGroupedGemmSoftmaxGemmPermute<ALayout,
|
||||
BLayout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_Gemm1N,
|
||||
: public DeviceGroupedGemmSoftmaxGemmPermute<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDimO,
|
||||
ADataType,
|
||||
BDataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
CElementwiseOperation,
|
||||
MaskingSpec>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle;
|
||||
using ProblemDesc =
|
||||
typename DeviceGroupedGemmSoftmaxGemmPermute<ALayout,
|
||||
BLayout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_Gemm1N,
|
||||
ADataType,
|
||||
BDataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>::ProblemDesc;
|
||||
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
|
||||
"Number of dimension must be greater than 0");
|
||||
|
||||
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
|
||||
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
|
||||
|
||||
// TODO ANT: implement bias combination
|
||||
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
|
||||
|
||||
#if 0
|
||||
// TODO ANT: use alias
|
||||
static constexpr index_t NumDimGemm0M = NumDimM;
|
||||
static constexpr index_t NumDimGemm0N = NumDimN;
|
||||
static constexpr index_t NumDimGemm0K = NumDimK;
|
||||
static constexpr index_t NumDimGemm1M = NumDimM;
|
||||
static constexpr index_t NumDimGemm1N = NumDimO;
|
||||
static constexpr index_t NumDimGemm1K = NumDimN;
|
||||
#endif
|
||||
|
||||
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle;
|
||||
using ProblemDesc = typename DeviceGroupedGemmSoftmaxGemmPermute<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDimO,
|
||||
ADataType,
|
||||
BDataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MaskingSpec>::ProblemDesc;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
|
||||
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
|
||||
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
|
||||
GemmSpec,
|
||||
ASpec,
|
||||
BSpec,
|
||||
B1Spec,
|
||||
CSpec>;
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
|
||||
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto
|
||||
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
|
||||
{
|
||||
const auto b1_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
|
||||
|
||||
const auto N = b1_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b1_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
|
||||
b1_gs_gemm1ns_gemm1ks_strides_vec),
|
||||
Number<B1K1>{});
|
||||
}
|
||||
|
||||
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
|
||||
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
|
||||
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
|
||||
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
constexpr static auto make_MaskOutPredicate()
|
||||
{
|
||||
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
|
||||
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
|
||||
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
|
||||
|
||||
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
|
||||
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto c_ms_ns_lengths = to_tuple(
|
||||
c_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto c_ms_ns_strides = to_tuple(
|
||||
c_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds);
|
||||
|
||||
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto c_grid_desc_ms_ns =
|
||||
make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
|
||||
|
||||
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
|
||||
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
|
||||
c_grid_desc_ms_ns,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
|
||||
make_tuple(mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
|
||||
{
|
||||
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
|
||||
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
|
||||
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
|
||||
|
||||
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
|
||||
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto c_gs_ms_ns_lengths =
|
||||
to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto c_gs_ms_ns_strides =
|
||||
to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for G0, G1, ...
|
||||
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
|
||||
NumDimG + NumDimM + NumDimN,
|
||||
1>::type{};
|
||||
|
||||
// lengths for G0, G1, ...
|
||||
const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds);
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds);
|
||||
|
||||
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto c_grid_desc_gs_ms_ns =
|
||||
make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides);
|
||||
|
||||
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
|
||||
// N2 * ...]
|
||||
const auto c_grid_desc_g_mraw_nraw =
|
||||
transform_tensor_descriptor(c_grid_desc_gs_ms_ns,
|
||||
make_tuple(make_merge_transform(gLengths),
|
||||
make_merge_transform(mLengths),
|
||||
make_merge_transform(nLengths)),
|
||||
make_tuple(gDimIds, mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// this desc is only for calculating batch offset so no padding needed
|
||||
return c_grid_desc_g_mraw_nraw;
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
struct C0MatrixMask
|
||||
{
|
||||
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
|
||||
|
||||
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
|
||||
{
|
||||
return n >= NRaw_;
|
||||
return MaskDisabledPredicate{};
|
||||
}
|
||||
|
||||
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
|
||||
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
|
||||
{
|
||||
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
|
||||
return MaskOutUpperTrianglePredicate{};
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
};
|
||||
}
|
||||
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
|
||||
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
|
||||
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
|
||||
const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
|
||||
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
|
||||
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
|
||||
c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
@@ -469,9 +331,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideB1_;
|
||||
AGridDesc_G_M_K a_grid_desc_g_m_k_;
|
||||
BGridDesc_G_N_K b_grid_desc_g_n_k_;
|
||||
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n_;
|
||||
};
|
||||
|
||||
@@ -535,8 +397,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
Transform::matrix_padder.PadN,
|
||||
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
|
||||
|
||||
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
|
||||
|
||||
@@ -570,16 +432,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
struct GroupDeviceArg
|
||||
{
|
||||
// problem definiton
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t O;
|
||||
// lengths for the last dimensions of overall problem for sanity check of vector load/store
|
||||
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
|
||||
|
||||
// Strides for the last dimensions of C for sanity check of vector load/store
|
||||
index_t c_extent_lowest_;
|
||||
index_t c_stride_lowest_;
|
||||
// strides for the last dimensions of each tensor for sanity check of vector load/store
|
||||
std::vector<index_t> a_mz_kz_strides_;
|
||||
std::vector<index_t> b_nz_kz_strides_;
|
||||
std::vector<index_t> b1_nz_kz_strides_;
|
||||
std::vector<index_t> c_mz_gemm1nz_strides_;
|
||||
|
||||
// for gridwise gemm check
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
};
|
||||
|
||||
@@ -591,6 +453,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
std::vector<const void*> p_b_vec,
|
||||
std::vector<const void*> p_b1_vec,
|
||||
std::vector<void*> p_c_vec,
|
||||
std::vector<std::vector<const void*>> p_acc0_biases_vec,
|
||||
std::vector<std::vector<const void*>> p_acc1_biases_vec,
|
||||
std::vector<ProblemDesc> problem_desc_vec,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
@@ -603,6 +467,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
b1_element_op_{b1_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
// TODO ANT: implement bias addition
|
||||
group_count_ = problem_desc_vec.size();
|
||||
|
||||
if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
|
||||
@@ -611,6 +476,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
|
||||
}
|
||||
|
||||
if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
|
||||
{
|
||||
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
|
||||
}
|
||||
|
||||
grid_size_ = 0;
|
||||
|
||||
for(std::size_t i = 0; i < group_count_; i++)
|
||||
@@ -620,14 +490,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
|
||||
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = DeviceOp::MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem_desc_vec[i].M, problem_desc_vec[i].K, problem_desc_vec[i].StrideA);
|
||||
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem_desc_vec[i].K, problem_desc_vec[i].N, problem_desc_vec[i].StrideB0);
|
||||
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
problem_desc_vec[i].N, problem_desc_vec[i].O, problem_desc_vec[i].StrideB1);
|
||||
const auto c_grid_desc_m_n = DeviceOp::MakeCGridDescriptor_M_N(
|
||||
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides);
|
||||
const auto& problem_desc = problem_desc_vec[i];
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
|
||||
const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
|
||||
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
|
||||
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
|
||||
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
|
||||
|
||||
const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
|
||||
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
|
||||
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
|
||||
problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
|
||||
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
|
||||
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
|
||||
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
|
||||
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -635,25 +516,32 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
const index_t BlockStart = grid_size_;
|
||||
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
|
||||
const index_t grid_size_grp = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) *
|
||||
problem_desc_vec[i].Batch;
|
||||
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
|
||||
const index_t grid_size_grp =
|
||||
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count;
|
||||
const index_t BlockEnd = grid_size_ + grid_size_grp;
|
||||
|
||||
// batch stride
|
||||
// TODO ANT: only keep batch stride in tensor desc to reduce scalar cache pressure
|
||||
const auto c_grid_desc_g_m_n = DeviceOp::MakeCGridDescriptor_G_M_N(
|
||||
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides);
|
||||
const auto compute_base_ptr_of_batch =
|
||||
ComputeBasePtrOfStridedBatch(problem_desc_vec[i].BatchStrideA,
|
||||
problem_desc_vec[i].BatchStrideB0,
|
||||
problem_desc_vec[i].BatchStrideB1,
|
||||
c_grid_desc_g_m_n);
|
||||
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
|
||||
a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n);
|
||||
|
||||
// C0 mask
|
||||
const auto c0_matrix_mask = C0MatrixMask(problem_desc_vec[i].N);
|
||||
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
// for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
|
||||
// so on
|
||||
if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
|
||||
problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
|
||||
problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
|
||||
problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! number of biases in function argument does not "
|
||||
"match that in template argument");
|
||||
}
|
||||
|
||||
group_kernel_args_.push_back({p_a_grid,
|
||||
p_b_grid,
|
||||
p_b1_grid,
|
||||
@@ -669,13 +557,20 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
BlockStart,
|
||||
BlockEnd});
|
||||
|
||||
group_device_args_.push_back({problem_desc_vec[i].M,
|
||||
problem_desc_vec[i].N,
|
||||
problem_desc_vec[i].K,
|
||||
problem_desc_vec[i].O,
|
||||
problem_desc_vec[i].c_gs_ms_os_lengths.back(),
|
||||
problem_desc_vec[i].c_gs_ms_os_strides.back(),
|
||||
c_grid_desc_m_n});
|
||||
group_device_args_.push_back(
|
||||
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
|
||||
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
|
||||
problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
|
||||
problem_desc.b1_gs_os_ns_lengths[NumDimG + NumDimO - 1]},
|
||||
{problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
|
||||
problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
|
||||
{problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN - 1],
|
||||
problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
|
||||
{problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO - 1],
|
||||
problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
|
||||
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
|
||||
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
|
||||
c_grid_desc_m_n});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -788,6 +683,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO ANT: Check if tensor specialization & strides mismatch
|
||||
|
||||
bool all_has_main_k_block_loop = true;
|
||||
bool some_has_main_k_block_loop = false;
|
||||
|
||||
@@ -815,19 +712,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
// Note: we need raw lengths since threadwise copy can not handle vector load when
|
||||
// part of vector is out of bounds
|
||||
const auto MRaw = device_arg.M;
|
||||
const auto NRaw = device_arg.N;
|
||||
const auto KRaw = device_arg.K;
|
||||
const auto Gemm1NRaw = device_arg.O;
|
||||
const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
|
||||
const auto NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
|
||||
const auto KzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
|
||||
const auto Gemm1NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
|
||||
const auto b_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
|
||||
const auto b1_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
|
||||
const auto c_extent_lowest = device_arg.c_extent_lowest_;
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
|
||||
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
|
||||
const auto c_extent_lowest = Gemm1NzRaw;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
@@ -837,8 +731,22 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector store requirement; assumes last dimension in N to be contiguous
|
||||
if(device_arg.c_stride_lowest_ != 1)
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
|
||||
? device_arg.a_mz_kz_strides_[1]
|
||||
: device_arg.a_mz_kz_strides_[0];
|
||||
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
|
||||
? device_arg.b_nz_kz_strides_[1]
|
||||
: device_arg.b_nz_kz_strides_[0];
|
||||
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
|
||||
? device_arg.b1_nz_kz_strides_[1]
|
||||
: device_arg.b1_nz_kz_strides_[0];
|
||||
const auto c_stride_lowest =
|
||||
device_arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be
|
||||
// contiguous
|
||||
|
||||
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
c_stride_lowest == 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -873,6 +781,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
std::vector<const void*> p_b_vec,
|
||||
std::vector<const void*> p_b1_vec,
|
||||
std::vector<void*> p_c_vec,
|
||||
std::vector<std::vector<const void*>> p_acc0_biases_vec,
|
||||
std::vector<std::vector<const void*>> p_acc1_biases_vec,
|
||||
std::vector<ProblemDesc> problem_desc_vec,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
@@ -884,6 +794,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
p_b_vec,
|
||||
p_b1_vec,
|
||||
p_c_vec,
|
||||
p_acc0_biases_vec,
|
||||
p_acc1_biases_vec,
|
||||
problem_desc_vec,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -895,21 +807,26 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*> p_a_vec,
|
||||
std::vector<const void*> p_b_vec,
|
||||
std::vector<const void*> p_b1_vec,
|
||||
std::vector<void*> p_c_vec,
|
||||
std::vector<ProblemDesc> problem_desc_vec,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*> p_a_vec,
|
||||
std::vector<const void*> p_b_vec,
|
||||
std::vector<const void*> p_b1_vec,
|
||||
std::vector<void*> p_c_vec,
|
||||
std::vector<std::vector<const void*>> p_acc0_biases_vec,
|
||||
std::vector<std::vector<const void*>> p_acc1_biases_vec,
|
||||
std::vector<ProblemDesc> problem_desc_vec,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a_vec,
|
||||
p_b_vec,
|
||||
p_b1_vec,
|
||||
p_c_vec,
|
||||
p_acc0_biases_vec,
|
||||
p_acc1_biases_vec,
|
||||
problem_desc_vec,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -942,7 +859,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
<< Gemm1NPerBlock << ", "
|
||||
<< Gemm1KPerBlock << ", "
|
||||
<< B1K1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ">";
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
|
||||
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
|
||||
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
|
||||
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
|
||||
<< getMaskingSpecializationString(MaskingSpec) << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -130,8 +130,11 @@ namespace device {
|
||||
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
|
||||
// FIXME: TensorSpecialization::Packed specialization does not cover all packed tensor cases, it
|
||||
// merely degenerates into TensorSpecialization::Default with NumDimG/M/N/K = 1
|
||||
// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
|
||||
// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
|
||||
// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
|
||||
// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
|
||||
// TensorSpecialization::Default with NumDimG/M/N/K = 1
|
||||
//
|
||||
// Detail- Packed tensor satisfies
|
||||
// stride_0 = 1
|
||||
@@ -147,7 +150,7 @@ namespace device {
|
||||
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
|
||||
//
|
||||
// Might need to expose dimension order to the interface to fully support
|
||||
// TensorSpecialization::Packed.
|
||||
// TensorSpecialization::Packed in a traditional sense of "packed" tensor
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
@@ -116,14 +117,17 @@ __global__ void
|
||||
// Computes C = A * B0 * B1
|
||||
// ^^^^^^ (Acc0)
|
||||
// ^^^^^^^^^^^ (Acc1)
|
||||
template <typename ALayout,
|
||||
typename BLayout, // B0Layout
|
||||
typename B1Layout,
|
||||
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<NumDimG, NumDimM, NumDimGemm1N>
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO, // NumDimGemm1N
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasDataType,
|
||||
typename Acc1BiasDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
@@ -132,6 +136,10 @@ template <typename ALayout,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
TensorSpecialization ASpec,
|
||||
TensorSpecialization BSpec,
|
||||
TensorSpecialization B1Spec,
|
||||
TensorSpecialization CSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -172,283 +180,135 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool MaskOutUpperTriangle,
|
||||
MaskingSpecialization MaskingSpec,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
|
||||
BLayout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_Gemm1N,
|
||||
: public DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDimO,
|
||||
ADataType,
|
||||
BDataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
CElementwiseOperation,
|
||||
MaskingSpec>
|
||||
{
|
||||
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
|
||||
"Number of dimension must be greater than 0");
|
||||
|
||||
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
|
||||
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
|
||||
|
||||
// TODO ANT: implement bias combination
|
||||
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
|
||||
|
||||
#if 0
|
||||
// TODO ANT: use alias
|
||||
static constexpr index_t NumDimGemm0M = NumDimM;
|
||||
static constexpr index_t NumDimGemm0N = NumDimN;
|
||||
static constexpr index_t NumDimGemm0K = NumDimK;
|
||||
static constexpr index_t NumDimGemm1M = NumDimM;
|
||||
static constexpr index_t NumDimGemm1N = NumDimO;
|
||||
static constexpr index_t NumDimGemm1K = NumDimN;
|
||||
#endif
|
||||
|
||||
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
|
||||
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
|
||||
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
|
||||
GemmSpec,
|
||||
ASpec,
|
||||
BSpec,
|
||||
B1Spec,
|
||||
CSpec>;
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
|
||||
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto
|
||||
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
|
||||
{
|
||||
const auto b1_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
|
||||
|
||||
const auto N = b1_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b1_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
|
||||
b1_gs_gemm1ns_gemm1ks_strides_vec),
|
||||
Number<B1K1>{});
|
||||
}
|
||||
|
||||
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
|
||||
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
|
||||
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
|
||||
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
constexpr static auto make_MaskOutPredicate()
|
||||
{
|
||||
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
|
||||
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
|
||||
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
|
||||
|
||||
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
|
||||
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto c_ms_ns_lengths = to_tuple(
|
||||
c_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto c_ms_ns_strides = to_tuple(
|
||||
c_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds);
|
||||
|
||||
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto c_grid_desc_ms_ns =
|
||||
make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
|
||||
|
||||
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
|
||||
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
|
||||
c_grid_desc_ms_ns,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
|
||||
make_tuple(mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
|
||||
{
|
||||
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
|
||||
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
|
||||
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
|
||||
|
||||
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
|
||||
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto c_gs_ms_ns_lengths =
|
||||
to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto c_gs_ms_ns_strides =
|
||||
to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for G0, G1, ...
|
||||
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
|
||||
NumDimG + NumDimM + NumDimN,
|
||||
1>::type{};
|
||||
|
||||
// lengths for G0, G1, ...
|
||||
const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds);
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds);
|
||||
|
||||
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto c_grid_desc_gs_ms_ns =
|
||||
make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides);
|
||||
|
||||
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
|
||||
// N2 * ...]
|
||||
const auto c_grid_desc_g_mraw_nraw =
|
||||
transform_tensor_descriptor(c_grid_desc_gs_ms_ns,
|
||||
make_tuple(make_merge_transform(gLengths),
|
||||
make_merge_transform(mLengths),
|
||||
make_merge_transform(nLengths)),
|
||||
make_tuple(gDimIds, mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// this desc is only for calculating batch offset so no padding needed
|
||||
return c_grid_desc_g_mraw_nraw;
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
struct C0MatrixMask
|
||||
{
|
||||
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
|
||||
|
||||
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
|
||||
{
|
||||
return n >= NRaw_;
|
||||
return MaskDisabledPredicate{};
|
||||
}
|
||||
|
||||
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
|
||||
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
|
||||
{
|
||||
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
|
||||
return MaskOutUpperTrianglePredicate{};
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
};
|
||||
}
|
||||
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
|
||||
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
|
||||
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
|
||||
const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
|
||||
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
|
||||
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
|
||||
c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
@@ -457,9 +317,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideB1_;
|
||||
AGridDesc_G_M_K a_grid_desc_g_m_k_;
|
||||
BGridDesc_G_N_K b_grid_desc_g_n_k_;
|
||||
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n_;
|
||||
};
|
||||
|
||||
@@ -523,47 +383,59 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
Transform::matrix_padder.PadN,
|
||||
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
|
||||
|
||||
// Argument
|
||||
// FIXME: constness
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
const B1DataType* p_b1_grid,
|
||||
CDataType* p_c_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw, // = ORaw
|
||||
index_t Batch,
|
||||
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideB1,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
Argument(
|
||||
const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
const B1DataType* p_b1_grid,
|
||||
CDataType* p_c_grid,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_b1_grid_{p_b1_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
b1_grid_desc_bk0_n_bk1_{
|
||||
DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)},
|
||||
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
c_grid_desc_g_m_n_{DeviceOp::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
|
||||
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
|
||||
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
a_grid_desc_g_m_k_{
|
||||
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
|
||||
b_grid_desc_g_n_k_{
|
||||
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
|
||||
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
|
||||
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
@@ -571,14 +443,31 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
acc_element_op_{acc_element_op},
|
||||
b1_element_op_{b1_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
batch_count_(Batch),
|
||||
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
|
||||
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
|
||||
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
|
||||
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
|
||||
b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]},
|
||||
a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
|
||||
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
|
||||
b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1],
|
||||
b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
|
||||
b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1],
|
||||
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
|
||||
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
|
||||
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
|
||||
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
|
||||
compute_base_ptr_of_batch_{
|
||||
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_},
|
||||
c0_matrix_mask_{NRaw},
|
||||
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
|
||||
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
|
||||
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
|
||||
a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_}
|
||||
{
|
||||
// TODO ANT: implement bias addition
|
||||
ignore = p_acc0_biases;
|
||||
ignore = p_acc1_biases;
|
||||
ignore = acc0_biases_gs_ms_ns_lengths;
|
||||
ignore = acc0_biases_gs_ms_ns_strides;
|
||||
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
|
||||
ignore = acc1_biases_gs_ms_gemm1ns_strides;
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
b1_grid_desc_bk0_n_bk1_,
|
||||
@@ -591,34 +480,66 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
|
||||
<< a_grid_desc_g_m_k_.GetLength(I1) << ", "
|
||||
<< a_grid_desc_g_m_k_.GetLength(I2) << '\n';
|
||||
// a_grid_desc_g_m_k_.Print();
|
||||
std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
|
||||
<< b_grid_desc_g_n_k_.GetLength(I1) << ", "
|
||||
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n';
|
||||
// b_grid_desc_g_n_k_.Print();
|
||||
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
|
||||
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
|
||||
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
|
||||
// b1_grid_desc_g_n_k_.Print();
|
||||
std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
|
||||
<< c_grid_desc_g_m_n_.GetLength(I1) << ", "
|
||||
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n';
|
||||
// c_grid_desc_g_m_n_.Print();
|
||||
}
|
||||
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
const B1DataType* p_b1_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
|
||||
// tensor descriptor
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
AGridDesc_G_M_K a_grid_desc_g_m_k_;
|
||||
BGridDesc_G_N_K b_grid_desc_g_n_k_;
|
||||
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-c-tile map
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
AccElementwiseOperation acc_element_op_;
|
||||
B1ElementwiseOperation b1_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
|
||||
// check C0 masking and padding
|
||||
C0MatrixMask c0_matrix_mask_;
|
||||
|
||||
// For robust IsSupportedArgument() check
|
||||
std::vector<index_t> raw_lengths_m_n_k_o_;
|
||||
index_t c_extent_lowest_;
|
||||
index_t c_stride_lowest_;
|
||||
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
|
||||
std::vector<index_t> a_mz_kz_strides_;
|
||||
std::vector<index_t> b_nz_kz_strides_;
|
||||
std::vector<index_t> b1_nz_kz_strides_;
|
||||
std::vector<index_t> c_mz_gemm1nz_strides_;
|
||||
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -628,13 +549,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
if(!DeviceOp::IsSupportedArgument(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
throw std::runtime_error("wrong! unsupported argument");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
@@ -719,17 +636,24 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if 0
|
||||
arg.Print();
|
||||
#endif
|
||||
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO ANT: Check if tensor specialization & strides mismatch
|
||||
|
||||
// Check if C permute dimension matches GEMM + GEMM shape
|
||||
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
|
||||
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
|
||||
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
|
||||
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
|
||||
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
|
||||
{
|
||||
return false;
|
||||
@@ -737,19 +661,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
|
||||
// vector is out of bounds
|
||||
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
|
||||
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
|
||||
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
|
||||
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
|
||||
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
|
||||
const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
|
||||
const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
|
||||
const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
|
||||
const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
|
||||
const auto b_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
|
||||
const auto b1_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
|
||||
const auto c_extent_lowest = arg.c_extent_lowest_;
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
|
||||
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
|
||||
const auto c_extent_lowest = Gemm1NzRaw;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
@@ -759,8 +681,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector store requirement; assumes last dimension in N to be contiguous
|
||||
if(arg.c_stride_lowest_ != 1)
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
|
||||
const auto b_stride_lowest =
|
||||
BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0];
|
||||
const auto b1_stride_lowest =
|
||||
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0];
|
||||
const auto c_stride_lowest =
|
||||
arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous
|
||||
|
||||
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
c_stride_lowest == 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -778,46 +710,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const B1DataType* p_b1,
|
||||
CDataType* p_c,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw,
|
||||
index_t Batch,
|
||||
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideB1,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
static auto MakeArgument(
|
||||
const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const B1DataType* p_b1,
|
||||
CDataType* p_c,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_b1,
|
||||
p_c,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
Gemm1NRaw,
|
||||
Batch,
|
||||
c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideB1,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideB1,
|
||||
p_acc0_biases,
|
||||
p_acc1_biases,
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b_gs_ns_ks_lengths,
|
||||
b_gs_ns_ks_strides,
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
acc0_biases_gs_ms_ns_lengths,
|
||||
acc0_biases_gs_ms_ns_strides,
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
acc_element_op,
|
||||
@@ -829,47 +766,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
// polymorphic
|
||||
// FIXME: constness
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw,
|
||||
index_t Batch,
|
||||
std::vector<index_t> c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
std::vector<index_t> c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideB1,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideB1,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
static_cast<CDataType*>(p_c),
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
Gemm1NRaw,
|
||||
Batch,
|
||||
c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideB1,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideB1,
|
||||
p_acc0_biases, // cast in struct Argument
|
||||
p_acc1_biases, // cast in struct Argument
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b_gs_ns_ks_lengths,
|
||||
b_gs_ns_ks_strides,
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
|
||||
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
acc0_biases_gs_ms_ns_lengths,
|
||||
acc0_biases_gs_ms_ns_strides,
|
||||
acc1_biases_gs_ms_gemm1ns_lengths,
|
||||
acc1_biases_gs_ms_gemm1ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
acc_element_op,
|
||||
@@ -901,7 +842,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
<< Gemm1NPerBlock << ", "
|
||||
<< Gemm1KPerBlock << ", "
|
||||
<< B1K1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ">";
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
|
||||
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
|
||||
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
|
||||
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
|
||||
<< getMaskingSpecializationString(MaskingSpec) << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
CElementwiseOperation,
|
||||
MaskOutUpperTriangle>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle;
|
||||
|
||||
@@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
struct C0MatrixMask
|
||||
{
|
||||
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
|
||||
|
||||
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
{
|
||||
return n >= NRaw_;
|
||||
}
|
||||
|
||||
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
|
||||
{
|
||||
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
};
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
@@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
using C0MatrixMask = conditional_t<MaskOutUpperTriangle,
|
||||
C0MatrixMask_impl<MaskOutUpperTrianglePredicate>,
|
||||
C0MatrixMask_impl<MaskDisabledPredicate>>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
enum struct MaskingSpecialization
|
||||
{
|
||||
MaskDisabled,
|
||||
MaskOutUpperTriangle
|
||||
};
|
||||
|
||||
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
|
||||
case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle";
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
|
||||
struct MaskDisabledPredicate
|
||||
{
|
||||
__host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const
|
||||
{
|
||||
return false;
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr bool
|
||||
IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const
|
||||
{
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct MaskOutUpperTrianglePredicate
|
||||
{
|
||||
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ constexpr bool
|
||||
IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const
|
||||
{
|
||||
return operator()(m + m_tile - 1, n);
|
||||
}
|
||||
};
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
template <typename MaskOutPredicate>
|
||||
struct C0MatrixMask_impl
|
||||
{
|
||||
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
|
||||
|
||||
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
{
|
||||
return n >= NRaw_;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const
|
||||
{
|
||||
return predicate_(m, n) || IsNOutOfBound(n);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr bool
|
||||
IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const
|
||||
{
|
||||
return predicate_.IsTileSkippable(m, n, m_tile, n_tile);
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
MaskOutPredicate predicate_;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -336,36 +336,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
};
|
||||
|
||||
template <bool Pred>
|
||||
struct ElementOpPredicatedResetNaNToMinusInf;
|
||||
|
||||
template <>
|
||||
struct ElementOpPredicatedResetNaNToMinusInf<true>
|
||||
{
|
||||
template <typename ElementOp, typename OutT, typename InT>
|
||||
__host__ __device__ void Run(OutT& y, const ElementOp& op, const InT& x)
|
||||
{
|
||||
if(ck::math::isnan(x))
|
||||
{
|
||||
y = -ck::NumericLimits<float>::Infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
op(y, x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementOpPredicatedResetNaNToMinusInf<false>
|
||||
{
|
||||
template <typename ElementOp, typename OutT, typename InT>
|
||||
__host__ __device__ void Run(OutT& y, const ElementOp& op, const InT& x)
|
||||
{
|
||||
op(y, x);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
@@ -406,11 +376,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return;
|
||||
}
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
const index_t gemm1_n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
@@ -627,7 +597,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
true, // DstResetCoord
|
||||
NumGemmKPrefetchStage>(
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
|
||||
b1_element_op,
|
||||
b1_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -745,29 +715,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
running_max = NumericLimits<FloatGemmAcc>::Lowest();
|
||||
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
|
||||
|
||||
// decoder lower triangular mask
|
||||
const auto thread_cluster_idx = threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_n_cluster_id = thread_cluster_idx[I1];
|
||||
const index_t MPerRepeat = MPerBlock / MXdlPerWave;
|
||||
const index_t NPerRepeat = NPerBlock / NXdlPerWave;
|
||||
const index_t mstart = m_block_data_idx_on_grid + thread_m_cluster_id;
|
||||
|
||||
// gemm1 K loop
|
||||
index_t gemm1_k_block_outer_index = 0;
|
||||
do
|
||||
{
|
||||
if constexpr(MaskOutUpperTriangle)
|
||||
auto n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
|
||||
if(c0_matrix_mask.IsTileSkippable(
|
||||
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
|
||||
{
|
||||
auto gemm0_n_block_idx =
|
||||
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
|
||||
if(c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid, gemm0_n_block_idx) &&
|
||||
c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid + MPerBlock - 1,
|
||||
gemm0_n_block_idx))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// gemm0
|
||||
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
|
||||
@@ -789,60 +746,58 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
// do MNK padding or upper triangular masking
|
||||
if constexpr(MaskOutUpperTriangle || PadN)
|
||||
{
|
||||
const index_t nstart = gemm1_k_block_outer_index * NPerBlock;
|
||||
// 8d thread_desc in thread scope
|
||||
constexpr auto c_thread_lengths =
|
||||
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
|
||||
|
||||
static_for<0, m0, 1>{}([&](auto m0_i) {
|
||||
const index_t m_global = mstart + m0_i * MPerRepeat;
|
||||
const index_t acc_idx_m0 = m0_i * n0 * n2 * n4;
|
||||
static_for<0, n0, 1>{}([&](auto n0_i) {
|
||||
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
|
||||
// const index_t nstartxdl = nstart + nrepeat_i;
|
||||
const index_t nstartxdl = nstart + n0_i * NPerRepeat;
|
||||
const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4;
|
||||
static_for<0, n2, 1>{}([&](auto n2_i) {
|
||||
const index_t nstartgroup =
|
||||
nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4;
|
||||
const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4;
|
||||
static_for<0, n4, 1>{}([&](auto n4_i) {
|
||||
const index_t n_global = nstartgroup + n4_i;
|
||||
const auto acc_offset = Number<acc_idx_n2 + n4_i>{};
|
||||
if constexpr(MaskOutUpperTriangle)
|
||||
{
|
||||
if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
|
||||
{
|
||||
acc_thread_buf(acc_offset) =
|
||||
-ck::NumericLimits<float>::Infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_element_op(acc_thread_buf(acc_offset),
|
||||
acc_thread_buf[acc_offset]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// ignore m_global;
|
||||
if(c0_matrix_mask.IsNOutOfBound(n_global))
|
||||
{
|
||||
acc_thread_buf(acc_offset) =
|
||||
-ck::NumericLimits<float>::Infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_element_op(acc_thread_buf(acc_offset),
|
||||
acc_thread_buf[acc_offset]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
// 8d block_desc in block scope
|
||||
constexpr auto c_block_lengths =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
|
||||
|
||||
constexpr auto M0 = c_block_lengths[I0];
|
||||
constexpr auto N0 = c_block_lengths[I1];
|
||||
constexpr auto M1 = c_block_lengths[I2];
|
||||
constexpr auto N1 = c_block_lengths[I3];
|
||||
constexpr auto M2 = c_block_lengths[I4];
|
||||
constexpr auto N2 = c_block_lengths[I5];
|
||||
constexpr auto N3 = c_block_lengths[I6];
|
||||
constexpr auto N4 = c_block_lengths[I7];
|
||||
|
||||
// works like multi-dimension static_for (static_ford), but provides both the linear
|
||||
// index as well as n-d index
|
||||
using Acc0TileIterator = SpaceFillingCurve<
|
||||
decltype(c_thread_lengths),
|
||||
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
|
||||
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
|
||||
false>; // SnakeCurved
|
||||
|
||||
auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D(
|
||||
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
|
||||
|
||||
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)),
|
||||
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
|
||||
|
||||
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
|
||||
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
|
||||
auto m_local =
|
||||
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
|
||||
auto n_local =
|
||||
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
|
||||
auto m_global = m_local + m_block_data_idx_on_grid;
|
||||
auto n_global = n_local + n_block_data_idx_on_grid;
|
||||
if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
|
||||
{
|
||||
acc_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_element_op(acc_thread_buf(i), acc_thread_buf[i]);
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, acc_thread_buf.Size(), 1>{}(
|
||||
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
|
||||
}
|
||||
|
||||
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
|
||||
|
||||
|
||||
@@ -593,7 +593,8 @@ struct XdlopsGemm
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
using CIndex = MultiIndex<2>;
|
||||
using CIndex = MultiIndex<2>;
|
||||
using CIndex4D = MultiIndex<4>;
|
||||
|
||||
__device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; }
|
||||
|
||||
@@ -822,6 +823,16 @@ struct XdlopsGemm
|
||||
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
|
||||
}
|
||||
|
||||
__device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
|
||||
{
|
||||
const auto blk_idx = GetBlkIdx();
|
||||
|
||||
const auto blk_id = blk_idx[I0];
|
||||
const auto blk_td = blk_idx[I1];
|
||||
|
||||
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
|
||||
}
|
||||
|
||||
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{};
|
||||
|
||||
static constexpr auto mfma_instr = mfma.selected_mfma;
|
||||
|
||||
@@ -0,0 +1,288 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
device::TensorSpecialization TensorSpec>
|
||||
static auto MakeGridDescriptorPair(const std::vector<index_t>& gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& gs_ms_ns_strides_vec)
|
||||
{
|
||||
if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
|
||||
gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
|
||||
{
|
||||
throw std::runtime_error("wrong! dimension must match input lengths");
|
||||
}
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto gs_ms_ns_lengths =
|
||||
to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto gs_ms_ns_strides =
|
||||
to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for G0, G1, ...
|
||||
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimG + NumDimM, NumDimG + NumDimM + NumDimN, 1>::type{};
|
||||
|
||||
// lengths for G0, G1, ...
|
||||
const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds);
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for N0, N1, ...
|
||||
const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds);
|
||||
|
||||
if constexpr(TensorSpec == device::TensorSpecialization::Packed)
|
||||
{
|
||||
auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
|
||||
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
|
||||
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
|
||||
const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
|
||||
make_tuple(G, M, N),
|
||||
make_tuple(gs_ms_ns_strides[Number<NumDimG - 1>{}],
|
||||
gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
|
||||
gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
|
||||
|
||||
const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N),
|
||||
make_tuple(gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
|
||||
gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
|
||||
|
||||
return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto grid_desc_gs_ms_ns =
|
||||
make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides);
|
||||
|
||||
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
|
||||
// N2 * ...]
|
||||
// Note: This does not require padding as it only provides G offset calculation. Technically
|
||||
// descriptor for only G is needed. Here we opt for backward compatibility purpose to return
|
||||
// G_M_N
|
||||
const auto grid_desc_g_mraw_nraw =
|
||||
transform_tensor_descriptor(grid_desc_gs_ms_ns,
|
||||
make_tuple(make_merge_transform(gLengths),
|
||||
make_merge_transform(mLengths),
|
||||
make_merge_transform(nLengths)),
|
||||
make_tuple(gDimIds, mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto c_ms_ns_lengths = to_tuple(
|
||||
gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto c_ms_ns_strides = to_tuple(
|
||||
gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
|
||||
// N2 * ...]
|
||||
const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
|
||||
|
||||
const auto grid_desc_mraw_nraw = transform_tensor_descriptor(
|
||||
grid_desc_ms_ns,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
|
||||
make_tuple(mDimIds - Number<NumDimG>{}, nDimIds - Number<NumDimG>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename NumDims_G_M_N_K_O, // Sequence<>
|
||||
typename PerBlock_M_N_K_O, // Sequence<>
|
||||
device::GemmSpecialization GemmSpec,
|
||||
device::TensorSpecialization ASpec,
|
||||
device::TensorSpecialization B0Spec,
|
||||
device::TensorSpecialization B1Spec,
|
||||
device::TensorSpecialization CSpec>
|
||||
struct TransformBatchedContractionContractionToBatchedGemmGemm
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
|
||||
static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0);
|
||||
static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1);
|
||||
static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2);
|
||||
static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3);
|
||||
static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4);
|
||||
|
||||
static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0);
|
||||
static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1);
|
||||
static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2);
|
||||
static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
device::GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, OPerBlock};
|
||||
|
||||
//
|
||||
// A
|
||||
//
|
||||
static auto MakeAGridDescriptorPair(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimK, ASpec>(a_gs_ms_ks_lengths_vec,
|
||||
a_gs_ms_ks_strides_vec);
|
||||
}
|
||||
|
||||
// TODO: rename to G_MRaw_KRaw
|
||||
static auto MakeAGridDescriptor_G_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first;
|
||||
}
|
||||
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
return matrix_padder.PadADescriptor_M_K(
|
||||
MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second);
|
||||
}
|
||||
|
||||
template <typename AGridDesc_M_K, typename Number>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1)
|
||||
{
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
//
|
||||
// B (alias of B0)
|
||||
//
|
||||
static auto MakeB0GridDescriptorPair(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
|
||||
{
|
||||
return MakeGridDescriptorPair<NumDimG, NumDimN, NumDimK, B0Spec>(b0_gs_ns_ks_lengths_vec,
|
||||
b0_gs_ns_ks_strides_vec);
|
||||
}
|
||||
|
||||
// TODO: rename to G_MRaw_NRaw
|
||||
static auto MakeB0GridDescriptor_G_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
|
||||
{
|
||||
return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first;
|
||||
}
|
||||
static auto MakeB0GridDescriptor_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
|
||||
{
|
||||
// alias of matrix_padder.PadB0Descriptor_N_K
|
||||
return matrix_padder.PadBDescriptor_N_K(
|
||||
MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second);
|
||||
}
|
||||
|
||||
template <typename BGridDesc_N_K, typename Number>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1)
|
||||
{
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
//
|
||||
// B1
|
||||
//
|
||||
static auto MakeB1GridDescriptorPair(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
|
||||
const std::vector<index_t>& b1_gs_os_ns_strides_vec)
|
||||
{
|
||||
return MakeGridDescriptorPair<NumDimG, NumDimO, NumDimN, B1Spec>(b1_gs_os_ns_lengths_vec,
|
||||
b1_gs_os_ns_strides_vec);
|
||||
}
|
||||
|
||||
// TODO: rename to G_NRaw_KRaw
|
||||
static auto MakeB1GridDescriptor_G_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
|
||||
const std::vector<index_t>& b1_gs_os_ns_strides_vec)
|
||||
{
|
||||
return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first;
|
||||
}
|
||||
static auto MakeB1GridDescriptor_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
|
||||
const std::vector<index_t>& b1_gs_os_ns_strides_vec)
|
||||
{
|
||||
// alias of matrix_padder.PadB1Descriptor_O_N
|
||||
return matrix_padder.PadB1Descriptor_N_K(
|
||||
MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second);
|
||||
}
|
||||
|
||||
template <typename B1GridDesc_N_K, typename Number>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1)
|
||||
{
|
||||
const auto N = b1_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b1_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
//
|
||||
// C
|
||||
//
|
||||
static auto MakeCGridDescriptorPair(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
|
||||
const std::vector<index_t>& c_gs_ms_os_strides_vec)
|
||||
{
|
||||
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimO, CSpec>(c_gs_ms_os_lengths_vec,
|
||||
c_gs_ms_os_strides_vec);
|
||||
}
|
||||
|
||||
// TODO: rename to G_MRaw_NRaw
|
||||
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
|
||||
const std::vector<index_t>& c_gs_ms_os_strides_vec)
|
||||
{
|
||||
return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first;
|
||||
}
|
||||
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
|
||||
const std::vector<index_t>& c_gs_ms_os_strides_vec)
|
||||
{
|
||||
return matrix_padder.PadCDescriptor_M_N(
|
||||
MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,100 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using CPermuteNumDims_G_M_O =
|
||||
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
|
||||
|
||||
void add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<Row,
|
||||
Col,
|
||||
Row,
|
||||
CPermuteNumDims_G_M_O,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CPermuteNumDims_G_M_Gemm1N,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_Gemm1N,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_Gemm1N,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
|
||||
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Row> &&
|
||||
is_same_v<CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_O>)
|
||||
{
|
||||
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -28,9 +28,26 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
false>>>& instances);
|
||||
|
||||
void add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
true>>>& instances);
|
||||
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
@@ -39,7 +56,8 @@ template <typename ALayout,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType>
|
||||
typename CDataType,
|
||||
bool MaskOutUpperTriangle>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout,
|
||||
B0Layout,
|
||||
@@ -51,9 +69,10 @@ struct DeviceOperationInstanceFactory<
|
||||
CDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>
|
||||
MaskOutUpperTriangle>>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmSoftmaxGemm<ALayout,
|
||||
B0Layout,
|
||||
@@ -65,9 +84,10 @@ struct DeviceOperationInstanceFactory<
|
||||
CDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
@@ -79,8 +99,16 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
op_ptrs);
|
||||
if constexpr(MaskOutUpperTriangle)
|
||||
{
|
||||
add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskOutUpperTriangle>>>&
|
||||
instances);
|
||||
|
||||
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskDisabled>>>&
|
||||
instances);
|
||||
|
||||
template <typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
MaskingSpecialization MaskingSpec>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpec>>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpec>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
|
||||
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
|
||||
{
|
||||
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
|
||||
{
|
||||
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -6,6 +6,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
clang_tidy_check(${INSTANCE_NAME})
|
||||
endfunction(add_instance_library INSTANCE_NAME)
|
||||
|
||||
|
||||
file(GLOB dir_list LIST_DIRECTORIES true *)
|
||||
set(CK_DEVICE_INSTANCES)
|
||||
FOREACH(subdir_path ${dir_list})
|
||||
|
||||
@@ -36,10 +36,10 @@ using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_inst
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
|
||||
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
|
||||
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
|
||||
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
|
||||
@@ -36,10 +36,10 @@ using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_inst
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 4, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can trigger compiler crash in mainline #9110 but not in #10738
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 4, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 4, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can cause validation error on MI100
|
||||
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 4, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can cause validation error on MI100
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 4, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 4, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
add_instance_library(device_batched_gemm_masking_scale_softmax_gemm_permute_instance
|
||||
device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
)
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using CPermuteNumDims_G_M_O =
|
||||
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
|
||||
// c[g, m, n] = a[g, m, k] * b[g, n, k]
|
||||
using device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// 2 of them are commented out because they trigger the clang-13 issue.
|
||||
//##############################################| ALayout| B0Layout| B1Layout| CPermuteNumDims_G_M_O| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskOut|
|
||||
//##############################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper|
|
||||
//##############################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle|
|
||||
//##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>,
|
||||
//DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>,
|
||||
//DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, true>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, true>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, true>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, true>,
|
||||
// Padded fallback kernel
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<Row,
|
||||
Col,
|
||||
Row,
|
||||
CPermuteNumDims_G_M_O,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -24,11 +24,13 @@ template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
|
||||
// c[g, m, n] = a[g, m, k] * b[g, n, k]
|
||||
template <bool Masking>
|
||||
using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
@@ -36,24 +38,25 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_
|
||||
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper|
|
||||
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle|
|
||||
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, Masking>,
|
||||
// Padded fallback kernel
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <bool Masking>
|
||||
using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
@@ -61,12 +64,14 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_
|
||||
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper|
|
||||
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle|
|
||||
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
#if CK_WORKAROUND_DISABLE_BROKEN_ATTN_KERNEL_INSTANCE == 0
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
#endif
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>,
|
||||
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
@@ -81,16 +86,45 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
false>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{});
|
||||
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<
|
||||
false>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances{});
|
||||
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances<
|
||||
false>{});
|
||||
}
|
||||
|
||||
void add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
true>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<
|
||||
true>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances<
|
||||
true>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
add_instance_library(device_batched_gemm_softmax_gemm_permute_instance
|
||||
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
|
||||
// c[g, m, n] = a[g, m, k] * b[g, n, k]
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO,
|
||||
MaskingSpecialization MaskingSpec>
|
||||
using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec|
|
||||
// #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
|
||||
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
|
||||
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
// Padded fallback kernel
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
|
||||
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskOutUpperTriangle>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
MaskingSpecialization::MaskOutUpperTriangle>{});
|
||||
}
|
||||
|
||||
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskDisabled>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
MaskingSpecialization::MaskDisabled>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -29,7 +29,8 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CLayout>
|
||||
typename CLayout,
|
||||
bool MaskOutUpperTriangle>
|
||||
bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
@@ -46,16 +47,18 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
int BatchStrideA = -1,
|
||||
int BatchStrideB0 = -1,
|
||||
int BatchStrideB1 = -1,
|
||||
int BatchStrideC = -1)
|
||||
int BatchStrideC = -1,
|
||||
float alpha = 1.f)
|
||||
|
||||
{
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
using Scale = tensor_operation::element_wise::Scale;
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
using AccDataType = float;
|
||||
@@ -67,7 +70,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
CElementOp>;
|
||||
Acc0ElementOp>;
|
||||
|
||||
// Ref Softmax: fp32 in, various type out
|
||||
using ReferenceSoftmaxInstance =
|
||||
@@ -185,7 +188,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
@@ -201,7 +204,8 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
CElementOp,
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
@@ -214,10 +218,16 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{});
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, Scale{alpha});
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
// mask out upper triangle
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
if(MaskOutUpperTriangle && idx[1] < idx[2])
|
||||
self(idx) = -ck::NumericLimits<float>::Infinity();
|
||||
});
|
||||
|
||||
auto ref_softmax = ReferenceSoftmaxInstance{};
|
||||
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
|
||||
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
|
||||
|
||||
@@ -7,10 +7,10 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -22,36 +22,32 @@
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename ADataType,
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
index_t NumDimO,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename CPermuteNumDims_G_M_O>
|
||||
bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int O,
|
||||
int G0,
|
||||
int G1,
|
||||
int StrideA = -1,
|
||||
int StrideB0 = -1,
|
||||
int StrideB1 = -1,
|
||||
int BatchStrideA = -1,
|
||||
int BatchStrideB0 = -1,
|
||||
int BatchStrideB1 = -1,
|
||||
float alpha = 1.f)
|
||||
typename Acc0BiasesDataType,
|
||||
typename Acc1BiasesDataType,
|
||||
tensor_operation::device::MaskingSpecialization MaskingSpec>
|
||||
bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int O,
|
||||
int G0,
|
||||
int G1,
|
||||
float alpha = 1.f)
|
||||
|
||||
{
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
using Scale = tensor_operation::element_wise::Scale;
|
||||
using AElementOp = PassThrough;
|
||||
@@ -60,6 +56,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
using AccDataType = float;
|
||||
using tensor_operation::device::MaskingSpecialization;
|
||||
|
||||
// Ref Gemm0: various type in, fp32 out
|
||||
using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
@@ -85,67 +82,33 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
|
||||
bool pass = true;
|
||||
|
||||
// A layout [G0, M, G1, K]
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B0 layout [G0, N, G1, K]
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B1 layout [G0, N, G1, O]
|
||||
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
|
||||
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
|
||||
|
||||
// C layout [G0, M, G1, O]
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
|
||||
|
||||
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
|
||||
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
|
||||
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
|
||||
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
|
||||
|
||||
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
|
||||
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
|
||||
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
|
||||
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
|
||||
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
|
||||
|
||||
const int BatchCount = G0 * G1;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, 1, stride}));
|
||||
}
|
||||
};
|
||||
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
|
||||
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
|
||||
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
|
||||
// C_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<ADataType> a_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<CDataType> c_gs_ms_os_host_result(
|
||||
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
|
||||
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
|
||||
Tensor<CDataType> c_gs_ms_os_device_result(
|
||||
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
|
||||
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
|
||||
// Host verification: Output of Gemm0 is input A of Gemm1
|
||||
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<CDataType> c_g_m_o_host_result(std::vector<int>{BatchCount, M, O},
|
||||
std::vector<int>{M * O, O, 1});
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
|
||||
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
|
||||
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
|
||||
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
|
||||
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
|
||||
|
||||
std::srand(1); // work around test flakiness
|
||||
@@ -157,38 +120,38 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
// or not. May want to try exact same approach as the GPU kernel in the host reference
|
||||
// GEMM+Softmax+GEMM function to see if the accuracy discrepancy goes away. Until then,
|
||||
// shrink the input value range as it is less likely to produce errors of around ~1e-3.
|
||||
// a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
// b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
// b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
// a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
// b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
// b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
|
||||
DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) *
|
||||
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) *
|
||||
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
|
||||
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
|
||||
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
@@ -196,20 +159,23 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
using DeviceOp =
|
||||
tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_O,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
using DeviceOp = tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
MaskingSpec>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
@@ -219,6 +185,26 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
|
||||
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
|
||||
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
|
||||
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
|
||||
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
|
||||
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
|
||||
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
|
||||
|
||||
// permute
|
||||
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
|
||||
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
|
||||
});
|
||||
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
|
||||
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
|
||||
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
@@ -228,7 +214,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
|
||||
// mask out upper triangle
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
if(idx[1] < idx[2])
|
||||
if(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle && idx[1] < idx[2])
|
||||
self(idx) = -ck::NumericLimits<float>::Infinity();
|
||||
});
|
||||
|
||||
@@ -265,23 +251,24 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_gs_ms_os_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
{}, // std::array<void*, 1> p_acc0_biases;
|
||||
{}, // std::array<void*, 1> p_acc1_biases;
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
|
||||
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc0_element_op,
|
||||
@@ -319,18 +306,18 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(c_gs_ms_os_device_result.mData,
|
||||
c_gs_ms_os_host_result.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a_g_m_k: ", a_g_m_k.mData, ",")
|
||||
LogRangeAsType<float>(std::cout << "a_gs_ms_ks: ", a_gs_ms_ks.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b0_g_k_n : ", b0_g_k_n.mData, ",")
|
||||
LogRangeAsType<float>(std::cout << "b0_gs_ns_ks : ", b0_gs_ns_ks.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b1_g_n_o : ", b1_g_n_o.mData, ",")
|
||||
LogRangeAsType<float>(std::cout << "b1_gs_os_ns : ", b1_gs_os_ns.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",")
|
||||
@@ -41,7 +41,7 @@ add_subdirectory(batched_gemm)
|
||||
add_subdirectory(batched_gemm_reduce)
|
||||
add_subdirectory(batched_gemm_gemm)
|
||||
add_subdirectory(batched_gemm_softmax_gemm)
|
||||
add_subdirectory(batched_gemm_masking_scale_softmax_gemm_permute)
|
||||
add_subdirectory(batched_gemm_softmax_gemm_permute)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(convnd_fwd)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
add_custom_target(test_batched_gemm_masking_scale_softmax_gemm_permute)
|
||||
|
||||
add_gtest_executable(test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp)
|
||||
target_link_libraries(test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_masking_scale_softmax_gemm_permute_instance)
|
||||
add_dependencies(test_batched_gemm_masking_scale_softmax_gemm_permute test_batched_gemm_masking_scale_softmax_gemm_permute_fp16)
|
||||
@@ -9,9 +9,13 @@ class TestBatchedGemmSoftmaxGemmFP16 : public TestBatchedGemmSoftmaxGemm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
using Masked = std::true_type;
|
||||
using NoMask = std::false_type;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row>
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row, NoMask>,
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row, Masked>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
@@ -120,7 +124,6 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16_IrregularK)
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
// TODO: enable KPadding tests when it is implemented
|
||||
TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch)
|
||||
{
|
||||
int P = 120; // requires padding
|
||||
@@ -152,12 +155,12 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMismatch)
|
||||
// IsSupported(M, N, K, O)
|
||||
// clang-format off
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
|
||||
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
|
||||
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -169,6 +172,5 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, AdhocTest)
|
||||
{1020, 1020, 64, 128, 24},
|
||||
{576, 576, 64, 64, 24},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
@@ -20,14 +20,15 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
template <typename Tuple>
|
||||
struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
|
||||
{
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<1, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<2, Tuple>;
|
||||
using CDataType = std::tuple_element_t<3, Tuple>;
|
||||
using ALayout = std::tuple_element_t<4, Tuple>;
|
||||
using B0Layout = std::tuple_element_t<5, Tuple>;
|
||||
using B1Layout = std::tuple_element_t<6, Tuple>;
|
||||
using CLayout = std::tuple_element_t<7, Tuple>;
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<1, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<2, Tuple>;
|
||||
using CDataType = std::tuple_element_t<3, Tuple>;
|
||||
using ALayout = std::tuple_element_t<4, Tuple>;
|
||||
using B0Layout = std::tuple_element_t<5, Tuple>;
|
||||
using B1Layout = std::tuple_element_t<6, Tuple>;
|
||||
using CLayout = std::tuple_element_t<7, Tuple>;
|
||||
using MaskingType = std::tuple_element_t<8, Tuple>;
|
||||
|
||||
std::vector<std::vector<int>> lengths_ = {{256, 256, 64, 64, 4},
|
||||
{256, 256, 128, 128, 4},
|
||||
@@ -54,7 +55,8 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout>(
|
||||
CLayout,
|
||||
MaskingType::value>(
|
||||
verify_, 1, false, bench_, M, N, K, O, BatchCount);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
|
||||
5
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
Normal file
5
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
add_custom_target(test_batched_gemm_softmax_gemm_permute)
|
||||
|
||||
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp)
|
||||
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp"
|
||||
#include "test_batched_gemm_softmax_gemm_permute_util.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
|
||||
@@ -10,13 +10,18 @@ class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
|
||||
{
|
||||
};
|
||||
|
||||
using I1_t = ck::Number<1>;
|
||||
using I2_t = ck::Number<2>;
|
||||
|
||||
using MaskDisabled_t =
|
||||
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
|
||||
using MaskOutUpperTriangle_t =
|
||||
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>;
|
||||
|
||||
// clang-format off
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
using CPermuteNumDims_G_M_O =
|
||||
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Row, CPermuteNumDims_G_M_O>
|
||||
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, MaskDisabled_t>,
|
||||
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, MaskOutUpperTriangle_t>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
@@ -91,7 +96,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddO)
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Bench_FP16_IrregularK)
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP16_IrregularK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{{256, 256, 160, 160, 1, 16},
|
||||
{256, 64, 160, 64, 1, 16},
|
||||
@@ -125,7 +130,6 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP1
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
// TODO: enable KPadding tests when it is implemented
|
||||
TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch)
|
||||
{
|
||||
int P = 120; // requires padding
|
||||
@@ -133,22 +137,22 @@ TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationS
|
||||
|
||||
// IsSupported(M, N, K, O)
|
||||
// clang-format off
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(Q, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MPadding>{}.IsSupported(P, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NPadding>{}.IsSupported(Q, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNPadding>{}.IsSupported(P, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::OPadding>{}.IsSupported(Q, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MOPadding>{}.IsSupported(P, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NOPadding>{}.IsSupported(Q, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNOPadding>{}.IsSupported(P, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(Q, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MPadding>{}.IsSupported(P, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NPadding>{}.IsSupported(Q, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNPadding>{}.IsSupported(P, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::OPadding>{}.IsSupported(Q, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MOPadding>{}.IsSupported(P, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NOPadding>{}.IsSupported(Q, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNOPadding>{}.IsSupported(P, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -156,13 +160,13 @@ TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationS
|
||||
{
|
||||
// IsSupported(M, N, K, O)
|
||||
// clang-format off
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
|
||||
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
|
||||
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -174,6 +178,5 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
|
||||
{1020, 1020, 64, 128, 4, 6},
|
||||
{576, 576, 64, 64, 4, 6},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->Run();
|
||||
}
|
||||
@@ -4,10 +4,14 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <vector>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp"
|
||||
#include "profiler/include/profile_batched_gemm_softmax_gemm_permute_impl.hpp"
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
using ck::tensor_operation::device::MaskingSpecialization;
|
||||
using ck::tensor_operation::device::TensorSpecialization;
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
@@ -20,14 +24,18 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
template <typename Tuple>
|
||||
struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
|
||||
{
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<1, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<2, Tuple>;
|
||||
using CDataType = std::tuple_element_t<3, Tuple>;
|
||||
using ALayout = std::tuple_element_t<4, Tuple>;
|
||||
using B0Layout = std::tuple_element_t<5, Tuple>;
|
||||
using B1Layout = std::tuple_element_t<6, Tuple>;
|
||||
using CPermuteNumDims_G_M_O = std::tuple_element_t<7, Tuple>;
|
||||
using NumDimGType = std::tuple_element_t<0, Tuple>;
|
||||
using NumDimMType = std::tuple_element_t<1, Tuple>;
|
||||
using NumDimNType = std::tuple_element_t<2, Tuple>;
|
||||
using NumDimKType = std::tuple_element_t<3, Tuple>;
|
||||
using NumDimOType = std::tuple_element_t<4, Tuple>;
|
||||
using ADataType = std::tuple_element_t<5, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<6, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<7, Tuple>;
|
||||
using CDataType = std::tuple_element_t<8, Tuple>;
|
||||
using Acc0BiasDataType = std::tuple_element_t<9, Tuple>;
|
||||
using Acc1BiasDataType = std::tuple_element_t<10, Tuple>;
|
||||
using MaskingType = std::tuple_element_t<11, Tuple>;
|
||||
|
||||
std::vector<std::vector<int>> lengths_ = {
|
||||
{256, 256, 64, 64, 6, 4},
|
||||
@@ -42,15 +50,20 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
|
||||
|
||||
void RunSingle(int M, int N, int K, int O, int G0, int G1)
|
||||
{
|
||||
bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl<
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_O>(verify_, 1, false, bench_, M, N, K, O, G0, G1);
|
||||
bool pass =
|
||||
ck::profiler::profile_batched_gemm_softmax_gemm_permute_impl<NumDimGType::value,
|
||||
NumDimMType::value,
|
||||
NumDimNType::value,
|
||||
NumDimKType::value,
|
||||
NumDimOType::value,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
MaskingType::value>(
|
||||
verify_, 1, false, bench_, M, N, K, O, G0, G1);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
@@ -72,19 +85,13 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
|
||||
};
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
using CPermuteNumDims_G_M_O =
|
||||
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
@@ -103,14 +110,17 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
|
||||
using DeviceGemmGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_O,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
@@ -119,6 +129,10 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
TensorSpecialization::Default, // ATensorSpec
|
||||
TensorSpecialization::Default, // B0TensorSpec
|
||||
TensorSpecialization::Default, // B1TensorSpec
|
||||
TensorSpecialization::Default, // CTensorSpec
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
@@ -159,29 +173,48 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
true>; // Masking
|
||||
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle
|
||||
|
||||
bool IsSupported(int M, int N, int K, int O)
|
||||
{
|
||||
const int G0 = 1, G1 = 1;
|
||||
|
||||
// A layout [G0, M, G1, K]
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B0 layout [G0, N, G1, K]
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B1 layout [G0, N, G1, O]
|
||||
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
|
||||
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
|
||||
|
||||
// C layout [G0, M, G1, O]
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
|
||||
|
||||
auto gemm = DeviceGemmGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
|
||||
static_cast<B0DataType*>(nullptr),
|
||||
static_cast<B1DataType*>(nullptr),
|
||||
static_cast<CDataType*>(nullptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
0, // BatchCount
|
||||
{0, 0, M, O}, // gs ms ns lengths
|
||||
{0, O, 0, 1}, // gs ms ns strides
|
||||
0, // StrideA
|
||||
0, // StrideB0
|
||||
0, // StrideB1
|
||||
0, // BatchStrideA
|
||||
0, // BatchStrideB0
|
||||
0, // BatchStrideB1
|
||||
{}, // p_acc0_biases
|
||||
{}, // p_acc1_biases
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
{}, // acc0_biases_gs_ms_ns_lengths
|
||||
{}, // acc0_biases_gs_ms_ns_strides
|
||||
{}, // acc1_biases_gs_ms_os_lengths
|
||||
{}, // acc1_biases_gs_ms_os_strides
|
||||
PassThrough{}, // a_element_op
|
||||
PassThrough{}, // b0_element_op
|
||||
Scale{1.f}, // acc0_element_op
|
||||
@@ -12,28 +12,91 @@
|
||||
|
||||
using namespace ck;
|
||||
|
||||
void traverse_using_space_filling_curve();
|
||||
void traverse_using_space_filling_curve_linear();
|
||||
void traverse_using_space_filling_curve_snakecurved();
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
(void)argc;
|
||||
(void)argv;
|
||||
|
||||
traverse_using_space_filling_curve();
|
||||
traverse_using_space_filling_curve_linear();
|
||||
traverse_using_space_filling_curve_snakecurved();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void traverse_using_space_filling_curve()
|
||||
void traverse_using_space_filling_curve_linear()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
using TensorLengths = Sequence<16, 10, 9>;
|
||||
using DimAccessOrder = Sequence<2, 0, 1>;
|
||||
using ScalarsPerAccess = Sequence<4, 2, 3>;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess>;
|
||||
using TensorLengths = Sequence<3, 2, 2>;
|
||||
using DimAccessOrder = Sequence<2, 0, 1>;
|
||||
using ScalarsPerAccess = Sequence<1, 1, 1>;
|
||||
using SpaceFillingCurve =
|
||||
SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess, false>;
|
||||
|
||||
constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
|
||||
make_tuple(0, 1, 0),
|
||||
make_tuple(1, 0, 0),
|
||||
make_tuple(1, 1, 0),
|
||||
make_tuple(2, 0, 0),
|
||||
make_tuple(2, 1, 0),
|
||||
make_tuple(0, 0, 1),
|
||||
make_tuple(0, 1, 1),
|
||||
make_tuple(1, 0, 1),
|
||||
make_tuple(1, 1, 1),
|
||||
make_tuple(2, 0, 1),
|
||||
make_tuple(2, 1, 1));
|
||||
|
||||
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == reduce_on_sequence(TensorLengths{} / ScalarsPerAccess{},
|
||||
math::multiplies{},
|
||||
Number<1>{}));
|
||||
|
||||
static_for<1, num_access, 1>{}([&](auto i) {
|
||||
constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);
|
||||
|
||||
static_assert(idx_curr[I0] == expected[i][I0]);
|
||||
static_assert(idx_curr[I1] == expected[i][I1]);
|
||||
static_assert(idx_curr[I2] == expected[i][I2]);
|
||||
|
||||
constexpr auto backward_step = SpaceFillingCurve::GetBackwardStep(i);
|
||||
constexpr auto expected_step = expected[i - I1] - expected[i];
|
||||
static_assert(backward_step[I0] == expected_step[I0]);
|
||||
static_assert(backward_step[I1] == expected_step[I1]);
|
||||
static_assert(backward_step[I2] == expected_step[I2]);
|
||||
});
|
||||
|
||||
static_for<0, num_access - 1, 1>{}([&](auto i) {
|
||||
constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);
|
||||
|
||||
static_assert(idx_curr[I0] == expected[i][I0]);
|
||||
static_assert(idx_curr[I1] == expected[i][I1]);
|
||||
static_assert(idx_curr[I2] == expected[i][I2]);
|
||||
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(i);
|
||||
constexpr auto expected_step = expected[i + I1] - expected[i];
|
||||
static_assert(forward_step[I0] == expected_step[I0]);
|
||||
static_assert(forward_step[I1] == expected_step[I1]);
|
||||
static_assert(forward_step[I2] == expected_step[I2]);
|
||||
});
|
||||
}
|
||||
|
||||
void traverse_using_space_filling_curve_snakecurved()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
using TensorLengths = Sequence<16, 10, 9>;
|
||||
using DimAccessOrder = Sequence<2, 0, 1>;
|
||||
using ScalarsPerAccess = Sequence<4, 2, 3>;
|
||||
using SpaceFillingCurve =
|
||||
SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess, true>;
|
||||
|
||||
constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
|
||||
make_tuple(0, 2, 0),
|
||||
|
||||
Reference in New Issue
Block a user