Merge remote-tracking branch 'origin/develop' into myamlak/cgemm

This commit is contained in:
myamlak
2022-05-17 10:23:36 +00:00
162 changed files with 3223 additions and 2327 deletions

View File

@@ -88,9 +88,9 @@ using ReferenceCGemmInstance = ck::tensor_operation::host::
int main(int argc, char* argv[])
{
bool do_verification = 0;
int init_method = 0;
int nrepeat = 5;
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// CGEMM shape
ck::index_t M = 3840;
@@ -105,13 +105,13 @@ int main(int argc, char* argv[])
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
@@ -223,7 +223,7 @@ int main(int argc, char* argv[])
"not support this CGEMM problem");
}
float ave_time = invoker.Run(argument, nrepeat);
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(8) * M * N * K;
std::size_t num_btype = std::size_t(2) * sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +