sync with upstream

This commit is contained in:
carlushuang
2024-03-26 16:05:54 +00:00
parent 04ee01191a
commit 1c92c5d83d
268 changed files with 16113 additions and 2241 deletions

View File

@@ -0,0 +1,42 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/types.hpp"
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
bool TransA = false;
bool TransB = false;
bool TransE = false;
std::vector<bool> DsTrans = {};
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType EDataType = DataType::Half;
std::vector<DataType> DsDataType = {};
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CDEElementOp = "ck::Tuple<>";
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
};
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck

View File

@@ -0,0 +1,42 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
struct Operation_Xdl_CShuffle
{
static std::vector<std::vector<Operation_Xdl_CShuffle>> CreateOperations();
static std::vector<Operation_Xdl_CShuffle> CreateOperations(const Problem& prob);
TensorDesc A{};
TensorDesc B{};
DataType acc = DataType::Float;
DataType cs_type = DataType::Half;
std::vector<TensorDesc> Ds = {};
TensorDesc E{};
std::string a_elem_op = PassThrough;
std::string b_elem_op = PassThrough;
std::string cde_elem_op = Bilinear;
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
operation::TileDesc tile_desc{};
operation::BlockTransferDesc a_block_transfer{};
operation::BlockTransferDesc b_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
Solution ToSolution() const;
};
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck

View File

@@ -0,0 +1,39 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
bool TransA = false;
bool TransB = false;
bool TransE = false;
std::vector<bool> DsTrans = {};
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType EDataType = DataType::Half;
std::vector<DataType> DsDataType = {};
std::string AElementOp = PassThrough;
std::string BElementOp = PassThrough;
std::string CDEElementOp = PassThrough;
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
};
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck

View File

@@ -0,0 +1,18 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <string_view>
#include <utility>
#include <unordered_map>
#include <vector>
namespace ck {
namespace host {
std::unordered_map<std::string_view, std::string_view> GetHeaders();
} // namespace host
} // namespace ck

View File

@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck {
namespace host {
namespace operation {
struct TileDesc
{
int block_size = 0;
int m_per_block = 0;
int n_per_block = 0;
int k_per_block = 0;
int ak1 = 0;
int bk1 = 0;
int m_per_XDL = 0;
int n_per_XDL = 0;
int m_Xdl_per_wave = 0;
int n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};
struct BlockTransferDesc
{
std::string thread_cluster_length = "";
std::string thread_cluster_arrange_order = "";
std::string src_access_order = "";
int src_vec_dim = 0;
int src_scalar_per_vector = 0;
int dst_scalar_per_vector_k1 = 0;
int lds_add_extra_dim = 0;
};
struct CShuffleDesc
{
int m_Xdl_per_wave_per_shuffle = 0;
int n_Xdl_per_wave_per_shuffle = 0;
};
struct CBlockTransferDesc
{
std::string cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl = "";
int scalar_per_vector_n_wave_n_per_Xdl = 0;
};
} // namespace operation
} // namespace host
} // namespace ck

View File

@@ -0,0 +1,104 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cassert>
#include <numeric>
#include <string>
#include <utility>
#include <unordered_map>
#include <vector>
namespace ck {
namespace host {
template <class F>
std::string trim(const std::string& s, F f)
{
auto start = std::find_if_not(s.begin(), s.end(), f);
auto last = std::find_if_not(s.rbegin(), std::string::const_reverse_iterator(start), f).base();
return {start, last};
}
inline std::string trim(const std::string& s)
{
return trim(s, [](unsigned char c) { return std::isspace(c); });
}
template <class Strings>
inline std::string JoinStrings(Strings strings, const std::string& delim)
{
auto it = strings.begin();
if(it == strings.end())
return "";
auto nit = std::next(it);
return std::accumulate(nit, strings.end(), *it, [&](std::string x, std::string y) {
return std::move(x) + delim + std::move(y);
});
}
template <class F>
inline std::string
InterpolateString(const std::string& input, F f, std::string start = "${", std::string end = "}")
{
std::string result = "";
result.reserve(input.size());
auto it = input.begin();
while(it != input.end())
{
auto next_start = std::search(it, input.end(), start.begin(), start.end());
auto next_end = std::search(next_start, input.end(), end.begin(), end.end());
result.append(it, next_start);
if(next_start == input.end())
break;
if(next_end == input.end())
{
throw std::runtime_error("Unbalanced brackets");
}
auto r = f(next_start + start.size(), next_end);
result.append(r.begin(), r.end());
it = next_end + end.size();
}
return result;
}
inline std::string InterpolateString(const std::string& input,
const std::unordered_map<std::string, std::string>& vars,
std::string start = "${",
std::string end = "}")
{
return InterpolateString(
input,
[&](auto start_it, auto last_it) {
auto key = trim({start_it, last_it});
auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
}
template <class Range, class F>
inline auto Transform(const Range& r, F f) -> std::vector<decltype(f(*r.begin()))>
{
std::vector<decltype(f(*r.begin()))> result;
std::transform(r.begin(), r.end(), std::back_inserter(result), f);
return result;
}
template <class Range1, class Range2, class F>
inline auto Transform(const Range1& r1, const Range2& r2, F f)
-> std::vector<decltype(f(*r1.begin(), *r2.begin()))>
{
std::vector<decltype(f(*r1.begin(), *r2.begin()))> result;
assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end()));
std::transform(r1.begin(), r1.end(), r2.begin(), std::back_inserter(result), f);
return result;
}
} // namespace host
} // namespace ck

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include <utility>
#include <unordered_map>
#include <vector>
namespace ck {
namespace host {
struct Solution
{
Solution() = default;
Solution(std::string str, std::unordered_map<std::string, std::string> values);
std::string ToTemplateString() const;
std::string GetTemplateParameter(const std::string& name) const;
template <class T>
T GetTemplateParameter(const std::string& name) const
{
T result;
std::stringstream ss(GetTemplateParameter(name));
ss >> result;
return result;
}
private:
std::string template_str;
std::unordered_map<std::string, std::string> template_values;
};
enum class DataType
{
Half,
Float,
Int8,
Int32
};
std::string ToString(DataType dt);
enum class Layout
{
Row,
Column
};
std::string ToString(Layout dl);
enum class GemmType
{
Default
};
std::string ToString(GemmType gt);
struct TensorDesc
{
DataType element;
Layout layout;
};
std::string SequenceStr(const std::vector<int>& v);
std::string MakeTuple(const std::vector<std::string>& v);
template <int... xs>
const std::string S = SequenceStr({xs...});
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
} // namespace host
} // namespace ck

View File

@@ -0,0 +1,17 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdint>
#include <unordered_set>
namespace ck {
namespace host {
std::size_t integer_divide_ceil(std::size_t x, std::size_t y);
const std::unordered_set<std::string>& get_xdlop_archs();
} // namespace host
} // namespace ck