Refactor elementwise kernels (#1222)

* Refactor elementwise kernels

* Instances fixes

* Fix cmake

* Fix max pool bwd test

* Update two stage gemm split k

* Restore elementwise scale for hiptensor backward compatiblity

* Fix Acc data type check in conv fwd multiple abd

* Disable conv fp64 fwd example

* Update grouped conv weight multi d

[ROCm/composable_kernel commit: ad1597c499]
This commit is contained in:
Bartłomiej Kocot
2024-04-19 13:31:17 +02:00
committed by GitHub
parent e7d121a6f0
commit 6578635cb3
38 changed files with 513 additions and 2502 deletions

View File

@@ -18,39 +18,6 @@ enum struct DataType
#define OP_NAME "transpose"
#define OP_DESC "Transpose"
struct TransposeArgParser
{
std::unordered_map<std::string, std::vector<int>> long_opts = {{"lengths", {}}};
bool parse_opt(const int argc, char* argv[], const std::string& key, int i)
{
if(std::string("--") + key == argv[i])
{
const int pos = i;
while(++i < argc && argv[i][0] != '-') {}
int end = i;
for(int j = pos + 1; j < end; j++)
{
long_opts[key].push_back(std::stoi(argv[j]));
}
return true;
}
return false;
}
void operator()(int argc, char* argv[])
{
for(auto& kv : long_opts)
{
for(int i = 1; i < argc; i++)
{
if(parse_opt(argc, argv, kv.first, i))
break;
}
}
}
};
static void print_helper_msg()
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
@@ -59,25 +26,27 @@ static void print_helper_msg()
printf("arg4: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg5: print tensor value (0: no; 1: yes)\n");
printf("arg6: time kernel (0=no, 1=yes)\n");
printf("arg7: --lengths: N, C, D, H, W\n");
printf("arg7 to arg11: N, C, D, H, W\n");
}
int profile_transpose(int argc, char* argv[])
{
if(argc != 7)
if(argc != 12)
{
print_helper_msg();
exit(1);
}
TransposeArgParser arg_parser;
const auto data_type = static_cast<DataType>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const bool time_kernel = std::stoi(argv[6]);
arg_parser(argc, argv);
const std::vector<ck::index_t> lengths = arg_parser.long_opts["lengths"];
const auto data_type = static_cast<DataType>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const bool time_kernel = std::stoi(argv[6]);
const std::vector<ck::index_t> lengths = {std::stoi(argv[7]),
std::stoi(argv[8]),
std::stoi(argv[9]),
std::stoi(argv[10]),
std::stoi(argv[11])};
using F32 = float;
using F16 = ck::half_t;