mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
sync with upstream
This commit is contained in:
42
codegen/include/ck/host/device_gemm_multiple_d.hpp
Normal file
42
codegen/include/ck/host/device_gemm_multiple_d.hpp
Normal file
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include "ck/host/types.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace device_gemm_multiple_d {
|
||||
|
||||
struct Problem
|
||||
{
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
bool TransA = false;
|
||||
bool TransB = false;
|
||||
bool TransE = false;
|
||||
std::vector<bool> DsTrans = {};
|
||||
DataType ADataType = DataType::Half;
|
||||
DataType BDataType = DataType::Half;
|
||||
DataType EDataType = DataType::Half;
|
||||
std::vector<DataType> DsDataType = {};
|
||||
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
|
||||
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
|
||||
std::string CDEElementOp = "ck::Tuple<>";
|
||||
|
||||
std::string GetIncludeHeader() const;
|
||||
|
||||
std::vector<Solution> GetSolutions(const std::string& arch) const;
|
||||
};
|
||||
|
||||
} // namespace device_gemm_multiple_d
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
42
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
Normal file
42
codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
Normal file
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ck/host/types.hpp"
|
||||
#include "ck/host/operation/gemm.hpp"
|
||||
#include "ck/host/device_gemm_multiple_d/problem.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace device_gemm_multiple_d {
|
||||
|
||||
struct Operation_Xdl_CShuffle
|
||||
{
|
||||
static std::vector<std::vector<Operation_Xdl_CShuffle>> CreateOperations();
|
||||
static std::vector<Operation_Xdl_CShuffle> CreateOperations(const Problem& prob);
|
||||
TensorDesc A{};
|
||||
TensorDesc B{};
|
||||
DataType acc = DataType::Float;
|
||||
DataType cs_type = DataType::Half;
|
||||
std::vector<TensorDesc> Ds = {};
|
||||
TensorDesc E{};
|
||||
std::string a_elem_op = PassThrough;
|
||||
std::string b_elem_op = PassThrough;
|
||||
std::string cde_elem_op = Bilinear;
|
||||
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
|
||||
operation::TileDesc tile_desc{};
|
||||
operation::BlockTransferDesc a_block_transfer{};
|
||||
operation::BlockTransferDesc b_block_transfer{};
|
||||
operation::CShuffleDesc cshuffle{};
|
||||
operation::CBlockTransferDesc c_block_transfer{};
|
||||
|
||||
Solution ToSolution() const;
|
||||
};
|
||||
|
||||
} // namespace device_gemm_multiple_d
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
39
codegen/include/ck/host/device_gemm_multiple_d/problem.hpp
Normal file
39
codegen/include/ck/host/device_gemm_multiple_d/problem.hpp
Normal file
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ck/host/types.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace device_gemm_multiple_d {
|
||||
|
||||
struct Problem
|
||||
{
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
bool TransA = false;
|
||||
bool TransB = false;
|
||||
bool TransE = false;
|
||||
std::vector<bool> DsTrans = {};
|
||||
DataType ADataType = DataType::Half;
|
||||
DataType BDataType = DataType::Half;
|
||||
DataType EDataType = DataType::Half;
|
||||
std::vector<DataType> DsDataType = {};
|
||||
std::string AElementOp = PassThrough;
|
||||
std::string BElementOp = PassThrough;
|
||||
std::string CDEElementOp = PassThrough;
|
||||
|
||||
std::string GetIncludeHeader() const;
|
||||
|
||||
std::vector<Solution> GetSolutions(const std::string& arch) const;
|
||||
};
|
||||
|
||||
} // namespace device_gemm_multiple_d
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
18
codegen/include/ck/host/headers.hpp
Normal file
18
codegen/include/ck/host/headers.hpp
Normal file
@@ -0,0 +1,18 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
|
||||
std::unordered_map<std::string_view, std::string_view> GetHeaders();
|
||||
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
49
codegen/include/ck/host/operation/gemm.hpp
Normal file
49
codegen/include/ck/host/operation/gemm.hpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace operation {
|
||||
|
||||
struct TileDesc
|
||||
{
|
||||
int block_size = 0;
|
||||
int m_per_block = 0;
|
||||
int n_per_block = 0;
|
||||
int k_per_block = 0;
|
||||
int ak1 = 0;
|
||||
int bk1 = 0;
|
||||
int m_per_XDL = 0;
|
||||
int n_per_XDL = 0;
|
||||
int m_Xdl_per_wave = 0;
|
||||
int n_Xdl_per_wave = 0;
|
||||
int num_gemmk_prefetch_stage = 0;
|
||||
};
|
||||
struct BlockTransferDesc
|
||||
{
|
||||
std::string thread_cluster_length = "";
|
||||
std::string thread_cluster_arrange_order = "";
|
||||
std::string src_access_order = "";
|
||||
int src_vec_dim = 0;
|
||||
int src_scalar_per_vector = 0;
|
||||
int dst_scalar_per_vector_k1 = 0;
|
||||
int lds_add_extra_dim = 0;
|
||||
};
|
||||
struct CShuffleDesc
|
||||
{
|
||||
int m_Xdl_per_wave_per_shuffle = 0;
|
||||
int n_Xdl_per_wave_per_shuffle = 0;
|
||||
};
|
||||
struct CBlockTransferDesc
|
||||
{
|
||||
std::string cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl = "";
|
||||
int scalar_per_vector_n_wave_n_per_Xdl = 0;
|
||||
};
|
||||
|
||||
} // namespace operation
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
104
codegen/include/ck/host/stringutils.hpp
Normal file
104
codegen/include/ck/host/stringutils.hpp
Normal file
@@ -0,0 +1,104 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
|
||||
template <class F>
|
||||
std::string trim(const std::string& s, F f)
|
||||
{
|
||||
auto start = std::find_if_not(s.begin(), s.end(), f);
|
||||
auto last = std::find_if_not(s.rbegin(), std::string::const_reverse_iterator(start), f).base();
|
||||
return {start, last};
|
||||
}
|
||||
|
||||
inline std::string trim(const std::string& s)
|
||||
{
|
||||
return trim(s, [](unsigned char c) { return std::isspace(c); });
|
||||
}
|
||||
|
||||
template <class Strings>
|
||||
inline std::string JoinStrings(Strings strings, const std::string& delim)
|
||||
{
|
||||
auto it = strings.begin();
|
||||
if(it == strings.end())
|
||||
return "";
|
||||
|
||||
auto nit = std::next(it);
|
||||
return std::accumulate(nit, strings.end(), *it, [&](std::string x, std::string y) {
|
||||
return std::move(x) + delim + std::move(y);
|
||||
});
|
||||
}
|
||||
|
||||
template <class F>
|
||||
inline std::string
|
||||
InterpolateString(const std::string& input, F f, std::string start = "${", std::string end = "}")
|
||||
{
|
||||
std::string result = "";
|
||||
result.reserve(input.size());
|
||||
auto it = input.begin();
|
||||
while(it != input.end())
|
||||
{
|
||||
auto next_start = std::search(it, input.end(), start.begin(), start.end());
|
||||
auto next_end = std::search(next_start, input.end(), end.begin(), end.end());
|
||||
result.append(it, next_start);
|
||||
if(next_start == input.end())
|
||||
break;
|
||||
if(next_end == input.end())
|
||||
{
|
||||
throw std::runtime_error("Unbalanced brackets");
|
||||
}
|
||||
auto r = f(next_start + start.size(), next_end);
|
||||
result.append(r.begin(), r.end());
|
||||
it = next_end + end.size();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
inline std::string InterpolateString(const std::string& input,
|
||||
const std::unordered_map<std::string, std::string>& vars,
|
||||
std::string start = "${",
|
||||
std::string end = "}")
|
||||
{
|
||||
return InterpolateString(
|
||||
input,
|
||||
[&](auto start_it, auto last_it) {
|
||||
auto key = trim({start_it, last_it});
|
||||
auto it = vars.find(key);
|
||||
if(it == vars.end())
|
||||
throw std::runtime_error("Unknown key: " + key);
|
||||
return it->second;
|
||||
},
|
||||
std::move(start),
|
||||
std::move(end));
|
||||
}
|
||||
|
||||
template <class Range, class F>
|
||||
inline auto Transform(const Range& r, F f) -> std::vector<decltype(f(*r.begin()))>
|
||||
{
|
||||
std::vector<decltype(f(*r.begin()))> result;
|
||||
std::transform(r.begin(), r.end(), std::back_inserter(result), f);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class Range1, class Range2, class F>
|
||||
inline auto Transform(const Range1& r1, const Range2& r2, F f)
|
||||
-> std::vector<decltype(f(*r1.begin(), *r2.begin()))>
|
||||
{
|
||||
std::vector<decltype(f(*r1.begin(), *r2.begin()))> result;
|
||||
assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end()));
|
||||
std::transform(r1.begin(), r1.end(), r2.begin(), std::back_inserter(result), f);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
78
codegen/include/ck/host/types.hpp
Normal file
78
codegen/include/ck/host/types.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
|
||||
struct Solution
|
||||
{
|
||||
|
||||
Solution() = default;
|
||||
Solution(std::string str, std::unordered_map<std::string, std::string> values);
|
||||
std::string ToTemplateString() const;
|
||||
std::string GetTemplateParameter(const std::string& name) const;
|
||||
template <class T>
|
||||
T GetTemplateParameter(const std::string& name) const
|
||||
{
|
||||
T result;
|
||||
std::stringstream ss(GetTemplateParameter(name));
|
||||
ss >> result;
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string template_str;
|
||||
std::unordered_map<std::string, std::string> template_values;
|
||||
};
|
||||
|
||||
enum class DataType
|
||||
{
|
||||
Half,
|
||||
Float,
|
||||
Int8,
|
||||
Int32
|
||||
};
|
||||
|
||||
std::string ToString(DataType dt);
|
||||
|
||||
enum class Layout
|
||||
{
|
||||
Row,
|
||||
Column
|
||||
};
|
||||
|
||||
std::string ToString(Layout dl);
|
||||
|
||||
enum class GemmType
|
||||
{
|
||||
Default
|
||||
};
|
||||
|
||||
std::string ToString(GemmType gt);
|
||||
|
||||
struct TensorDesc
|
||||
{
|
||||
DataType element;
|
||||
Layout layout;
|
||||
};
|
||||
|
||||
std::string SequenceStr(const std::vector<int>& v);
|
||||
|
||||
std::string MakeTuple(const std::vector<std::string>& v);
|
||||
|
||||
template <int... xs>
|
||||
const std::string S = SequenceStr({xs...});
|
||||
|
||||
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
|
||||
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
|
||||
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
17
codegen/include/ck/host/utils.hpp
Normal file
17
codegen/include/ck/host/utils.hpp
Normal file
@@ -0,0 +1,17 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
|
||||
std::size_t integer_divide_ceil(std::size_t x, std::size_t y);
|
||||
|
||||
const std::unordered_set<std::string>& get_xdlop_archs();
|
||||
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user