Change initialization method of tensor for iGEMM (#49)

* change init method

[ROCm/composable_kernel commit: 0a72e4df94]
This commit is contained in:
Chao Liu
2021-07-16 22:55:01 -05:00
committed by GitHub
parent 9e04c9faef
commit 4c61ba83c6
4 changed files with 101 additions and 79 deletions

View File

@@ -179,26 +179,38 @@ int main(int argc, char* argv[])
std::size_t num_thread = std::thread::hardware_concurrency();
if(do_verification)
switch(init_method)
{
switch(init_method)
{
case 0:
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
break;
case 1:
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
break;
case 2:
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
break;
default:
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
}
case 0:
// no initialization
break;
case 1:
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
break;
case 2:
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
break;
case 3:
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
break;
case 4:
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
break;
case 5:
out.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
break;
default:
out.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
auto gen_wei = [](auto... is) {
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
};
wei.GenerateTensorValue(gen_wei, num_thread);
}
auto f_make_for_device_nchw = [&]() {