bf16A_Int8B with fastgelu/bias (#1264)

* changed the copy function to v7r2

* adding multi_abd

* in-progress

* add post-load oob check

* debugging

* adjust instances

* add run_lds

* add elemntwise_op

* replace multi_abd_device with v3

* clean up

* clean

* clean

* Added LDSType

* profiling

* adjust oobcheck

* add missing file

* refactor

* clean

* add examples

[ROCm/composable_kernel commit: 0d0150db20]
This commit is contained in:
zjing14
2024-04-26 07:26:30 -05:00
committed by GitHub
parent 672831fa3b
commit 1c712ea255
37 changed files with 4752 additions and 970 deletions

View File

@@ -10,4 +10,7 @@ if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf
add_executable(client_gemm_bf16_i8_bf16 gemm_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_multiply_bf16_i8_bf16 gemm_xdl_multiply_bf16_i8.cpp)
target_link_libraries(client_gemm_multiply_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
endif()

View File

@@ -38,19 +38,19 @@ using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
using CDEElementOp = AddFastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

View File

@@ -36,7 +36,7 @@ using D0DataType = BF16;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Col;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
@@ -45,12 +45,12 @@ using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
using CDEElementOp = Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

View File

@@ -37,19 +37,19 @@ using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
using CDEElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
@@ -74,12 +74,12 @@ struct SimpleDeviceMem
int main(int argc, char* argv[])
{
// GEMM shape
ck::index_t M = 64;
ck::index_t N = 1024;
ck::index_t K = 512;
ck::index_t M = 4096;
ck::index_t N = 768;
ck::index_t K = 6144;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
if(argc == 1)

View File

@@ -37,19 +37,19 @@ using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
using CDEElementOp = FastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

View File

@@ -0,0 +1,220 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iomanip>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using DsDataType = ck::Tuple<B1DataType>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<B1Layout>;
using ELayout = Row;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Multiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
// clang-format on
int main(int argc, char* argv[])
{
// GEMM shape
ck::index_t M = 4096;
ck::index_t N = 768;
ck::index_t K = 6144;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 7)
{
M = std::stoi(argv[1]);
N = std::stoi(argv[2]);
K = std::stoi(argv[3]);
StrideA = std::stoi(argv[4]);
StrideB = std::stoi(argv[5]);
StrideE = std::stoi(argv[6]);
}
else
{
printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideE\n");
exit(0);
}
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if constexpr(std::is_same<Layout, Row>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
SimpleDeviceMem a0_device_buf(sizeof(A0DataType) *
f_matrix_space_size(M, K, StrideA, A0Layout{}));
SimpleDeviceMem b0_device_buf(sizeof(B0DataType) *
f_matrix_space_size(K, N, StrideB, B0Layout{}));
SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{}));
SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{}));
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 1;
constexpr ck::index_t NumDTensor = 1;
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
Row,
AsDataType,
BsDataType,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{b1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB},
std::array<ck::index_t, NumDTensor>{0},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return 0;
}

View File

@@ -38,19 +38,19 @@ using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
using CDEElementOp = AddFastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

View File

@@ -36,7 +36,7 @@ using D0DataType = BF16;
using DsDataType = ck::Tuple<>;
using EDataType = BF16;
using A0Layout = Col;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
@@ -45,12 +45,12 @@ using D0Layout = Row;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
using CDEElementOp = FastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

View File

@@ -52,12 +52,12 @@ using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using DsLayout = ck::Tuple<Row>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

View File

@@ -1,2 +1,4 @@
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
add_example_executable(example_gemm_multi_ABD_xdl_bf16_i8 gemm_multi_ABD_xdl_bf16_i8.cpp)
add_example_executable(example_gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp)
add_example_executable(example_gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp)
add_example_executable(example_gemm_multi_ABD_xdl_fastgelu_bf16_i8 gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp)

View File

@@ -18,9 +18,12 @@
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
@@ -41,22 +44,22 @@ using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
using CDEElementOp = AddFastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle
// clang-format off
@@ -64,9 +67,9 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>;
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
@@ -74,13 +77,13 @@ int main(int argc, char* argv[])
bool time_kernel = false;
// GEMM shape
ck::index_t M = 64;
ck::index_t N = 1024;
ck::index_t K = 512;
ck::index_t M = 4096;
ck::index_t N = 768;
ck::index_t K = 6144;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD = N;
ck::index_t StrideB = N;
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
if(argc == 1)

View File

@@ -0,0 +1,273 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = BF16;
using DsDataType = ck::Tuple<>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Multiply;
using CDEElementOp = FastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 4096;
ck::index_t N = 768;
ck::index_t K = 6144;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{}));
Tensor<D0DataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
b1_device_buf.ToDevice(b1_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 2;
constexpr ck::index_t NumDTensor = 0;
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB, 0},
std::array<ck::index_t, NumDTensor>{},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
Tensor<A0DataType> a_m_k({M, K});
Tensor<B1DataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n));
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B1DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}

View File

@@ -0,0 +1,274 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType>;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = BF16;
using DsDataType = ck::Tuple<B1DataType, D0DataType>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<B1Layout, D0Layout>;
using ELayout = Row;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyAddFastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 4096;
ck::index_t N = 768;
ck::index_t K = 6144;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{}));
Tensor<D0DataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
b1_device_buf.ToDevice(b1_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 1;
constexpr ck::index_t NumDTensor = 2;
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{b1_device_buf.GetDeviceBuffer(),
d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB},
std::array<ck::index_t, NumDTensor>{0, StrideD},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
Tensor<A0DataType> a_m_k({M, K});
Tensor<B1DataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
#if 0
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n));
}
}
#endif
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), b1_k_n(0, n), d_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}

View File

@@ -41,7 +41,8 @@ template <typename ThreadGroup,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags>
typename ThreadTransferDstResetCoordinateAfterRunFlags,
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v7r2
{
static constexpr index_t nDim =
@@ -100,7 +101,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
@@ -117,29 +118,33 @@ struct ThreadGroupTensorSliceTransfer_v7r2
}
}
template <typename SrcBuffers>
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs);
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename DstBuffers>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWrite(dst_descs, dst_bufs);
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
else
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs));
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
}
}
@@ -206,7 +211,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2
SrcScalarPerVector,
DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags>;
ThreadTransferDstResetCoordinateAfterRunFlags,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};

View File

@@ -92,15 +92,6 @@ struct Add
};
};
struct Scales
{
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
y = ck::type_convert<Y>(ck::type_convert<float>(x0) * ck::type_convert<float>(x1));
}
};
struct Max
{
template <typename Y, typename X0, typename X1>
@@ -188,6 +179,16 @@ struct Multiply
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const int8_t& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x0);
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x1_tmp * x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
@@ -521,6 +522,71 @@ struct AddFastGelu
}
};
// E = MultiplyFastGelu(C + D)
struct MultiplyFastGelu
{
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{
const float x = c * d;
FastGelu{}.template operator()<float, float>(e, x);
}
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{
const half_t x = c * d;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
template <>
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
{
const float x0_f = c * d;
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<half_t>(x1_f);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
{
const float x0_f = type_convert<float>(c) * type_convert<float>(d);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = type_convert<bhalf_t>(x1_f);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
{
const float x0_f = c * type_convert<float>(d);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = type_convert<bhalf_t>(x1_f);
}
};
// E = Silu(C + D)
struct AddSilu
{

View File

@@ -221,6 +221,15 @@ struct MultiplyAdd
e = y;
}
template <>
__host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
const float& c,
const bhalf_t& d0,
const bhalf_t& d1) const
{
const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
const float& c,
const half_t& d0,
@@ -240,6 +249,26 @@ struct MultiplyAdd
}
};
struct MultiplyAddFastGelu
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<ck::bhalf_t, float, ck::bhalf_t, ck::bhalf_t>(
ck::bhalf_t& e, const float& c, const ck::bhalf_t& d0, const ck::bhalf_t& d1) const
{
const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = ck::type_convert<ck::bhalf_t>(x1_f);
}
};
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{

View File

@@ -504,6 +504,16 @@ struct FastGelu
y = type_convert<half_t>(y_f);
}
template <>
__host__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<bhalf_t>(y_f);
}
template <>
__device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{

View File

@@ -594,11 +594,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
#if 0
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
AsDataType,
@@ -616,7 +611,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
uniform_sequence_gen_t<NumATensor, false>,
uniform_sequence_gen_t<NumATensor, AThreadTransferSrcResetCoordinateAfterRun>,
Sequence<true>>{as_grid_desc_ak0_m_ak1,
idx_as_block_begin,
tie(a_block_desc_ak0_m_ak1),
@@ -627,11 +622,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
#if 0
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
BsDataType,
@@ -649,7 +639,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
uniform_sequence_gen_t<NumBTensor, false>,
uniform_sequence_gen_t<NumBTensor, BThreadTransferSrcResetCoordinateAfterRun>,
Sequence<true>>{bs_grid_desc_bk0_n_bk1,
idx_bs_block_begin,
tie(b_block_desc_bk0_n_bk1),

File diff suppressed because it is too large Load Diff

View File

@@ -42,7 +42,8 @@ template <typename SrcDatas,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags> // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v7r2
{
static constexpr auto I0 = Number<0>{};
@@ -139,14 +140,19 @@ struct ThreadwiseTensorSliceTransfer_v7r2
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template <typename SrcBuffers,
index_t ThreadScratchId = 0,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// loop over space-filling curve
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
bool oob_val = true;
// copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
@@ -155,9 +161,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
src_coords_[i]);
oob_val = oob_val & is_src_valid;
src_vectors(i).template AsType<src_vector_t>()(I0) =
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(),
is_src_valid);
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
});
constexpr auto get_elem_op_vec_len = []() {
@@ -218,7 +225,8 @@ struct ThreadwiseTensorSliceTransfer_v7r2
unpack2(element_op_, dst_data_refs, src_data_refs);
});
elm_vectors_tuple_(iAccess) = elm_vectors;
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val;
// move coordinate
if constexpr(iAccess.value != src_num_access - 1)
@@ -245,17 +253,38 @@ struct ThreadwiseTensorSliceTransfer_v7r2
});
}
__device__ void TransposeFromElmToDst()
#if 1
template <index_t ThreadScratchId = 0>
__device__ void OOBCheck(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// loop over space-filling curve
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess];
auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess];
static_for<0, nDst, 1>{}([&](auto i) {
using elm_vector_t = typename remove_cvref_t<decltype(elm_vectors[i])>::type;
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[I0] : elm_vector_t{0};
});
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
});
}
#endif
template <index_t ThreadScratchId = 0>
__device__ void
TransposeFromElmToDst(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
using SrcThreadScratch =
using ElmThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
SrcScalarPerVector,
decltype(GetSrcThreadScratchDescriptor()),
true>;
using DstThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
@@ -263,15 +292,17 @@ struct ThreadwiseTensorSliceTransfer_v7r2
decltype(GetDstThreadScratchDescriptor()),
true>;
SrcThreadScratch elm_thread_scratch_;
ElmThreadScratch elm_thread_scratch_;
DstThreadScratch dst_thread_scratch_;
elm_thread_scratch_.data_ =
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_);
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_[thread_scratch_id]);
if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
@@ -338,20 +369,24 @@ struct ThreadwiseTensorSliceTransfer_v7r2
[&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; });
}
dst_vectors_tuple_ = bit_cast<decltype(dst_vectors_tuple_)>(dst_thread_scratch_.data_);
dst_vectors_tuple_(thread_scratch_id) = bit_cast<DstVectorTuple>(dst_thread_scratch_.data_);
}
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename DstBuffers,
index_t ThreadScratchId = 0,
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
TransposeFromElmToDst();
OOBCheck(thread_scratch_id);
TransposeFromElmToDst(thread_scratch_id);
// loop over space-filling curve
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[Number<iAccess>{}];
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
@@ -578,8 +613,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
StaticallyIndexedArray<ElmVectorsType, src_num_access> elm_vectors_tuple_;
StaticallyIndexedArray<DstVectorsType, dst_num_access> dst_vectors_tuple_;
using ElmVectorTuple = StaticallyIndexedArray<ElmVectorsType, src_num_access>;
using DstVectorTuple = StaticallyIndexedArray<DstVectorsType, dst_num_access>;
StaticallyIndexedArray<ElmVectorTuple, NumThreadScratch> elm_vectors_tuple_;
StaticallyIndexedArray<DstVectorTuple, NumThreadScratch> dst_vectors_tuple_;
using OOBVectorTuple = StaticallyIndexedArray<bool, src_num_access>;
StaticallyIndexedArray<OOBVectorTuple, NumThreadScratch> oob_vectors_tuple_;
SrcCoords src_coords_;
DstCoords dst_coords_;

View File

@@ -40,23 +40,10 @@ inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y bit_cast(const X& x)
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
Y y;
static_assert(__has_builtin(__builtin_bit_cast), "");
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
// auto t = reinterpret_cast<const Y*>(&x);
// y = *t;
__builtin_memcpy(&y, &x, sizeof(X));
return y;
#else
union AsType
{
X x;
Y y;
};
return AsType{x}.y;
#endif
return __builtin_bit_cast(Y, x);
}
} // namespace ck

View File

@@ -91,23 +91,26 @@ using GK_Tuple = ck::Tuple<G_K>;
using GK_GK_Tuple = ck::Tuple<G_K, G_K>;
// pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using TanH = ck::tensor_operation::element_wise::TanH;
using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddSilu = ck::tensor_operation::element_wise::AddSilu;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish;
using Add = ck::tensor_operation::element_wise::Add;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu;
using TanH = ck::tensor_operation::element_wise::TanH;
using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddSilu = ck::tensor_operation::element_wise::AddSilu;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish;
using Add = ck::tensor_operation::element_wise::Add;
using Multiply = ck::tensor_operation::element_wise::Multiply;
template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;

View File

@@ -17,7 +17,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
@@ -33,7 +33,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
AddFastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
@@ -46,7 +46,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
Add>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
@@ -59,7 +59,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
FastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
@@ -72,7 +72,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
PassThrough>>>& instances);
// RCR
@@ -86,7 +86,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances(
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
AddFastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(
@@ -99,7 +99,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
Add>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(
@@ -112,7 +112,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
FastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(
@@ -125,7 +125,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
PassThrough>>>& instances);
// CRR
@@ -139,7 +139,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances(
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
AddFastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(
@@ -152,7 +152,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
Add>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(
@@ -165,7 +165,7 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
FastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(
@@ -178,8 +178,62 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
PassThrough>>>& instances);
// Multiply
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>>>& instances);
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances);
#endif
// GEMM + Add + Gelu
@@ -201,7 +255,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
AddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
@@ -213,7 +267,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
AddFastGelu>;
static auto GetInstances()
@@ -271,7 +325,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
Add>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
@@ -283,7 +337,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
Add>;
static auto GetInstances()
@@ -341,7 +395,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
FastGelu>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
@@ -353,7 +407,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
FastGelu>;
static auto GetInstances()
@@ -411,7 +465,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
PassThrough>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
@@ -423,7 +477,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
PassThrough>;
static auto GetInstances()
@@ -462,6 +516,234 @@ struct DeviceOperationInstanceFactory<
}
};
// Multiply
// GEMM + Add + Gelu
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
MultiplyAddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
MultiplyAddFastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8>> &&
is_same_v<DsDataType, ck::Tuple<BF16, BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row>> &&
is_same_v<DsLayout, ck::Tuple<Row, Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
// GEMM + Add
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
MultiplyAdd>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
MultiplyAdd>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8>> &&
is_same_v<DsDataType, ck::Tuple<BF16, BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row>> &&
is_same_v<DsLayout, ck::Tuple<Row, Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
// GEMM + Gelu
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
MultiplyFastGelu>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
MultiplyFastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8>> &&
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
// GEMM
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
Multiply>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
Multiply>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8>> &&
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation

View File

@@ -17,7 +17,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
@@ -32,7 +32,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_g
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
AddFastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
@@ -45,7 +45,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_i
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
Add>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
@@ -58,7 +58,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_i
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
FastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
@@ -71,7 +71,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instan
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
PassThrough>>>& instances);
// RCR
@@ -85,7 +85,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_g
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
AddFastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
@@ -98,7 +98,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_i
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
Add>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
@@ -111,7 +111,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_i
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
FastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
@@ -124,7 +124,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instan
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
PassThrough>>>& instances);
// CRR
@@ -138,7 +138,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_g
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
AddFastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
@@ -151,7 +151,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_i
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Multiply,
Add>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
@@ -164,7 +164,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_i
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
FastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
@@ -177,7 +177,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instan
ck::Tuple<>,
BF16,
PassThrough,
Scales,
Multiply,
PassThrough>>>& instances);
// GEMM + Add + Gelu
@@ -199,7 +199,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
AddFastGelu>>
{
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
@@ -211,7 +211,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
AddFastGelu>;
static auto GetInstances()
@@ -270,7 +270,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
Add>>
{
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
@@ -282,7 +282,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
Add>;
static auto GetInstances()
@@ -341,7 +341,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
FastGelu>>
{
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
@@ -353,7 +353,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
FastGelu>;
static auto GetInstances()
@@ -412,7 +412,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
PassThrough>>
{
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
@@ -424,7 +424,7 @@ struct DeviceOperationInstanceFactory<
DsDataType,
EDataType,
PassThrough,
Scales,
Multiply,
PassThrough>;
static auto GetInstances()

View File

@@ -4,7 +4,8 @@ set(GEMM_MULTI_ABD_INSTANCES)
list(APPEND GEMM_MULTI_ABD_INSTANCES
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_km_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
)
add_instance_library(device_gemm_multi_abd_instance ${GEMM_MULTI_ABD_INSTANCES})

View File

@@ -1,101 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
// using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Col;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
// using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
using BElementOp = Scales;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::PipelineVersion PipVer,
ck::LoopScheduler LoopSche>
using device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances = std::tuple<
// clang-format off
//###############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| K0Per| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//###############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//PipelineVersion::v1
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 48, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 24, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, LoopSche, PipVer>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -27,13 +27,13 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
// using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using CShuffleDataType = F32;
using D0DataType = BF16;
// using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
@@ -42,57 +42,84 @@ using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
// using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
// using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
using BElementOp = Scales;
// using BElementOp = Multiply;
// using CDEElementOp = AddFastGelu;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
template <typename DsLayout,
template <typename BsLayout,
typename DsLayout,
typename BsDataType,
typename DsDataType,
typename BElementOp,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::PipelineVersion PipVer,
ck::LoopScheduler LoopSche>
using device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple<
BlockGemmPipelineScheduler BlkGemmPipeSched>
using device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple<
// clang-format off
//###############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| K0Per| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//###############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//PipelineVersion::v1
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 48, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 24, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, LoopSche, PipVer>
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <typename BsLayout,
typename DsLayout,
typename BsDataType,
typename DsDataType,
typename BElementOp,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched>
using device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple<
// clang-format off
//###############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| K0Per| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//###############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
} // namespace instance

View File

@@ -47,14 +47,18 @@ using D0Layout = Row;
// using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
@@ -66,33 +70,52 @@ template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::PipelineVersion PipVer,
ck::LoopScheduler LoopSche>
using device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple<
BlockGemmPipelineScheduler BlkGemmPipeSched>
using device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances = std::tuple<
// clang-format off
//###############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| K0Per| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//###############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//PipelineVersion::v1
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 48, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 24, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, LoopSche, PipVer>
//Compute-bound
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 224, 256, 64, 8, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched>
using device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_mem_instances = std::tuple<
// clang-format off
//###############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| K0Per| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//###############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 16, 32, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 32, 128, 8, 16, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 32, 128, 8, 16, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 16, 32, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 16, 64, 128, 8, 16, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 64, 128, 8, 16, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 16, 128, 128, 8, 16, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 128, 128, 8, 16, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 32, 256, 128, 8, 16, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
} // namespace instance

View File

@@ -1,115 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
AddFastGelu,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
Add,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances<ck::Tuple<>,
ck::Tuple<>,
PassThrough,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances<ck::Tuple<>,
ck::Tuple<>,
FastGelu,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -19,94 +19,143 @@ namespace instance {
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
Multiply,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
AddFastGelu,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
Multiply,
AddFastGelu,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
Multiply,
AddFastGelu,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
Multiply,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
Add,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
Multiply,
Add,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<D0Layout>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<D0DataType>,
Multiply,
Add,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
Multiply,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<ck::Tuple<>,
ck::Tuple<>,
PassThrough,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
PassThrough,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
PassThrough,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
Multiply,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<ck::Tuple<>,
ck::Tuple<>,
FastGelu,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
FastGelu,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout, B1Layout>,
ck::Tuple<>,
ck::Tuple<B0DataType, B1DataType>,
ck::Tuple<>,
Multiply,
FastGelu,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance

View File

@@ -32,12 +32,11 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances(
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
AddFastGelu,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
AddFastGelu,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(
@@ -55,12 +54,11 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
Add,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
Add,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(
@@ -78,12 +76,11 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<ck::Tuple<>,
ck::Tuple<>,
PassThrough,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances<ck::Tuple<>,
ck::Tuple<>,
PassThrough,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(
@@ -101,12 +98,11 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<ck::Tuple<>,
ck::Tuple<>,
FastGelu,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances<ck::Tuple<>,
ck::Tuple<>,
FastGelu,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance

View File

@@ -0,0 +1,163 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
EDataType,
AElementOp,
PassThrough,
MultiplyAddFastGelu>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
PassThrough,
MultiplyAddFastGelu,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
PassThrough,
MultiplyAddFastGelu,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
EDataType,
AElementOp,
PassThrough,
MultiplyAdd>>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
PassThrough,
MultiplyAdd,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
ck::Tuple<B0Layout>,
ck::Tuple<D0Layout, B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<D0DataType, B1DataType>,
PassThrough,
MultiplyAdd,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
EDataType,
AElementOp,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
Multiply,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
Multiply,
GemmMNKPadding,
Interwave>{});
}
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ELayout,
AsDataType,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
EDataType,
AElementOp,
PassThrough,
MultiplyFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
MultiplyFastGelu,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<B0Layout>,
ck::Tuple<B1Layout>,
ck::Tuple<B0DataType>,
ck::Tuple<B1DataType>,
PassThrough,
MultiplyFastGelu,
GemmMNKPadding,
Interwave>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -47,14 +47,14 @@ using D0Layout = Row;
// using DsLayout = ck::Tuple<Row>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

View File

@@ -47,14 +47,14 @@ using D0Layout = Row;
// using DsLayout = ck::Tuple<Row>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

View File

@@ -47,14 +47,14 @@ using D0Layout = Row;
// using DsLayout = ck::Tuple<Row>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;