mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Apply Ck-tile argument parser for vectors [I/O] (#1758)
* Parser for a vector was added. Additionaly we valid correctnes of numbers
* Remove unnecessary comments
* Review part 1
* Review part 2
* Add const to variadic lambda
* Rename C->K
[ROCm/composable_kernel commit: e758d006a5]
This commit is contained in:
@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "R", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("group_count", "16", "group count");
|
||||
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
|
||||
.insert("Ns", "", "N dimensions - empty by default.")
|
||||
.insert("Ks", "", "K dimensions - empty by default.")
|
||||
.insert("stride_As", "", "Tensor A strides - it is empty by default.")
|
||||
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "R", "B tensor data layout - Row by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "16", "group count.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -53,26 +53,34 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
return -1;
|
||||
};
|
||||
|
||||
auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
return !(args.empty() || ...) && group_count == (args.size() == ...);
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
|
||||
std::vector<ck_tile::index_t> Ms;
|
||||
std::vector<ck_tile::index_t> Ns;
|
||||
std::vector<ck_tile::index_t> Ks;
|
||||
std::vector<ck_tile::index_t> stride_As;
|
||||
std::vector<ck_tile::index_t> stride_Bs;
|
||||
std::vector<ck_tile::index_t> stride_Cs;
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
|
||||
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
|
||||
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(128 + 128 * i);
|
||||
Ks.push_back(128 + 64 * i);
|
||||
std::cout << "Please check the input data. Default values will be used." << std::endl;
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(128 + 128 * i);
|
||||
Ks.push_back(128 + 64 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
|
||||
@@ -15,11 +15,14 @@
|
||||
|
||||
namespace ck_tile {
|
||||
/*
|
||||
* a host side utility, arg parser for
|
||||
* -[key0]=[value0] -[key1]=[value1] ...
|
||||
* a host side utility, arg parser for, either
|
||||
* -[key0] = [value0, value1, value2]
|
||||
* or
|
||||
* -[key0]=[value0] -[key1]=[value1] ...
|
||||
*/
|
||||
class ArgParser
|
||||
{
|
||||
|
||||
public:
|
||||
class Arg
|
||||
{
|
||||
@@ -187,6 +190,45 @@ class ArgParser
|
||||
return value;
|
||||
}
|
||||
|
||||
std::vector<std::string> get_string_vec(const std::string& name,
|
||||
const std::string& delimiter = ",") const
|
||||
{
|
||||
if(get_str(name).empty())
|
||||
{
|
||||
return {};
|
||||
}
|
||||
std::string s = get_str(name);
|
||||
std::vector<std::string> tokens;
|
||||
size_t pos = 0;
|
||||
std::string token;
|
||||
while((pos = s.find(delimiter)) != std::string::npos)
|
||||
{
|
||||
token = s.substr(0, pos);
|
||||
tokens.push_back(token);
|
||||
s.erase(0, pos + delimiter.length());
|
||||
}
|
||||
tokens.push_back(s);
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
std::vector<int> get_int_vec(const std::string& name, const std::string& delimiter = ",") const
|
||||
{
|
||||
if(get_str(name).empty())
|
||||
{
|
||||
return {};
|
||||
}
|
||||
const std::vector<std::string> args = get_string_vec(name, delimiter);
|
||||
std::vector<int> tokens;
|
||||
tokens.reserve(static_cast<int>(args.size()));
|
||||
for(const std::string& token : args)
|
||||
{
|
||||
int value = atoi(token.c_str());
|
||||
tokens.push_back(value);
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, Arg> input_map;
|
||||
std::vector<std::string> keys;
|
||||
|
||||
Reference in New Issue
Block a user