now can build

This commit is contained in:
carlushuang
2024-03-04 20:45:51 +00:00
parent 112d521b09
commit a67473fff8
55 changed files with 829 additions and 534 deletions

View File

@@ -16,12 +16,12 @@ template <typename ADataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename BinaryElementOp = ck_tile::plus<AccDataType>>
void reference_batched_elementwise(const HostTensor<ADataType>& a_b_m_n,
const HostTensor<BDataType>& b_b_m_n,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const BinaryElementOp& binary_element_op = {})
CK_TILE_HOST void reference_batched_elementwise(const HostTensor<ADataType>& a_b_m_n,
const HostTensor<BDataType>& b_b_m_n,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const BinaryElementOp& binary_element_op = {})
{
const ck_tile::index_t N = c_b_m_n.mDesc.get_lengths()[2];

View File

@@ -16,12 +16,12 @@ template <typename ADataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
void reference_batched_gemm(const HostTensor<ADataType>& a_b_m_k,
const HostTensor<BDataType>& b_b_n_k,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
CK_TILE_HOST void reference_batched_gemm(const HostTensor<ADataType>& a_b_m_k,
const HostTensor<BDataType>& b_b_n_k,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const int N = b_b_n_k.mDesc.get_lengths()[1];
const int K = b_b_n_k.mDesc.get_lengths()[2];

View File

@@ -10,7 +10,7 @@
namespace ck_tile {
template <typename CDataType, typename MaskingType>
void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, const MaskingType& mask)
CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, const MaskingType& mask)
{
const int M = c_b_m_n.mDesc.get_lengths()[1];
const int N = c_b_m_n.mDesc.get_lengths()[2];

View File

@@ -10,7 +10,7 @@
namespace ck_tile {
template <typename ADataType, typename CompDataType, typename BDataType>
void reference_batched_softmax(
CK_TILE_HOST void reference_batched_softmax(
const HostTensor<ADataType>& a_b_m_n,
HostTensor<BDataType>& b_b_m_n,
std::optional<std::reference_wrapper<HostTensor<CompDataType>>> lse_b_m = std::nullopt)

View File

@@ -16,12 +16,12 @@ template <typename ADataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
void reference_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_n_k,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_n_k,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const int N = b_n_k.mDesc.get_lengths()[0];
const int K = b_n_k.mDesc.get_lengths()[1];

View File

@@ -10,25 +10,25 @@
namespace ck_tile {
template <typename T>
void reference_im2col(HostTensor<T>& in_mtx_host_ref,
const HostTensor<T>& in_host,
int /*N*/,
int /*K*/,
int C,
int /*Y*/,
int X,
int Hi,
int Wi,
int Ho,
int Wo,
int ConvStrideH,
int ConvStrideW,
int ConvDilationH,
int ConvDilationW,
int InLeftPadH,
int InLeftPadW,
int /*InRightPadH*/,
int /*InRightPadW*/)
CK_TILE_HOST void reference_im2col(HostTensor<T>& in_mtx_host_ref,
const HostTensor<T>& in_host,
int /*N*/,
int /*K*/,
int C,
int /*Y*/,
int X,
int Hi,
int Wi,
int Ho,
int Wo,
int ConvStrideH,
int ConvStrideW,
int ConvDilationH,
int ConvDilationW,
int InLeftPadH,
int InLeftPadW,
int /*InRightPadH*/,
int /*InRightPadW*/)
{
int GemmM = in_mtx_host_ref.get_lengths()[0];
int GemmK = in_mtx_host_ref.get_lengths()[1];

View File

@@ -10,7 +10,7 @@
namespace ck_tile {
template <typename ADataType, typename AccDataType, typename BDataType>
void reference_reduce(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m)
CK_TILE_HOST void reference_reduce(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.get_lengths()[1];

View File

@@ -10,12 +10,13 @@
namespace ck_tile {
template <typename ADataType, typename AccDataType, typename BDataType>
void reference_softmax(const HostTensor<ADataType>& a_m_n, HostTensor<BDataType>& b_m_n)
CK_TILE_HOST void reference_softmax(const HostTensor<ADataType>& a_m_n,
HostTensor<BDataType>& b_m_n)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.get_lengths()[1];
AccDataType v_max = ck_tile::NumericLimits<ADataType>::Lowest();
AccDataType v_max = ck_tile::numeric_limits<ADataType>::Lowest();
// max
for(int n = 0; n < N; ++n)