external api for gemm + layernorm (#285)

* Extract base class for elementwise

* Refactor interface of DeviceGemmReduce. Do not use tuple in interface

* [What] Rename d into reduce in gemm + reduction related code
[Why] Prepare to add d term for add

* Unify base class of gemm + reduce and gemm + bias + add + reduce

* 1. Rename gemm_bias_add_reduce for external api
 2. Refine cmake

* Add normalize device operation

* [What] Reorder the argument
[Why] Because d0 is also the input of c.

* Add type string

* Add example of gemm_bias_add_layernorm  via external api

* Refactor example code

* clang-format

* Fix compile error

* clang-format

* Add external api for gemm_add_add_layernorm and normalize

* Add client example

* clang-format
This commit is contained in:
rocking5566
2022-06-28 03:25:10 +08:00
committed by GitHub
parent aebd211c36
commit 12235112a1
47 changed files with 2577 additions and 1946 deletions

View File

@@ -0,0 +1,2 @@
add_executable(gemm_add_add_reduce_normalize gemm_add_add_layernorm.cpp)
target_link_libraries(gemm_add_add_reduce_normalize PRIVATE composable_kernel::device_operations)

View File

@@ -0,0 +1,270 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp"
using F16 = ck::half_t;
using F32 = float;
using ADataType = F16;
using BDataType = F16;
using BiasDataType = F32;
using CDataType = F16;
using D0DataType = F16;
using ReduceDataType = F32;
using GammaDataType = F16;
using BetaDataType = F16;
using LayerNormOutDataType = F16;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
template <typename gemm_reduce_op_ptr>
bool RunDeviceGemmMeanSquareMean(gemm_reduce_op_ptr& p_op,
const void* p_a,
const void* p_b,
const void* p_bias,
const void* p_d0,
void* p_c,
void* p_mean,
void* p_square_mean,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideC,
int StrideD0,
bool time_kernel)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
auto passOp = PassThrough{};
auto squareOp = UnarySquareElementOp{};
auto divOp = UnaryDivElementOp{N};
auto argument_ptr =
p_op->MakeArgumentPointer(p_a,
p_b,
p_bias,
{p_d0},
p_c,
{p_mean, p_square_mean},
M,
N,
K,
StrideA,
StrideB,
StrideC,
{StrideD0},
{&passOp, &passOp, &passOp}, // functor for a, b, c
{&passOp}, // functor for d0
{&passOp, &squareOp}, // functor for inputs of reduction
{&divOp, &divOp}); // functor for outputs of reduction
if(p_op->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = p_op->MakeInvokerPointer();
// If we evaluate running time of gemm_reduce. The output may wrong.
// Because we need to initialize the reduction tensor before runing the kernel.
// However we run kernel many times for time_kernel = trie without reinitialize the out
// of reduction tensor.
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel)
std::cout << "Gemm + reduce Perf: " << std::setw(10) << ave_time << " ms" << std::endl;
return true;
}
return false;
}
template <typename normalize_op_ptr>
bool RunDeviceNormalize2D(normalize_op_ptr& p_op,
const void* p_x,
const void* p_mean,
const void* p_square_mean,
const void* p_gamma,
const void* p_beta,
void* p_y,
int M,
int N,
int StrideX,
bool time_kernel)
{
std::array<const void*, 5> input = {p_x, p_mean, p_square_mean, p_gamma, p_beta};
std::array<void*, 1> output = {p_y};
auto normalize_functor = ck::tensor_operation::element_wise::Normalize{};
auto argument_ptr = p_op->MakeArgumentPointer(input,
output,
{M, N},
{{StrideX, 1}, {1, 0}, {1, 0}, {0, 1}, {0, 1}},
{{StrideX, 1}},
ck::tensor_operation::element_wise::Normalize{});
if(p_op->IsSupportedArgument(argument_ptr.get()))
{
auto invoker_ptr = p_op->MakeInvokerPointer();
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel)
std::cout << "Normalize Perf: " << std::setw(10) << ave_time << " ms" << std::endl;
return true;
}
return false;
}
int main()
{
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideC = 1024;
ck::index_t StrideD0 = 1024;
const auto gemm_reduce_ptrs = ck::tensor_operation::device::device_gemm_instance::
get_device_gemm_add_add_mean_squaremean_instances<ADataType,
BDataType,
CDataType,
ALayout,
BLayout,
CLayout>();
const auto normalize_ptrs =
ck::tensor_operation::device::get_device_normalize_from_mean_meansquare_instances<
CDataType,
ReduceDataType,
ReduceDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>();
std::cout << "found " << gemm_reduce_ptrs.size()
<< " gemm_reduceMean_reduceSquareMean instances" << std::endl;
std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl;
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
SimpleDeviceMem bias_device_buf(sizeof(BiasDataType) * N);
SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{}));
SimpleDeviceMem d0_device_buf(sizeof(D0DataType) *
f_matrix_space_size(M, N, StrideD0, CLayout{}));
SimpleDeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * M);
SimpleDeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * M);
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N);
SimpleDeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * M * N);
bool b_time_kernel = true;
bool b_only_run_first_kernel = true;
// layernorm => (1) + (2)
// (1). c = gemm(a, b), reduce_mean(c), reduce_square_mean(c)
// (2). normalize(c, mean, square_mean, gamma, beta)
for(auto& gemm_reduce_ptr : gemm_reduce_ptrs)
{
// run first available kernel
if(RunDeviceGemmMeanSquareMean(gemm_reduce_ptr,
a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
bias_device_buf.GetDeviceBuffer(),
d0_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideD0,
b_time_kernel))
{
if(b_only_run_first_kernel)
break;
}
else
{
std::cout << gemm_reduce_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
for(auto& normalize_ptr : normalize_ptrs)
{
if(RunDeviceNormalize2D(normalize_ptr,
c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
layerNorm_device_buf.GetDeviceBuffer(),
M,
N,
StrideC,
b_time_kernel))
{
if(b_only_run_first_kernel)
break;
}
else
{
std::cout << normalize_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
}

View File

@@ -7,3 +7,4 @@ find_package(hip REQUIRED PATHS /opt/rocm)
message(STATUS "Build with HIP ${hip_VERSION}")
add_subdirectory(02_gemm_add_add_fastgelu)
add_subdirectory(03_gemm_layernorm)

View File

@@ -33,19 +33,19 @@ using BDataType = F16;
using CDataType = F16;
using GemmAccDataType = F32;
using ReduceAccDataType = F32;
using DDataType = F64;
using DPtrsGlobal = ck::Tuple<DDataType*>;
using ReduceDataType = F64;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using DsReduceOp = ck::Tuple<ck::reduce::Max>;
using DsElementOp = ck::Tuple<ck::tensor_operation::element_wise::PassThrough>;
using DGlobalMemOp =
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using ReduceOps = ck::Tuple<ck::reduce::Max>;
using ReduceElementOps = ck::Tuple<ck::tensor_operation::element_wise::PassThrough>;
using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
static constexpr auto GemmSpecialization =
@@ -53,11 +53,11 @@ static constexpr auto GemmSpecialization =
// clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DsReduceOp, DsElementOp, DsElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceElementOps, ReduceElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
@@ -68,12 +68,12 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
BElementOp,
CElementOp>;
template <typename ADataType, typename BDataType, typename CDataType, typename DDataType>
template <typename ADataType, typename BDataType, typename CDataType, typename ReduceDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
{
std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(DDataType) * M;
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
@@ -148,17 +148,17 @@ int main(int argc, char* argv[])
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d_m_host_result(
Tensor<ReduceDataType> reduce_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d_m_device_result(
Tensor<ReduceDataType> reduce_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "d_m: " << d_m_host_result.mDesc << std::endl;
std::cout << "reduce_m: " << reduce_m_host_result.mDesc << std::endl;
switch(init_method)
{
@@ -176,35 +176,40 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce_device_buf(sizeof(ReduceDataType) *
reduce_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto ds_element_op = DsElementOp{};
auto p_ds_global = ck::make_tuple(static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()));
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto reduce_element_op = ReduceElementOps{}[ck::Number<0>{}];
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
std::array<void*, 1> reduce_element_ops = {&reduce_element_op};
std::array<void*, 1> p_reduces = {reduce_device_buf.GetDeviceBuffer()};
// do GEMM
auto gemm = DeviceGemmReduceInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
p_ds_global,
auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
nullptr,
{},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
ds_element_op,
ds_element_op);
{},
gemm_element_ops,
{},
reduce_element_ops,
reduce_element_ops);
if(!gemm.IsSupportedArgument(argument))
{
@@ -215,7 +220,7 @@ int main(int argc, char* argv[])
// [CAUSION]: launch_and_time_kernel will not initialize D.
// If we evaluate kernel multiple time but without initialize D. Verification will fail
d_device_buf.SetValue(ck::NumericLimits<DDataType>::Lowest());
reduce_device_buf.SetValue(ck::NumericLimits<ReduceDataType>::Lowest());
invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true;
@@ -223,7 +228,7 @@ int main(int argc, char* argv[])
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
d_device_buf.FromDevice(d_m_device_result.mData.data());
reduce_device_buf.FromDevice(reduce_m_device_result.mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
@@ -233,27 +238,27 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
auto d_reduce_op = DsReduceOp{}[ck::Number<0>{}];
auto reduce_op = ReduceOps{}[ck::Number<0>{}];
for(int m = 0; m < M; ++m)
{
ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue<ReduceAccDataType>();
ReduceAccDataType reduce_acc = reduce_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
ReduceAccDataType curr_val =
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
d_reduce_op(d_acc, curr_val);
reduce_op(reduce_acc, curr_val);
};
d_m_host_result(m) = d_acc;
reduce_m_host_result(m) = reduce_acc;
}
pass = ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData,
"Error: Incorrect results c") &&
ck::utils::check_err(d_m_device_result.mData,
d_m_host_result.mData,
ck::utils::check_err(reduce_m_device_result.mData,
reduce_m_host_result.mData,
"Error: Incorrect results d",
1e-3,
1e-3);
@@ -263,7 +268,7 @@ int main(int argc, char* argv[])
{
float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, DDataType>(
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, ReduceDataType>(
gemm_reduceMax_ave_time, M, N, K);
}

View File

@@ -33,27 +33,27 @@ using BDataType = F16;
using CDataType = F16;
using GemmAccDataType = F32;
using ReduceAccDataType = F32;
using DDataType = F32;
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
using ReduceDataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add;
using D1ReduceOp = ck::reduce::Add;
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using ReduceOp0 = ck::reduce::Add;
using ReduceOp1 = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceOp0, ReduceOp1>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DGlobalMemOp =
using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
@@ -62,11 +62,11 @@ static constexpr auto GemmSpecialization =
// clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceDData| A| B| C| Reduce| ReduceInEleOp| ReduceOutEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
@@ -77,13 +77,13 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
BElementOp,
CElementOp>;
template <typename ADataType, typename BDataType, typename CDataType, typename DDataType>
template <typename ADataType, typename BDataType, typename CDataType, typename ReduceDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
{
std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(DDataType) * M +
sizeof(DDataType) * M;
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
@@ -158,22 +158,22 @@ int main(int argc, char* argv[])
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_host_result(
Tensor<ReduceDataType> reduce0_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_host_result(
Tensor<ReduceDataType> reduce1_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_device_result(
Tensor<ReduceDataType> reduce0_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_device_result(
Tensor<ReduceDataType> reduce1_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl;
std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl;
std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl;
std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl;
switch(init_method)
{
@@ -191,39 +191,48 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace());
DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
reduce0_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
auto dxs_in_element_op = DxsInElementOps{};
auto dxs_out_element_op = DxsOutElementOps{N, N};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
// do GEMM
auto gemm = DeviceGemmReduceInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
dxs_global,
auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
nullptr,
{},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op);
{},
gemm_element_ops,
{},
reduce_in_element_ops,
reduce_out_element_ops);
if(!gemm.IsSupportedArgument(argument))
{
@@ -232,9 +241,9 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}
// init DO, D1 to 0
d0_device_buf.SetZero();
d1_device_buf.SetZero();
// init reducetion buffer to 0
reduce0_device_buf.SetZero();
reduce1_device_buf.SetZero();
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
// will not be correct. need to set time_kernel = false for correctness test
@@ -244,8 +253,8 @@ int main(int argc, char* argv[])
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
d0_device_buf.FromDevice(d0_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_m_device_result.mData.data());
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
@@ -255,42 +264,40 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
auto d0_reduce_op = D0ReduceOp{};
auto d1_reduce_op = D1ReduceOp{};
auto reduce0_op = ReduceOp0{};
auto reduce1_op = ReduceOp1{};
for(int m = 0; m < M; ++m)
{
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
ReduceAccDataType d0_val;
ReduceAccDataType d1_val;
ReduceAccDataType square_c_val;
square(square_c_val, c_val);
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val);
reduce0_op(reduce0_acc, c_val);
reduce1_op(reduce1_acc, square_c_val);
}
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc);
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc);
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
div(reduce0_acc, reduce0_acc);
div(reduce1_acc, reduce1_acc);
reduce0_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce0_acc);
reduce1_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce1_acc);
}
pass = ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData,
"Error: Incorrect results c") &&
ck::utils::check_err(d0_m_device_result.mData,
d0_m_host_result.mData,
ck::utils::check_err(reduce0_m_device_result.mData,
reduce0_m_host_result.mData,
"Error: Incorrect results d0",
1e-4,
1e-5) &&
ck::utils::check_err(d1_m_device_result.mData,
d1_m_host_result.mData,
ck::utils::check_err(reduce1_m_device_result.mData,
reduce1_m_host_result.mData,
"Error: Incorrect results d1",
1e-3,
1e-5);
@@ -300,7 +307,7 @@ int main(int argc, char* argv[])
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, DDataType>(ave_time, M, N, K);
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, ReduceDataType>(ave_time, M, N, K);
}
return pass ? 0 : 1;

View File

@@ -31,26 +31,26 @@ using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using ReduceAccDataType = F32;
using DDataType = F32;
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
using ReduceDataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add;
using D1ReduceOp = ck::reduce::Add;
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using ReduceOp0 = ck::reduce::Add;
using ReduceOp1 = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceOp0, ReduceOp1>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using ReduceOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
using DGlobalMemOp =
using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
@@ -63,7 +63,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
@@ -143,16 +143,16 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
Tensor<ReduceDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
Tensor<ReduceDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
Tensor<ReduceDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
Tensor<ReduceDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
@@ -177,38 +177,48 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace());
DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
d0_g_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
d1_g_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&passthrough, &passthrough};
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
// do GEMM
auto batched_gemm = DeviceBatchedGemmReduceInstance{};
auto invoker = batched_gemm.MakeInvoker();
auto argument =
batched_gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
dxs_global,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
DxsInElementOps{},
DxsOutElementOps{},
BatchCount);
auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
nullptr,
{},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M,
N,
K,
StrideA,
StrideB,
StrideC,
{},
gemm_element_ops,
{},
reduce_in_element_ops,
reduce_out_element_ops,
BatchCount);
if(!batched_gemm.IsSupportedArgument(argument))
{
@@ -218,8 +228,8 @@ int main(int argc, char* argv[])
}
// init DO, D1 to 0
d0_device_buf.SetZero();
d1_device_buf.SetZero();
reduce0_device_buf.SetZero();
reduce1_device_buf.SetZero();
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
// will not be correct. need to set time_kernel = false for correctness test
@@ -241,8 +251,8 @@ int main(int argc, char* argv[])
if(do_verification)
{
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
d0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
@@ -252,15 +262,15 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
auto d0_reduce_op = D0ReduceOp{};
auto d1_reduce_op = D1ReduceOp{};
auto reduce0_op = ReduceOp0{};
auto reduce1_op = ReduceOp1{};
for(int batch = 0; batch < BatchCount; ++batch)
{
for(int m = 0; m < M; ++m)
{
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
@@ -271,12 +281,12 @@ int main(int argc, char* argv[])
UnaryIdenticElementOp{}(d0_val, c_val);
UnarySquareElementOp{}(d1_val, c_val);
d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val);
reduce0_op(reduce0_acc, d0_val);
reduce1_op(reduce1_acc, d1_val);
}
d0_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d0_acc);
d1_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d1_acc);
d0_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce0_acc);
d1_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce1_acc);
}
}

View File

@@ -99,15 +99,17 @@ int main()
a_m_n_device_buf.ToDevice(a_m_n.mData.data());
b_n_device_buf.ToDevice(b_n.mData.data());
std::array<const void*, 2> input = {a_m_n_device_buf.GetDeviceBuffer(),
b_n_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_n_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {Stride, 1};
std::vector<ck::index_t> b_strides = {0, 1};
std::vector<ck::index_t> c_strides = {Stride, 1};
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(a_m_n_device_buf.GetDeviceBuffer(),
b_n_device_buf.GetDeviceBuffer(),
c_m_n_device_buf.GetDeviceBuffer(),
{M, N},
{Stride, 1},
{0, 1}, // broadcast in first dimension
{Stride, 1},
Add{});
auto argument = broadcastAdd.MakeArgumentPointer(
input, output, {M, N}, {a_strides, b_strides}, {c_strides}, Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{

View File

@@ -81,18 +81,24 @@ int main()
a_m_device_buf.ToDevice(a_m.mData.data());
b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data());
std::array<const void*, 2> input = {a_m_device_buf.GetDeviceBuffer(),
b_m_n_k_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_n_k_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {1, 0, 0};
std::vector<ck::index_t> b_strides{b_m_n_k.mDesc.GetStrides().begin(),
b_m_n_k.mDesc.GetStrides().end()};
std::vector<ck::index_t> c_strides{c_m_n_k.mDesc.GetStrides().begin(),
c_m_n_k.mDesc.GetStrides().end()};
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(
a_m_device_buf.GetDeviceBuffer(),
b_m_n_k_device_buf.GetDeviceBuffer(),
c_m_n_k_device_buf.GetDeviceBuffer(),
std::vector<ck::index_t>{mnk.begin(), mnk.end()},
{1, 0, 0}, // broadcast A on second and third dimension
std::vector<ck::index_t>{b_m_n_k.mDesc.GetStrides().begin(),
b_m_n_k.mDesc.GetStrides().end()},
std::vector<ck::index_t>{c_m_n_k.mDesc.GetStrides().begin(),
c_m_n_k.mDesc.GetStrides().end()},
Add{});
auto argument =
broadcastAdd.MakeArgumentPointer(input,
output,
std::vector<ck::index_t>{mnk.begin(), mnk.end()},
{a_strides, b_strides},
{c_strides},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{

View File

@@ -79,15 +79,17 @@ int main()
a_m_device_buf.ToDevice(a_m.mData.data());
b_m_device_buf.ToDevice(b_m.mData.data());
std::array<const void*, 2> input = {a_m_device_buf.GetDeviceBuffer(),
b_m_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {1};
std::vector<ck::index_t> b_strides = {1};
std::vector<ck::index_t> c_strides = {1};
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(),
b_m_device_buf.GetDeviceBuffer(),
c_m_device_buf.GetDeviceBuffer(),
{M},
{1},
{1},
{1},
Add{});
auto argument = broadcastAdd.MakeArgumentPointer(
input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{

View File

@@ -81,16 +81,22 @@ int main()
a_device_buf.ToDevice(a.mData.data());
b_device_buf.ToDevice(b.mData.data());
std::array<const void*, 2> input = {a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()};
std::vector<ck::index_t> b_strides{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()};
std::vector<ck::index_t> c_strides{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()};
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
std::vector<ck::index_t>{nchw.begin(), nchw.end()},
std::vector<ck::index_t>{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()},
std::vector<ck::index_t>{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()},
std::vector<ck::index_t>{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()},
Add{});
auto argument =
broadcastAdd.MakeArgumentPointer(input,
output,
std::vector<ck::index_t>{nchw.begin(), nchw.end()},
{{a_strides}, b_strides},
{c_strides},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{

View File

@@ -31,12 +31,12 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using C0DataType = F32;
using C1DataType = F16;
using BiasDataType = F32;
using D0DataType = F16;
using GemmAccDataType = F32;
using ReduceAccDataType = F32;
using DDataType = F32;
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
using ReduceDataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using GammaDataType = F16;
using BetaDataType = F16;
using LayerNormOutDataType = F16;
@@ -50,17 +50,17 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = ck::tensor_operation::element_wise::Relu;
using C1ElementOp = PassThrough;
using D0ElementOp = PassThrough;
using ReduceSumOp = ck::reduce::Add;
using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>;
using ReduceOps = ck::Tuple<ReduceSumOp, ReduceSumOp>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DxsGlobalMemOp =
using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
@@ -69,11 +69,11 @@ static constexpr auto GemmSpecialization =
// clang-format off
using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, C1ElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, D0ElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
@@ -89,8 +89,8 @@ using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance =
ck::tensor_operation::device::Device5AryElementwise<CDataType,
DDataType,
DDataType,
ReduceDataType,
ReduceDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType,
@@ -125,10 +125,10 @@ auto f_host_tensor_descriptor2d =
};
template <typename CDataType,
typename DDataType,
typename ReduceDataType,
typename AccDataType,
typename C0DataType,
typename C1DataType,
typename BiasDataType,
typename D0DataType,
typename A_functor,
typename B_functor,
typename C_functor,
@@ -136,8 +136,8 @@ template <typename CDataType,
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
const Tensor<ADataType>& a_m_k,
const Tensor<ADataType>& b_k_n,
const Tensor<C0DataType>& bias_n,
const Tensor<C1DataType>& c1_m_n,
const Tensor<BiasDataType>& bias_n,
const Tensor<D0DataType>& c1_m_n,
const Tensor<GammaDataType>& gamma_n,
const Tensor<GammaDataType>& beta_n,
A_functor a_element_op,
@@ -150,8 +150,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
int StrideC = N;
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<DDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{N};
auto ref_gemm = ReferenceGemmInstance{};
@@ -196,8 +196,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
averageOpInst(mean_acc, mean_acc);
averageOpInst(square_mean_acc, square_mean_acc);
mean_m(m) = ck::type_convert<DDataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<DDataType>(square_mean_acc);
mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc);
}
// LayerNorm
@@ -213,7 +213,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
static_cast<AccDataType>(meanSquare_m(m)),
static_cast<AccDataType>(gamma_n(n)),
static_cast<AccDataType>(beta_n(n)));
out_m_n(m, n) = static_cast<DDataType>(out_acc);
out_m_n(m, n) = static_cast<ReduceDataType>(out_acc);
}
}
}
@@ -221,9 +221,9 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
template <typename ADataType,
typename BDataType,
typename CDataType,
typename C0DataType,
typename C1DataType,
typename DDataType,
typename BiasDataType,
typename D0DataType,
typename ReduceDataType,
typename GammaDataType,
typename BetaDataType,
typename NormalizeDataType>
@@ -231,12 +231,12 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
{
std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(C0DataType) * M * N +
sizeof(C1DataType) * M * N + sizeof(DDataType) * M +
sizeof(DDataType) * M;
sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N +
sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M;
std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(DDataType) * M +
sizeof(DDataType) * M + sizeof(GammaDataType) * N +
std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
@@ -260,15 +260,15 @@ int main()
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideC = 1024;
ck::index_t StrideC1 = 1024;
ck::index_t StrideD0 = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<C0DataType> bias_n(f_host_tensor_descriptor1d(N, 1));
Tensor<C1DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<DDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<BiasDataType> bias_n(f_host_tensor_descriptor1d(N, 1));
Tensor<D0DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<LayerNormOutDataType> layerNorm_m_n(
@@ -276,18 +276,18 @@ int main()
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
bias_n.GenerateTensorValue(GeneratorTensor_3<C0DataType>{-1, 1});
c1_m_n.GenerateTensorValue(GeneratorTensor_3<C1DataType>{-5, 5});
bias_n.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{-1, 1});
c1_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-5, 5});
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(C0DataType) * bias_n.mDesc.GetElementSpace());
DeviceMem c1_device_buf(sizeof(C1DataType) * c1_m_n.mDesc.GetElementSpace());
DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace());
DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) *
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(D0DataType) * c1_m_n.mDesc.GetElementSpace());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace());
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
reduceMeanSquare_m.mDesc.GetElementSpace());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
@@ -297,44 +297,45 @@ int main()
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
bias_device_buf.ToDevice(bias_n.mData.data());
c1_device_buf.ToDevice(c1_m_n.mData.data());
d0_device_buf.ToDevice(c1_m_n.mData.data());
gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto c1_element_op = C1ElementOp{};
auto dxs_global =
ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()));
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto d_element_op = D0ElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
auto dxs_in_element_op = DxsInElementOps{};
auto dxs_out_element_op = DxsOutElementOps{N, N};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer()};
// Prepare GEMM, reduce_mean, reduce_mean_square
auto gemmReduce = DeviceGemmBiasAddReduceInstance{};
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
auto gemmReduce_argument =
gemmReduce.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(bias_device_buf.GetDeviceBuffer()),
static_cast<C1DataType*>(c1_device_buf.GetDeviceBuffer()),
dxs_global,
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideC1,
a_element_op,
b_element_op,
c_element_op,
c1_element_op,
dxs_in_element_op,
dxs_out_element_op);
auto gemmReduce = DeviceGemmBiasAddReduceInstance{};
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
bias_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer()},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M,
N,
K,
StrideA,
StrideB,
StrideC,
{StrideD0},
gemm_element_ops,
{&d_element_op},
reduce_in_element_ops,
reduce_out_element_ops);
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
{
@@ -347,23 +348,25 @@ int main()
reduceMeanSquare_device_buf.SetZero();
// Prepare LayerNorm
std::array<const void*, 5> input = {c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()};
auto normalize = DeviceNormalizeInstance{};
auto normalize_invoker = normalize.MakeInvoker();
auto normalize_argument = normalize.MakeArgument(
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()),
static_cast<GammaDataType*>(gamma_device_buf.GetDeviceBuffer()),
static_cast<BetaDataType*>(beta_device_buf.GetDeviceBuffer()),
static_cast<LayerNormOutDataType*>(layerNorm_device_buf.GetDeviceBuffer()),
{M, N},
{StrideC, 1},
{1, 0},
{1, 0},
{0, 1},
{0, 1},
{StrideC, 1},
NormalizeFunctor{});
auto normalize_argument = normalize.MakeArgument(input,
output,
{M, N},
{StrideC, 1},
{1, 0},
{1, 0},
{0, 1},
{0, 1},
{StrideC, 1},
NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument))
{
@@ -381,19 +384,19 @@ int main()
Tensor<LayerNormOutDataType> host_layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
host_gemm_layernorm<CDataType, DDataType, ReduceAccDataType>(host_layerNorm_m_n,
a_m_k,
b_k_n,
bias_n,
c1_m_n,
gamma_n,
beta_n,
a_element_op,
b_element_op,
c_element_op,
c1_element_op,
M,
N);
host_gemm_layernorm<CDataType, ReduceDataType, ReduceAccDataType>(host_layerNorm_m_n,
a_m_k,
b_k_n,
bias_n,
c1_m_n,
gamma_n,
beta_n,
a_element_op,
b_element_op,
c_element_op,
d_element_op,
M,
N);
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
pass &= ck::utils::check_err(layerNorm_m_n.mData,
@@ -416,9 +419,9 @@ int main()
DumpGemmLayerNormPerf<ADataType,
BDataType,
CDataType,
C0DataType,
C1DataType,
DDataType,
BiasDataType,
D0DataType,
ReduceDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>(

View File

@@ -33,8 +33,8 @@ using BDataType = F16;
using CDataType = F16;
using GemmAccDataType = F32;
using ReduceAccDataType = F32;
using DDataType = F32;
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
using ReduceDataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using GammaDataType = F16;
using BetaDataType = F16;
using LayerNormOutDataType = F16;
@@ -48,15 +48,15 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using ReduceSumOp = ck::reduce::Add;
using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>;
using ReduceOps = ck::Tuple<ReduceSumOp, ReduceSumOp>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DxsGlobalMemOp =
using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
@@ -65,11 +65,11 @@ static constexpr auto GemmSpecialization =
// clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
@@ -85,8 +85,8 @@ using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance =
ck::tensor_operation::device::Device5AryElementwise<CDataType,
DDataType,
DDataType,
ReduceDataType,
ReduceDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType,
@@ -121,7 +121,7 @@ auto f_host_tensor_descriptor2d =
};
template <typename CDataType,
typename DDataType,
typename ReduceDataType,
typename A_functor,
typename B_functor,
typename C_functor>
@@ -140,8 +140,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
int StrideC = N;
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<DDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{N};
auto ref_gemm = ReferenceGemmInstance{};
@@ -172,8 +172,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
averageOpInst(mean_acc, mean_acc);
averageOpInst(square_mean_acc, square_mean_acc);
mean_m(m) = ck::type_convert<DDataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<DDataType>(square_mean_acc);
mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc);
}
// LayerNorm
@@ -197,7 +197,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
template <typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename ReduceDataType,
typename GammaDataType,
typename BetaDataType,
typename NormalizeDataType>
@@ -205,11 +205,11 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
{
std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(DDataType) * M +
sizeof(DDataType) * M;
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M;
std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(DDataType) * M +
sizeof(DDataType) * M + sizeof(GammaDataType) * N +
std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
@@ -237,8 +237,8 @@ int main()
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<DDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<LayerNormOutDataType> layerNorm_m_n(
@@ -252,8 +252,8 @@ int main()
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace());
DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) *
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace());
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
reduceMeanSquare_m.mDesc.GetElementSpace());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
@@ -265,35 +265,40 @@ int main()
gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto dxs_global =
ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()));
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
auto dxs_in_element_op = DxsInElementOps{};
auto dxs_out_element_op = DxsOutElementOps{N, N};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer()};
// Prepare GEMM, reduce_mean, reduce_mean_square
auto gemmReduce = DeviceGemmReduceInstance{};
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
auto gemmReduce_argument =
gemmReduce.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
dxs_global,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op);
auto gemmReduce = DeviceGemmReduceInstance{};
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
nullptr,
{},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M,
N,
K,
StrideA,
StrideB,
StrideC,
{},
gemm_element_ops,
{},
reduce_in_element_ops,
reduce_out_element_ops);
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
{
@@ -306,23 +311,25 @@ int main()
reduceMeanSquare_device_buf.SetZero();
// Prepare LayerNorm
std::array<const void*, 5> input = {c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()};
auto normalize = DeviceNormalizeInstance{};
auto normalize_invoker = normalize.MakeInvoker();
auto normalize_argument = normalize.MakeArgument(
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()),
static_cast<GammaDataType*>(gamma_device_buf.GetDeviceBuffer()),
static_cast<BetaDataType*>(beta_device_buf.GetDeviceBuffer()),
static_cast<LayerNormOutDataType*>(layerNorm_device_buf.GetDeviceBuffer()),
{M, N},
{StrideC, 1},
{1, 0},
{1, 0},
{0, 1},
{0, 1},
{StrideC, 1},
NormalizeFunctor{});
auto normalize_argument = normalize.MakeArgument(input,
output,
{M, N},
{StrideC, 1},
{1, 0},
{1, 0},
{0, 1},
{0, 1},
{StrideC, 1},
NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument))
{
@@ -340,16 +347,16 @@ int main()
Tensor<LayerNormOutDataType> host_layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
host_gemm_layernorm<CDataType, DDataType>(host_layerNorm_m_n,
a_m_k,
b_k_n,
gamma_n,
beta_n,
a_element_op,
b_element_op,
c_element_op,
M,
N);
host_gemm_layernorm<CDataType, ReduceDataType>(host_layerNorm_m_n,
a_m_k,
b_k_n,
gamma_n,
beta_n,
a_element_op,
b_element_op,
c_element_op,
M,
N);
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
pass &= ck::utils::check_err(layerNorm_m_n.mData,
@@ -372,7 +379,7 @@ int main()
DumpGemmLayerNormPerf<ADataType,
BDataType,
CDataType,
DDataType,
ReduceDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>(

View File

@@ -10,7 +10,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
@@ -35,7 +35,7 @@ template <typename ADataType,
index_t DScalarPerVector,
index_t EScalarPerVector,
index_t FScalarPerVector>
struct Device5AryElementwise : public BaseOperator
struct Device5AryElementwise : public DeviceElementwise<5, 1, NDim, ElementwiseFunctor>
{
static constexpr auto I0 = Number<0>{};
@@ -268,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator
return true;
};
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const CDataType* p_c,
const DDataType* p_d,
const EDataType* p_e,
FDataType* p_f,
static auto MakeArgument(std::array<const void*, 5> p_inputs,
std::array<void*, 1> p_outputs,
std::vector<index_t> lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
@@ -283,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
{
return Argument{p_a,
p_b,
p_c,
p_d,
p_e,
p_f,
return Argument{static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_inputs[1]),
static_cast<const CDataType*>(p_inputs[2]),
static_cast<const DDataType*>(p_inputs[3]),
static_cast<const EDataType*>(p_inputs[4]),
static_cast<FDataType*>(p_outputs[0]),
lengths,
a_strides,
b_strides,
@@ -299,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator
functor};
}
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_c,
const void* p_d,
const void* p_e,
void* p_f,
std::vector<index_t> lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
std::vector<index_t> d_strides,
std::vector<index_t> e_strides,
std::vector<index_t> f_strides,
ElementwiseFunctor functor)
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, 5> p_inputs,
std::array<void*, 1> p_outputs,
std::vector<index_t> lengths,
std::vector<std::vector<index_t>> input_strides,
std::vector<std::vector<index_t>> output_strides,
ElementwiseFunctor functor) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<const CDataType*>(p_c),
static_cast<const DDataType*>(p_d),
static_cast<const EDataType*>(p_e),
static_cast<FDataType*>(p_f),
return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_inputs[1]),
static_cast<const CDataType*>(p_inputs[2]),
static_cast<const DDataType*>(p_inputs[3]),
static_cast<const EDataType*>(p_inputs[4]),
static_cast<FDataType*>(p_outputs[0]),
lengths,
a_strides,
b_strides,
c_strides,
d_strides,
e_strides,
f_strides,
input_strides[0],
input_strides[1],
input_strides[2],
input_strides[3],
input_strides[4],
output_strides[0],
functor);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Device5aryElementwise"
<< "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< "DScalarPerVector = " << DScalarPerVector
<< "EScalarPerVector = " << EScalarPerVector
<< "FScalarPerVector = " << FScalarPerVector
<< ">";
// clang-format on
return str.str();
}
}; // namespace device
} // namespace device
} // namespace tensor_operation

View File

@@ -1,54 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceBatchedGemmReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
void* p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t Batch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceBatchedGemmReducePtr =
std::unique_ptr<DeviceBatchedGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp"
#include "ck/device_utility/device_prop.hpp"
@@ -23,16 +23,16 @@ namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_MBlock_MPerBlock,
typename ReduceGridDescriptor_MBlock_MPerBlock,
typename ComputeBasePrtOfBatch,
typename Block2CTileMap,
bool HasMainK0BlockLoop>
@@ -44,18 +44,18 @@ __global__ void
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
DPtrsGlobal p_ds_grid,
ReducePtrsGlobal p_reduces_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op,
const DxsReduceAccElementwiseOperation dxs_out_element_op,
const ReduceInElementwiseOperations reduce_in_element_ops,
const ReduceAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock,
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
const Block2CTileMap block_2_ctile_map)
{
@@ -71,10 +71,10 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In)));
p_ds_grid(In) = p_ds_grid(In) + d_batch_offset;
p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset;
});
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
@@ -82,33 +82,33 @@ __global__ void
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_ds_grid,
p_reduces_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op,
reduce_in_element_ops,
reduce_out_element_ops,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock,
reduce_grid_desc_mblock_mperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_ds_grid;
ignore = p_reduces_grid;
ignore = batch_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = dxs_in_element_op;
ignore = dxs_out_element_op;
ignore = reduce_in_element_ops;
ignore = reduce_out_element_ops;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock;
ignore = reduce_grid_desc_mblock_mperblock;
ignore = compute_base_ptr_of_batch_;
ignore = block_2_ctile_map;
#endif
@@ -126,14 +126,14 @@ template <typename ALayout,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename DGlobalMemoryDataOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename ReduceGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
@@ -168,12 +168,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmReduce_Xdl_CShuffle
: public DeviceBatchedGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
{
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
@@ -446,7 +441,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
}
// assume D is packed tensor
static auto MakeDGridDescriptor_M(index_t MRaw)
static auto MakeReduceGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
@@ -474,7 +469,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
struct ComputeBasePtrOfStridedBatch
{
@@ -527,19 +522,19 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
CShuffleDataType,
CDataType,
ReduceAccDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
ReduceOperations,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation,
ReduceGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
DGridDesc_M,
ReduceGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
@@ -582,7 +577,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DPtrsGlobal p_ds_grid,
ReducePtrsGlobal p_reduces_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
@@ -592,31 +587,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ReduceInElementwiseOperations reduce_in_element_ops,
ReduceAccElementwiseOperations reduce_out_element_ops,
index_t Batch)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_ds_grid_{p_ds_grid},
p_reduces_grid_{p_reduces_grid},
Batch_(Batch),
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{},
reduce_grid_desc_mblock_mperblock_{},
compute_base_ptr_of_batch_{
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
type_convert<index_t>(reduce_grid_desc_m_.GetElementSpaceSize())},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
dxs_in_element_op_{dxs_in_element_op},
dxs_out_element_op_{dxs_out_element_op}
reduce_in_element_ops_{reduce_in_element_ops},
reduce_out_element_ops_{reduce_out_element_ops}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
@@ -627,8 +622,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
d_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
reduce_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
}
}
@@ -636,22 +631,23 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
DPtrsGlobal p_ds_grid_;
ReducePtrsGlobal p_reduces_grid_;
index_t Batch_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
DGridDesc_M d_grid_desc_m_;
ReduceGridDesc_M reduce_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsReduceAccElementwiseOperation dxs_out_element_op_;
ReduceInElementwiseOperations reduce_in_element_ops_;
ReduceAccElementwiseOperations reduce_out_element_ops_;
};
// Invoker
@@ -678,7 +674,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}"
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
@@ -704,16 +700,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
ComputeBasePtrOfStridedBatch,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
@@ -727,17 +723,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_ds_grid_,
arg.p_reduces_grid_,
arg.Batch_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.reduce_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
}
@@ -747,16 +743,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
ComputeBasePtrOfStridedBatch,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
@@ -770,17 +766,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_ds_grid_,
arg.p_reduces_grid_,
arg.Batch_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.reduce_grid_desc_mblock_mperblock_,
arg.compute_base_ptr_of_batch_,
arg.block_2_ctile_map_);
}
@@ -824,38 +820,76 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
}
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
DPtrsGlobal p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
static constexpr int NumReduce = ReduceOperations::Size();
static auto MakeArgument(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 0> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
index_t Batch)
{
return Argument{p_a,
p_b,
p_c,
p_dxs,
MRaw,
NRaw,
KRaw,
(void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op,
reduce_in_element_ops,
reduce_out_element_ops,
Batch};
}
@@ -865,37 +899,73 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c,
void* p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t Batch) override
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 0> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
index_t Batch = 1) override
{
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
(void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
dxs_tuple,
MRaw,
NRaw,
KRaw,
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op,
reduce_in_element_ops,
reduce_out_element_ops,
Batch);
}

View File

@@ -9,6 +9,7 @@
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace ck {
@@ -25,7 +26,7 @@ template <typename ADataType,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator
struct DeviceBinaryElementwise : public DeviceElementwise<2, 1, NDim, ElementwiseFunctor>
{
static constexpr auto I0 = Number<0>{};
@@ -198,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
std::vector<index_t> lengths,
std::vector<index_t> a_strides,
std::vector<index_t> b_strides,
std::vector<index_t> c_strides,
ElementwiseFunctor functor)
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, 2> p_inputs,
std::array<void*, 1> p_outputs,
std::vector<index_t> lengths,
std::vector<std::vector<index_t>> input_strides,
std::vector<std::vector<index_t>> output_strides,
ElementwiseFunctor functor) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_inputs[1]),
static_cast<CDataType*>(p_outputs[0]),
lengths,
a_strides,
b_strides,
c_strides,
input_strides[0],
input_strides[1],
output_strides[0],
functor);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
@@ -226,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off
str << "DeviceBinaryElementwise"
<< "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< ">";
// clang-format on

View File

@@ -0,0 +1,40 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <ck::index_t NumInputTensor,
ck::index_t NumOutputTensor,
index_t NDim,
typename ElementwiseFunctor>
struct DeviceElementwise : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumInputTensor> p_inputs,
std::array<void*, NumOutputTensor> p_outputs,
std::vector<index_t> lengths,
std::vector<std::vector<index_t>> input_strides,
std::vector<std::vector<index_t>> output_strides,
ElementwiseFunctor functor) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <ck::index_t NumInputTensor,
ck::index_t NumOutputTensor,
index_t NDim,
typename ElementwiseFunctor>
using DeviceElementwisePtr =
std::unique_ptr<DeviceElementwise<NumInputTensor, NumOutputTensor, NDim, ElementwiseFunctor>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -29,20 +29,20 @@ template <typename ALayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename C0DataType,
typename C1DataType,
typename BiasDataType,
typename D0DataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename DGlobalMemoryDataOperation,
typename D0ElementwiseOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename ReduceGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
@@ -77,13 +77,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmBiasAddReduce_Xdl_CShuffle
: public DeviceGemmBiasAddReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>
struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceOperations::Size()>
{
using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle;
@@ -356,7 +350,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
}
// assume D is packed tensor
static auto MakeDGridDescriptor_M(index_t MRaw)
static auto MakeReduceGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
@@ -386,7 +380,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0));
using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
@@ -394,25 +388,25 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GemmAccDataType,
CShuffleDataType,
CDataType,
C0DataType,
C1DataType,
BiasDataType,
D0DataType,
ReduceAccDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
D0ElementwiseOperation,
ReduceOperations,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation,
ReduceGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
C0GridDesc_M_N,
C1GridDesc_M_N,
DGridDesc_M,
ReduceGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
@@ -455,9 +449,9 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
const C0DataType* p_c0_grid,
const C1DataType* p_c1_grid,
DPtrsGlobal p_ds_grid,
const BiasDataType* p_bias_grid,
const D0DataType* p_d0_grid,
ReducePtrsGlobal p_reduces_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
@@ -468,32 +462,32 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op)
D0ElementwiseOperation d0_element_op,
ReduceInElementwiseOperations reduce_in_element_ops,
ReduceAccElementwiseOperations reduce_out_element_ops)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_c0_grid_{p_c0_grid},
p_c1_grid_{p_c1_grid},
p_ds_grid_{p_ds_grid},
p_bias_grid_{p_bias_grid},
p_d0_grid_{p_d0_grid},
p_reduces_grid_{p_reduces_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)},
c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
c0_grid_desc_mblock_mperblock_nblock_nperblock_{},
c1_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{},
reduce_grid_desc_mblock_mperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
c1_element_op_{c1_element_op},
dxs_in_element_op_{dxs_in_element_op},
dxs_out_element_op_{dxs_out_element_op}
d0_element_op_{d0_element_op},
reduce_in_element_ops_{reduce_in_element_ops},
reduce_out_element_ops_{reduce_out_element_ops}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
@@ -512,8 +506,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c1_grid_desc_m_n_);
d_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
reduce_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
}
}
@@ -521,29 +515,30 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
const C0DataType* p_c0_grid_;
const C1DataType* p_c1_grid_;
DPtrsGlobal p_ds_grid_;
const BiasDataType* p_bias_grid_;
const D0DataType* p_d0_grid_;
ReducePtrsGlobal p_reduces_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
C0GridDesc_M_N c0_grid_desc_m_n_;
C1GridDesc_M_N c1_grid_desc_m_n_;
DGridDesc_M d_grid_desc_m_;
ReduceGridDesc_M reduce_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
C1ElementwiseOperation c1_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsReduceAccElementwiseOperation dxs_out_element_op_;
D0ElementwiseOperation d0_element_op_;
ReduceInElementwiseOperations reduce_in_element_ops_;
ReduceAccElementwiseOperations reduce_out_element_ops_;
};
// Invoker
@@ -574,21 +569,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
C0DataType,
C1DataType,
DPtrsGlobal,
BiasDataType,
D0DataType,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
D0ElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
@@ -601,21 +596,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.p_ds_grid_,
arg.p_bias_grid_,
arg.p_d0_grid_,
arg.p_reduces_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.c1_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.d0_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.reduce_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
}
else
@@ -624,21 +619,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
C0DataType,
C1DataType,
DPtrsGlobal,
BiasDataType,
D0DataType,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
D0ElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
@@ -651,21 +646,21 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.p_ds_grid_,
arg.p_bias_grid_,
arg.p_d0_grid_,
arg.p_reduces_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.c1_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.d0_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.reduce_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
}
@@ -700,45 +695,76 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const C0DataType* p_c0,
const C1DataType* p_c1,
DPtrsGlobal p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op)
static constexpr int NumReduce = ReduceOperations::Size();
static auto MakeArgument(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 1> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 1> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 1> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op)
{
return Argument{p_a,
p_b,
p_c,
p_c0,
p_c1,
p_dxs,
MRaw,
NRaw,
KRaw,
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
D0ElementwiseOperation d_element_op =
*(static_cast<D0ElementwiseOperation*>(d_element_ops[0]));
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<const BiasDataType*>(p_bias),
static_cast<const D0DataType*>(p_ds[0]),
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideC1,
StrideDs[0],
a_element_op,
b_element_op,
c_element_op,
c1_element_op,
dxs_in_element_op,
dxs_out_element_op};
d_element_op,
reduce_in_element_ops,
reduce_out_element_ops};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -747,45 +773,74 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 1> p_ds,
void* p_c,
const void* p_c0,
const void* p_c1,
void* p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 1> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 1> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
index_t /* KBatch */ = 1) override
{
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
D0ElementwiseOperation d_element_op =
*(static_cast<D0ElementwiseOperation*>(d_element_ops[0]));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<const C0DataType*>(p_c0),
static_cast<const C1DataType*>(p_c1),
dxs_tuple,
MRaw,
NRaw,
KRaw,
static_cast<const BiasDataType*>(p_bias),
static_cast<const D0DataType*>(p_ds[0]),
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideC1,
StrideDs[0],
a_element_op,
b_element_op,
c_element_op,
c1_element_op,
dxs_in_element_op,
dxs_out_element_op);
d_element_op,
reduce_in_element_ops,
reduce_out_element_ops);
}
// polymorphic
@@ -800,7 +855,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmReduce_Xdl_CShuffle"
str << "DeviceGemmBiasAddReduce_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "

View File

@@ -9,91 +9,34 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
template <ck::index_t NumDTensor, ck::index_t NumReduce>
struct DeviceGemmReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
void* p_dxs,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
std::array<ck::index_t, NumDTensor> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, NumDTensor> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_ops,
std::array<void*, NumReduce> reduce_out_element_ops,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmBiasAddReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
void* p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr =
std::unique_ptr<DeviceGemmBiasAddReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
template <ck::index_t NumDTensor, ck::index_t NumReduce>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<NumDTensor, NumReduce>>;
} // namespace device
} // namespace tensor_operation

View File

@@ -32,14 +32,14 @@ template <typename ALayout,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename DGlobalMemoryDataOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename ReduceGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
@@ -74,11 +74,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()>
{
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
@@ -350,8 +346,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
}
}
// assume D is packed tensor
static auto MakeDGridDescriptor_M(index_t MRaw)
// assume Reduce is packed tensor
static auto MakeReduceGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
@@ -379,7 +375,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
@@ -388,19 +384,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CShuffleDataType,
CDataType,
ReduceAccDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
ReduceOperations,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation,
ReduceGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
DGridDesc_M,
ReduceGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
@@ -443,7 +439,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DPtrsGlobal p_ds_grid,
ReducePtrsGlobal p_reduces_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
@@ -453,24 +449,24 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op)
ReduceInElementwiseOperations reduce_in_element_ops,
ReduceAccElementwiseOperations reduce_out_element_ops)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_ds_grid_{p_ds_grid},
p_reduces_grid_{p_reduces_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{},
reduce_grid_desc_mblock_mperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
dxs_in_element_op_{dxs_in_element_op},
dxs_out_element_op_{dxs_out_element_op}
reduce_in_element_ops_{reduce_in_element_ops},
reduce_out_element_ops_{reduce_out_element_ops}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
@@ -481,8 +477,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
d_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
reduce_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
}
}
@@ -490,20 +486,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
DPtrsGlobal p_ds_grid_;
ReducePtrsGlobal p_reduces_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
DGridDesc_M d_grid_desc_m_;
ReduceGridDesc_M reduce_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsReduceAccElementwiseOperation dxs_out_element_op_;
ReduceInElementwiseOperations reduce_in_element_ops_;
ReduceAccElementwiseOperations reduce_out_element_ops_;
};
// Invoker
@@ -528,7 +525,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}"
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
@@ -554,16 +551,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
@@ -576,16 +573,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_ds_grid_,
arg.p_reduces_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.reduce_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
}
else
@@ -594,16 +591,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DPtrsGlobal,
ReducePtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
@@ -616,16 +613,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_ds_grid_,
arg.p_reduces_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_,
arg.reduce_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_);
}
@@ -660,37 +657,75 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
DPtrsGlobal p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op)
static constexpr int NumReduce = ReduceOperations::Size();
static auto MakeArgument(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 0> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op)
{
return Argument{p_a,
p_b,
p_c,
p_dxs,
MRaw,
NRaw,
KRaw,
(void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op};
reduce_in_element_ops,
reduce_out_element_ops};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -699,37 +734,73 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c,
void* p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 0> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
ck::index_t = 1) override
{
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
(void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
dxs_tuple,
MRaw,
NRaw,
KRaw,
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op);
reduce_in_element_ops,
reduce_out_element_ops);
}
// polymorphic

View File

@@ -23,19 +23,19 @@ template <typename GridwiseGemm,
typename FloatC,
typename FloatC0,
typename FloatC1,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_MBlock_MPerBlock,
typename ReduceGridDescriptor_MBlock_MPerBlock,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
@@ -46,15 +46,15 @@ __global__ void
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid,
DPtrsGlobal p_ds_grid,
const FloatC0* __restrict__ p_bias_grid,
const FloatC1* __restrict__ p_d0_grid,
ReducePtrsGlobal p_reduces_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const C1ElementwiseOperation c1_element_op,
const DxsInElementwiseOperation dxs_in_element_op,
const DxsReduceAccElementwiseOperation dxs_out_element_op,
const ReduceInElementwiseOperations reduce_in_element_ops,
const ReduceAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
@@ -63,7 +63,7 @@ __global__ void
c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock,
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
@@ -72,42 +72,42 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_c0_grid,
p_c1_grid,
p_ds_grid,
p_bias_grid,
p_d0_grid,
p_reduces_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
c1_element_op,
dxs_in_element_op,
dxs_out_element_op,
reduce_in_element_ops,
reduce_out_element_ops,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c1_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock,
reduce_grid_desc_mblock_mperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_c0_grid;
ignore = p_c1_grid;
ignore = p_ds_grid;
ignore = p_bias_grid;
ignore = p_d0_grid;
ignore = p_reduces_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = c1_element_op;
ignore = dxs_in_element_op;
ignore = dxs_out_element_op;
ignore = reduce_in_element_ops;
ignore = reduce_out_element_ops;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock;
ignore = reduce_grid_desc_mblock_mperblock;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
@@ -119,22 +119,22 @@ template <typename FloatAB,
typename FloatC0,
typename FloatC1,
typename FloatReduceAcc,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation,
typename ReduceGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
typename C0GridDesc_M_N,
typename C1GridDesc_M_N,
typename DGridDesc_M,
typename ReduceGridDesc_M,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
@@ -321,18 +321,18 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
__host__ __device__ static constexpr auto
MakeDGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m)
MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m)
{
const auto M = d_grid_desc_m.GetLength(I0);
const auto MBlock = M / MPerBlock;
const auto d_grid_desc_mblock_mperblock = transform_tensor_descriptor(
const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
d_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
return d_grid_desc_mblock_mperblock;
return reduce_grid_desc_mblock_mperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
@@ -352,36 +352,37 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>;
using DGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeDGridDescriptor_MBlock_MPerBlock(DGridDesc_M{}))>;
using ReduceGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid,
DPtrsGlobal p_ds_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const C1ElementwiseOperation& c1_element_op,
const DxsInElementwiseOperation& dxs_in_element_op,
const DxsReduceAccElementwiseOperation& dxs_out_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map)
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_bias_grid,
const FloatC1* __restrict__ p_d0_grid,
ReducePtrsGlobal p_reduces_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const C1ElementwiseOperation& c1_element_op,
const ReduceInElementwiseOperations& reduce_in_element_ops,
const ReduceAccElementwiseOperations& reduce_out_element_ops,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
@@ -390,9 +391,9 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
p_bias_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c1_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
p_d0_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
@@ -725,12 +726,12 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed(
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mperblock
constexpr auto d_reduce_thread_desc_mperblock =
// VGPR reduce_thread_desc_mperblock
constexpr auto reduce_thread_desc_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mblock_mperblock
constexpr auto d_reduce_thread_desc_mblock_mperblock =
// VGPR reduce_thread_desc_mblock_mperblock
constexpr auto reduce_thread_desc_mblock_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
@@ -759,29 +760,29 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple(
auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
[&](auto I) {
auto p_d_grid = p_ds_grid[I];
auto d_out_element_op = dxs_out_element_op[I];
auto p_reduce_grid = p_reduces_grid[I];
auto reduce_acc_element_op = reduce_out_element_ops[I];
return ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc,
remove_pointer_t<decltype(p_d_grid)>,
decltype(d_reduce_thread_desc_mblock_mperblock),
decltype(d_grid_desc_mblock_mperblock),
decltype(d_out_element_op),
remove_pointer_t<decltype(p_reduce_grid)>,
decltype(reduce_thread_desc_mblock_mperblock),
decltype(reduce_grid_desc_mblock_mperblock),
decltype(reduce_acc_element_op),
Sequence<1, mreduce_per_thread>,
Sequence<0, 1>,
1,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
DGlobalMemoryDataOperation::At(I),
ReduceGlobalMemoryDataOperation::At(I),
1,
false>{d_grid_desc_mblock_mperblock,
false>{reduce_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx[I0], // mblock
c_reduce_thread_data_idx_begin[I0]), // mperblock
d_out_element_op};
reduce_acc_element_op};
},
Number<p_ds_grid.Size()>{});
Number<p_reduces_grid.Size()>{});
// c0 and c1
constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
@@ -909,35 +910,35 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
auto& p_d_grid = p_ds_grid[In];
static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
auto& p_reduce_grid = p_reduces_grid[In];
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto d_thread_buf =
auto reduce_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
reduce_thread_desc_mperblock.GetElementSpaceSize());
auto& d_in_element_op = dxs_in_element_op[In];
auto& reduce_in_element_op = reduce_in_element_ops[In];
auto& d_reduce_thread_copy_vgpr_to_global =
dxs_reduce_thread_copy_vgpr_to_global(In);
auto& reduce_thread_copy_vgpr_to_global =
reduce_tuple_thread_copy_vgpr_to_global(In);
using DReduceOperation = remove_cvref_t<decltype(DxsReduceOperation{}[In])>;
using ReduceOperation = remove_cvref_t<decltype(ReduceOperations{}[In])>;
using ThreadwiseReduce =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
DReduceOperation,
decltype(reduce_thread_desc_mperblock),
ReduceOperation,
false>;
// Global write Gemm shuffle + reduction
const auto d_zeroVal =
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
const auto reduce_identityVal =
ReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; });
[&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
@@ -946,26 +947,25 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
d_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset));
reduce_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset));
});
});
ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf);
ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
// copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global.Run(
d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
d_thread_buf,
d_grid_desc_mblock_mperblock,
d_grid_buf);
reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
reduce_thread_buf,
reduce_grid_desc_mblock_mperblock,
reduce_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
d_grid_desc_mblock_mperblock,
reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
reduce_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
}
});

View File

@@ -21,16 +21,16 @@ namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_MBlock_MPerBlock,
typename ReduceGridDescriptor_MBlock_MPerBlock,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
@@ -41,17 +41,17 @@ __global__ void
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
DPtrsGlobal p_ds_grid,
ReducePtrsGlobal p_reduces_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op,
const DxsReduceAccElementwiseOperation dxs_out_element_op,
const ReduceInElementwiseOperations reduce_in_element_ops,
const ReduceAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock,
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
@@ -60,32 +60,32 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_ds_grid,
p_reduces_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op,
reduce_in_element_ops,
reduce_out_element_ops,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock,
reduce_grid_desc_mblock_mperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_ds_grid;
ignore = p_reduces_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = dxs_in_element_op;
ignore = dxs_out_element_op;
ignore = reduce_in_element_ops;
ignore = reduce_out_element_ops;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock;
ignore = reduce_grid_desc_mblock_mperblock;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
@@ -95,19 +95,19 @@ template <typename FloatAB,
typename FloatCShuffle,
typename FloatC,
typename FloatReduceAcc,
typename DPtrsGlobal,
typename ReducePtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation,
typename ReduceGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
typename DGridDesc_M,
typename ReduceGridDesc_M,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
@@ -293,18 +293,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
__host__ __device__ static constexpr auto
MakeDGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m)
MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m)
{
const auto M = d_grid_desc_m.GetLength(I0);
const auto MBlock = M / MPerBlock;
const auto d_grid_desc_mblock_mperblock = transform_tensor_descriptor(
const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
d_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
return d_grid_desc_mblock_mperblock;
return reduce_grid_desc_mblock_mperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
@@ -318,29 +318,30 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeDGridDescriptor_MBlock_MPerBlock(DGridDesc_M{}))>;
using ReduceGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
DPtrsGlobal p_ds_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const DxsInElementwiseOperation& dxs_in_element_op,
const DxsReduceAccElementwiseOperation& dxs_out_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map)
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
ReducePtrsGlobal p_reduces_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const ReduceInElementwiseOperations& reduce_in_element_ops,
const ReduceAccElementwiseOperations& reduce_out_element_ops,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
@@ -706,12 +707,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed(
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mperblock
constexpr auto d_reduce_thread_desc_mperblock =
// VGPR reduce_thread_desc_mperblock
constexpr auto reduce_thread_desc_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{}));
// VGPR d_reduce_thread_desc_mblock_mperblock
constexpr auto d_reduce_thread_desc_mblock_mperblock =
// VGPR reduce_thread_desc_mblock_mperblock
constexpr auto reduce_thread_desc_mblock_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
@@ -740,29 +741,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple(
auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
[&](auto I) {
auto p_d_grid = p_ds_grid[I];
auto d_out_element_op = dxs_out_element_op[I];
auto p_reduce_grid = p_reduces_grid[I];
auto reduce_acc_element_op = reduce_out_element_ops[I];
return ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc,
remove_pointer_t<decltype(p_d_grid)>,
decltype(d_reduce_thread_desc_mblock_mperblock),
decltype(d_grid_desc_mblock_mperblock),
decltype(d_out_element_op),
remove_pointer_t<decltype(p_reduce_grid)>,
decltype(reduce_thread_desc_mblock_mperblock),
decltype(reduce_grid_desc_mblock_mperblock),
decltype(reduce_acc_element_op),
Sequence<1, mreduce_per_thread>,
Sequence<0, 1>,
1,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
DGlobalMemoryDataOperation::At(I),
ReduceGlobalMemoryDataOperation::At(I),
1,
false>{d_grid_desc_mblock_mperblock,
false>{reduce_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx[I0], // mblock
c_reduce_thread_data_idx_begin[I0]), // mperblock
d_out_element_op};
reduce_acc_element_op};
},
Number<p_ds_grid.Size()>{});
Number<p_reduces_grid.Size()>{});
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
@@ -797,35 +798,35 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(I0, I0),
c_reduce_thread_buf);
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
auto& p_d_grid = p_ds_grid[In];
static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
auto& p_reduce_grid = p_reduces_grid[In];
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto d_thread_buf =
auto reduce_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
reduce_thread_desc_mperblock.GetElementSpaceSize());
auto& d_in_element_op = dxs_in_element_op[In];
auto& reduce_in_element_op = reduce_in_element_ops[In];
auto& d_reduce_thread_copy_vgpr_to_global =
dxs_reduce_thread_copy_vgpr_to_global(In);
auto& reduce_thread_copy_vgpr_to_global =
reduce_tuple_thread_copy_vgpr_to_global(In);
using DReduceOperation = remove_cvref_t<decltype(DxsReduceOperation{}[In])>;
using ReduceOperation = remove_cvref_t<decltype(ReduceOperations{}[In])>;
using ThreadwiseReduce =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
DReduceOperation,
decltype(reduce_thread_desc_mperblock),
ReduceOperation,
false>;
// Global write Gemm shuffle + reduction
const auto d_identityVal =
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
const auto reduce_identityVal =
ReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_identityVal; });
[&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
@@ -834,26 +835,25 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
d_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset));
reduce_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset));
});
});
ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf);
ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
// copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global.Run(
d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
d_thread_buf,
d_grid_desc_mblock_mperblock,
d_grid_buf);
reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
reduce_thread_buf,
reduce_grid_desc_mblock_mperblock,
reduce_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
d_grid_desc_mblock_mperblock,
reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
reduce_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
}
});

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
using Normalize = ck::tensor_operation::element_wise::Normalize;
using DeviceNormalizeFromMeanMeanSquarePtr =
ck::tensor_operation::device::DeviceElementwisePtr<5, 1, 2, Normalize>;
void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(
std::vector<DeviceNormalizeFromMeanMeanSquarePtr>& instances);
template <typename InputType,
typename MeanType,
typename MeanSquareType,
typename GammaDataType,
typename BetaDataType,
typename OutputType>
auto get_device_normalize_from_mean_meansquare_instances()
{
std::vector<DeviceNormalizeFromMeanMeanSquarePtr> op_ptrs;
if constexpr(is_same<InputType, half_t>::value && is_same<MeanType, float>::value &&
is_same<MeanSquareType, float>::value && is_same<GammaDataType, half_t>::value &&
is_same<BetaDataType, half_t>::value && is_same<OutputType, half_t>::value)
{
ck::tensor_operation::device::
add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(op_ptrs);
}
return op_ptrs;
}
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,84 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmAddAddMeanSquareMeanPtr = ck::tensor_operation::device::DeviceGemmReducePtr<1, 2>;
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
auto get_device_gemm_add_add_mean_squaremean_instances()
{
std::vector<DeviceGemmAddAddMeanSquareMeanPtr> op_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
op_ptrs);
}
}
return op_ptrs;
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -5,6 +5,7 @@ function(add_instance_library INSTANCE_NAME)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endfunction(add_instance_library INSTANCE_NAME)
add_subdirectory(elementwise)
add_subdirectory(gemm)
add_subdirectory(gemm_splitk)
add_subdirectory(gemm_bias2d)
@@ -31,6 +32,7 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_gemm_splitk_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_gemm_bias2d_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
@@ -44,6 +46,8 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_convnd_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_elementwise_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_reduce_instance>
)
add_library(composablekernels::device_operations ALIAS device_operations)

View File

@@ -15,9 +15,9 @@ namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
@@ -29,10 +29,10 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Identity, Identity>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
@@ -43,35 +43,31 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
// clang-format on
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmReducePtr<PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>>& instances)
{
add_device_operation_instances(
instances,

View File

@@ -15,9 +15,9 @@ namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
@@ -29,10 +29,10 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Identity, Identity>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
@@ -43,35 +43,31 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
// clang-format on
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmReducePtr<PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>>& instances)
{
add_device_operation_instances(
instances,

View File

@@ -15,9 +15,9 @@ namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
@@ -29,10 +29,10 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Identity, Identity>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
@@ -43,35 +43,31 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
// clang-format on
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmReducePtr<PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>>& instances)
{
add_device_operation_instances(
instances,

View File

@@ -15,9 +15,9 @@ namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
@@ -29,10 +29,10 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Identity, Identity>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
@@ -43,32 +43,28 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>
// clang-format on
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmReducePtr<PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>>& instances)
{
add_device_operation_instances(
instances,

View File

@@ -0,0 +1,10 @@
set(DEVICE_ELEMENTWISE_INSTANCE_SOURCE
device_normalize_instance.cpp
)
add_instance_library(device_elementwise_instance ${DEVICE_ELEMENTWISE_INSTANCE_SOURCE})
target_compile_features(device_elementwise_instance PUBLIC)
set_target_properties(device_elementwise_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_elementwise_instance)

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
using F16 = ck::half_t;
using F32 = float;
using inputType = F16;
using MeanType = F32;
using SquareMeanType = F32;
using GammaDataType = F16;
using BetaDataType = F16;
using outputType = F16;
using Normalize = ck::tensor_operation::element_wise::Normalize;
using device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances = std::tuple<
// clang-format off
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
Device5AryElementwise<F16, F32, F32, F16, F16, F16, F32, Normalize, 2, 8, 8, 1, 1, 8, 8, 8 >,
Device5AryElementwise<F16, F32, F32, F16, F16, F16, F32, Normalize, 2, 4, 4, 1, 1, 4, 4, 4 >,
Device5AryElementwise<F16, F32, F32, F16, F16, F16, F32, Normalize, 2, 2, 2, 1, 1, 2, 2, 2 >,
Device5AryElementwise<F16, F32, F32, F16, F16, F16, F32, Normalize, 2, 1, 1, 1, 1, 1, 1, 1 >
// clang-format on
>;
void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(
std::vector<DeviceElementwisePtr<5, 1, 2, Normalize>>& instances)
{
add_device_operation_instances(
instances, device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances{});
}
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,10 +1,13 @@
set(DEVICE_GEMM_REDUCE_INSTANCE_SOURCE
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
set(DEVICE_GEMM_BIAS_ADD_REDUCE_INSTANCE_SOURCE
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
)
add_instance_library(device_gemm_bias_add_reduce_instance ${DEVICE_GEMM_REDUCE_INSTANCE_SOURCE})
rocm_install(TARGETS device_gemm_bias_add_reduce_instance)
add_library(device_gemm_bias_add_reduce_instance OBJECT ${DEVICE_GEMM_BIAS_ADD_REDUCE_INSTANCE_SOURCE})
target_compile_features(device_gemm_bias_add_reduce_instance PUBLIC)
set_target_properties(device_gemm_bias_add_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_gemm_bias_add_reduce_instance)

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[k, m] * b[k, n]
using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
// clang-format on
>;
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmReducePtr<1, ReduceOps::Size()>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[k, m] * b[n, k]
using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
// clang-format on
>;
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmReducePtr<1, ReduceOps::Size()>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[m, k] * b[n, k]
using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
// clang-format on
>;
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmReducePtr<1, ReduceOps::Size()>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,79 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[m, k] * b[n, k]
using device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>
// clang-format on
>;
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmReducePtr<1, ReduceOps::Size()>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,87 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[k, m] * b[k, n]
using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
// clang-format on
>;
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmBiasAddReducePtr<PassThrough,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,87 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[k, m] * b[n, k]
using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
// clang-format on
>;
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmBiasAddReducePtr<PassThrough,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,87 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[m, k] * b[n, k]
using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout| AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
// clang-format on
>;
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmBiasAddReducePtr<PassThrough,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,84 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[m, k] * b[n, k]
using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>
// clang-format on
>;
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmBiasAddReducePtr<PassThrough,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -16,9 +16,9 @@ namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
@@ -30,11 +30,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
@@ -44,33 +44,31 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// c[m, n] = a[k, m] * b[k, n]
using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple<
// clang-format off
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
// clang-format on
>;
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
std::vector<DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>>& instances)
{
add_device_operation_instances(
instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances{});

View File

@@ -16,9 +16,9 @@ namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
@@ -30,11 +30,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
@@ -44,33 +44,31 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// c[m, n] = a[k, m] * b[n, k]
using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = std::tuple<
// clang-format off
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
// clang-format on
>;
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
std::vector<DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>>& instances)
{
add_device_operation_instances(
instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances{});

View File

@@ -16,9 +16,9 @@ namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
@@ -30,11 +30,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
@@ -44,33 +44,31 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// c[m, n] = a[m, k] * b[n, k]
using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = std::tuple<
// clang-format off
//###########################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
// clang-format on
>;
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
std::vector<DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>>& instances)
{
add_device_operation_instances(
instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances{});

View File

@@ -16,9 +16,9 @@ namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
@@ -30,11 +30,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
@@ -44,30 +44,28 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// c[m, n] = a[m, k] * b[n, k]
using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = std::tuple<
// clang-format off
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>
// clang-format on
>;
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
std::vector<DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>>& instances)
{
add_device_operation_instances(
instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances{});

View File

@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
@@ -21,32 +21,28 @@ namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F32 = float;
using F16 = ck::half_t;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
using F32 = float;
using F16 = ck::half_t;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Identity, Identity>;
using DeviceBatchedGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceBatchedGemmReducePtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
DInElementOps,
DOutElementOps>;
using DeviceGemmReduceNoOpPtr =
ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmReduceNoOpPtr>&);
std::vector<DeviceGemmReduceNoOpPtr>&);
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmReduceNoOpPtr>&);
std::vector<DeviceGemmReduceNoOpPtr>&);
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmReduceNoOpPtr>&);
std::vector<DeviceGemmReduceNoOpPtr>&);
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmReduceNoOpPtr>&);
std::vector<DeviceGemmReduceNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
@@ -59,7 +55,7 @@ namespace profiler {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename ReduceDataType,
typename ALayout,
typename BLayout,
typename CLayout>
@@ -99,16 +95,16 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
Tensor<ReduceDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
Tensor<ReduceDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
Tensor<ReduceDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
Tensor<ReduceDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
@@ -135,20 +131,23 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add;
using D1ReduceOp = ck::reduce::Add;
using ReduceOp0 = ck::reduce::Add;
using ReduceOp1 = ck::reduce::Add;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
const auto dxs_in_element_op = DxsInElementOps{};
const auto dxs_out_element_op = DxsOutElementOps{};
const auto d0_reduce_op = D0ReduceOp{};
const auto d1_reduce_op = D1ReduceOp{};
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
const auto reduce0_op = ReduceOp0{};
const auto reduce1_op = ReduceOp1{};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&passthrough, &passthrough};
if(do_verification)
{
@@ -160,6 +159,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
BElementOp,
CElementOp>;
using ReduceAccDataType = ReduceDataType;
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
@@ -172,21 +173,22 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
{
for(int m = 0; m < M; ++m)
{
float d0_acc = d0_reduce_op.GetIdentityValue<float>();
float d1_acc = d1_reduce_op.GetIdentityValue<float>();
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
float d0_val = ck::type_convert<float>(c_g_m_n_host_result(batch, m, n));
float d1_val;
ReduceAccDataType d0_val =
ck::type_convert<ReduceAccDataType>(c_g_m_n_host_result(batch, m, n));
ReduceAccDataType d1_val;
UnarySquareElementOp{}(d1_val, d0_val);
d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val);
square(d1_val, d0_val);
reduce0_op(reduce0_acc, d0_val);
reduce1_op(reduce1_acc, d1_val);
}
d0_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d0_acc);
d1_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d1_acc);
d0_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce0_acc);
d1_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce1_acc);
}
}
}
@@ -194,17 +196,19 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace());
DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
d0_g_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
d1_g_m_device_result.mDesc.GetElementSpace());
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceBatchedGemmReduceNoOpPtr>
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmReduceNoOpPtr>
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
@@ -257,31 +261,32 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
&dxs_global,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op,
BatchCount);
auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
nullptr,
{},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M,
N,
K,
StrideA,
StrideB,
StrideC,
{},
gemm_element_ops,
{},
reduce_in_element_ops,
reduce_out_element_ops,
BatchCount);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
// init DO, D1 to 0
d0_device_buf.SetZero();
d1_device_buf.SetZero();
reduce0_device_buf.SetZero();
reduce1_device_buf.SetZero();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
@@ -311,8 +316,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
if(do_verification)
{
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
d0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
float c_error = check_error(c_g_m_n_host_result, c_g_m_n_device_result);
float d0_error = check_error(d0_g_m_host_result, d0_g_m_device_result);

View File

@@ -21,33 +21,28 @@ namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F32 = float;
using F16 = ck::half_t;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>;
using F32 = float;
using F16 = ck::half_t;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using DeviceGemmBiasAddReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmBiasAddReducePtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
DInElementOps,
DOutElementOps>;
using DeviceGemmBiasAddReduceNoOpPtr =
ck::tensor_operation::device::DeviceGemmReducePtr<1, ReducePtrsGlobal::Size()>;
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
} // namespace device_gemm_instance
@@ -61,9 +56,9 @@ namespace profiler {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename C0DataType,
typename C1DataType,
typename DDataType,
typename BiasDataType,
typename D0DataType,
typename ReduceDataType,
typename ALayout,
typename BLayout,
typename CLayout>
@@ -77,7 +72,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
int StrideA,
int StrideB,
int StrideC,
int StrideC1)
int StrideD0)
{
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
@@ -102,24 +97,24 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<C0DataType> bias_n(f_host_tensor_descriptor1d(N, 1));
Tensor<C1DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_host_result(
Tensor<BiasDataType> bias_n(f_host_tensor_descriptor1d(N, 1));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduce0_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_host_result(
Tensor<ReduceDataType> reduce1_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_device_result(
Tensor<ReduceDataType> reduce0_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_device_result(
Tensor<ReduceDataType> reduce1_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl;
std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl;
std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl;
std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl;
std::size_t num_thread = 1;
switch(init_method)
@@ -130,50 +125,53 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
bias_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
c1_m_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
d0_m_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
break;
default:
std::srand(0);
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
bias_n.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5}, num_thread);
c1_m_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
d0_m_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
using C1ElementOp = PassThrough;
using D0ReduceOp = ck::reduce::Add;
using D1ReduceOp = ck::reduce::Add;
using D0ElementOp = PassThrough;
using ReduceOp0 = ck::reduce::Add;
using ReduceOp1 = ck::reduce::Add;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
const auto c1_element_op = C1ElementOp{};
const auto d0_reduce_op = D0ReduceOp{};
const auto d1_reduce_op = D1ReduceOp{};
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
auto dxs_in_element_op = DxsInElementOps{};
auto dxs_out_element_op = DxsOutElementOps{N, N};
auto d0_element_op = D0ElementOp{};
const auto reduce0_op = ReduceOp0{};
const auto reduce1_op = ReduceOp1{};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
DDataType,
ReduceDataType,
AElementOp,
BElementOp,
CElementOp>;
using ReduceAccDataType = DDataType;
using ReduceAccDataType = ReduceDataType;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
@@ -189,53 +187,53 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
ReduceAccDataType acc = static_cast<ReduceAccDataType>(c_m_n_host_result(m, n)) +
static_cast<ReduceAccDataType>(bias_n(n));
ReduceAccDataType c1 = static_cast<ReduceAccDataType>(c1_m_n(m, n));
ReduceAccDataType d0 = static_cast<ReduceAccDataType>(d0_m_n(m, n));
c_element_op(acc, acc);
c1_element_op(c1, c1);
acc += c1;
d0_element_op(d0, d0);
acc += d0;
c_m_n_host_result(m, n) = static_cast<CDataType>(acc);
}
for(int m = 0; m < M; ++m)
{
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
ReduceAccDataType c_val =
ReduceAccDataType d0_val =
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
ReduceAccDataType d0_val;
ReduceAccDataType d1_val;
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val);
square(d1_val, d0_val);
reduce0_op(reduce0_acc, d0_val);
reduce1_op(reduce1_acc, d1_val);
}
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc);
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc);
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
div(reduce0_acc, reduce0_acc);
div(reduce1_acc, reduce1_acc);
reduce0_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce0_acc);
reduce1_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce1_acc);
}
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(C0DataType) * bias_n.mDesc.GetElementSpace());
DeviceMem c1_device_buf(sizeof(C1DataType) * c1_m_n.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace());
DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
reduce0_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpace());
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
bias_device_buf.ToDevice(bias_n.mData.data());
c1_device_buf.ToDevice(c1_m_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasAddReduceNoOpPtr>
@@ -249,7 +247,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
@@ -257,7 +255,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
@@ -265,7 +263,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
@@ -273,7 +271,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs);
}
}
@@ -291,34 +289,31 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr = gemm_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(bias_device_buf.GetDeviceBuffer()),
static_cast<C1DataType*>(c1_device_buf.GetDeviceBuffer()),
&dxs_global,
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideC1,
a_element_op,
b_element_op,
c_element_op,
c1_element_op,
dxs_in_element_op,
dxs_out_element_op);
auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
bias_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer()},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M,
N,
K,
StrideA,
StrideB,
StrideC,
{StrideD0},
gemm_element_ops,
{&d0_element_op},
reduce_in_element_ops,
reduce_out_element_ops);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
// init DO, D1 to 0
d0_device_buf.SetZero();
d1_device_buf.SetZero();
reduce0_device_buf.SetZero();
reduce1_device_buf.SetZero();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
@@ -328,9 +323,9 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(C0DataType) * M * N +
sizeof(C1DataType) * M * N + sizeof(DDataType) * M +
sizeof(DDataType) * M;
sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N +
sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
@@ -350,12 +345,12 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
d0_device_buf.FromDevice(d0_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_m_device_result.mData.data());
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
ck::utils::check_err(d0_m_device_result.mData, d0_m_host_result.mData);
ck::utils::check_err(d1_m_device_result.mData, d1_m_host_result.mData);
ck::utils::check_err(reduce0_m_device_result.mData, reduce0_m_host_result.mData);
ck::utils::check_err(reduce1_m_device_result.mData, reduce1_m_host_result.mData);
if(do_log)
{
@@ -365,13 +360,17 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "d0_host: ", d0_m_host_result.mData, ",")
LogRangeAsType<float>(
std::cout << "d0_host: ", reduce0_m_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "d0_device: ", d0_m_device_result.mData, ",")
LogRangeAsType<float>(
std::cout << "d0_device: ", reduce0_m_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "d1_host: ", d1_m_host_result.mData, ",")
LogRangeAsType<float>(
std::cout << "d1_host: ", reduce1_m_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "d1_device: ", d1_m_device_result.mData, ",")
LogRangeAsType<float>(
std::cout << "d1_device: ", reduce1_m_device_result.mData, ",")
<< std::endl;
}
}

View File

@@ -21,21 +21,17 @@ namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F32 = float;
using F16 = ck::half_t;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>;
using F32 = float;
using F16 = ck::half_t;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
DInElementOps,
DOutElementOps>;
using DeviceGemmReduceNoOpPtr =
ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>;
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
@@ -60,7 +56,7 @@ namespace profiler {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename ReduceDataType,
typename ALayout,
typename BLayout,
typename CLayout>
@@ -95,22 +91,22 @@ bool profile_gemm_reduce_impl(int do_verification,
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_host_result(
Tensor<ReduceDataType> reduce0_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_host_result(
Tensor<ReduceDataType> reduce1_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_device_result(
Tensor<ReduceDataType> reduce0_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_device_result(
Tensor<ReduceDataType> reduce1_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl;
std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl;
std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl;
std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl;
std::size_t num_thread = 1;
switch(init_method)
@@ -130,34 +126,37 @@ bool profile_gemm_reduce_impl(int do_verification,
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add;
using D1ReduceOp = ck::reduce::Add;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using ReduceOp0 = ck::reduce::Add;
using ReduceOp1 = ck::reduce::Add;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
const auto d0_reduce_op = D0ReduceOp{};
const auto d1_reduce_op = D1ReduceOp{};
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
auto dxs_in_element_op = DxsInElementOps{};
auto dxs_out_element_op = DxsOutElementOps{N, N};
const auto reduce0_op = ReduceOp0{};
const auto reduce1_op = ReduceOp1{};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
DDataType,
ReduceDataType,
AElementOp,
BElementOp,
CElementOp>;
using ReduceAccDataType = DDataType;
using ReduceAccDataType = ReduceDataType;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
@@ -169,37 +168,37 @@ bool profile_gemm_reduce_impl(int do_verification,
for(int m = 0; m < M; ++m)
{
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
ReduceAccDataType c_val =
ReduceAccDataType d0_val =
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
ReduceAccDataType d0_val;
ReduceAccDataType d1_val;
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val);
square(d1_val, d0_val);
reduce0_op(reduce0_acc, d0_val);
reduce1_op(reduce1_acc, d1_val);
}
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc);
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc);
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
div(reduce0_acc, reduce0_acc);
div(reduce1_acc, reduce1_acc);
reduce0_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce0_acc);
reduce1_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce1_acc);
}
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace());
DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
reduce0_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpace());
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
@@ -258,30 +257,31 @@ bool profile_gemm_reduce_impl(int do_verification,
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
&dxs_global,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op);
auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
nullptr,
{},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M,
N,
K,
StrideA,
StrideB,
StrideC,
{},
gemm_element_ops,
{},
reduce_in_element_ops,
reduce_out_element_ops);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
// init DO, D1 to 0
d0_device_buf.SetZero();
d1_device_buf.SetZero();
reduce0_device_buf.SetZero();
reduce1_device_buf.SetZero();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
@@ -311,12 +311,12 @@ bool profile_gemm_reduce_impl(int do_verification,
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
d0_device_buf.FromDevice(d0_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_m_device_result.mData.data());
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
ck::utils::check_err(d0_m_device_result.mData, d0_m_host_result.mData);
ck::utils::check_err(d1_m_device_result.mData, d1_m_host_result.mData);
ck::utils::check_err(reduce0_m_device_result.mData, reduce0_m_host_result.mData);
ck::utils::check_err(reduce1_m_device_result.mData, reduce1_m_host_result.mData);
if(do_log)
{
@@ -326,13 +326,17 @@ bool profile_gemm_reduce_impl(int do_verification,
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "d0_host: ", d0_m_host_result.mData, ",")
LogRangeAsType<float>(
std::cout << "d0_host: ", reduce0_m_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "d0_device: ", d0_m_device_result.mData, ",")
LogRangeAsType<float>(
std::cout << "d0_device: ", reduce0_m_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "d1_host: ", d1_m_host_result.mData, ",")
LogRangeAsType<float>(
std::cout << "d1_host: ", reduce1_m_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "d1_device: ", d1_m_device_result.mData, ",")
LogRangeAsType<float>(
std::cout << "d1_device: ", reduce1_m_device_result.mData, ",")
<< std::endl;
}
}