Softmax client example (#396)

* Update Softmax device operation interface.

* Update ckProfiler.

* Update Softmax UT.

* Update example.

* Client example.

* Clang format

Co-authored-by: Adam Osewski <aosewski@amd.com>
This commit is contained in:
Adam Osewski
2022-09-06 19:22:48 +02:00
committed by GitHub
parent 7589116121
commit 3da5c19e62
13 changed files with 738 additions and 331 deletions

View File

@@ -50,7 +50,7 @@ struct ArgParser
void print_help()
{
std::cout << "arg1: tensor operation (layernorm/batchnorm/softmax)\n"
std::cout << "arg1: tensor operation (batchnorm/softmax)\n"
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
<< "arg3: verification (0: no; 1: yes)\n"
<< "arg4: initialization (0: no init; 1: integer value; 2: decimal value)\n"
@@ -91,31 +91,73 @@ int profile_normalization(int argc, char* argv[])
arg_parser.long_opts["alpha"].empty() ? 1 : arg_parser.long_opts["alpha"][0];
const index_t beta = arg_parser.long_opts["beta"].empty() ? 0 : arg_parser.long_opts["beta"][0];
if(data_type == NormDataType::F16_F16)
if(length.size() == 3)
{
ck::profiler::profile_normalization_impl<ck::half_t, float, ck::half_t>(do_verification,
init_method,
do_log,
time_kernel,
length,
stride,
reduce,
float(alpha),
float(beta),
norm_type);
if(data_type == NormDataType::F16_F16)
{
ck::profiler::profile_normalization_impl<ck::half_t, float, ck::half_t, 3>(
do_verification,
init_method,
do_log,
time_kernel,
length,
stride,
reduce,
float(alpha),
float(beta),
norm_type);
}
else if(data_type == NormDataType::F32_F32)
{
ck::profiler::profile_normalization_impl<float, float, float, 3>(do_verification,
init_method,
do_log,
time_kernel,
length,
stride,
reduce,
float(alpha),
float(beta),
norm_type);
}
else
{
throw std::runtime_error("not implemented yet");
}
}
else if(data_type == NormDataType::F32_F32)
else if(length.size() == 4)
{
ck::profiler::profile_normalization_impl<float, float, float>(do_verification,
init_method,
do_log,
time_kernel,
length,
stride,
reduce,
float(alpha),
float(beta),
norm_type);
if(data_type == NormDataType::F16_F16)
{
ck::profiler::profile_normalization_impl<ck::half_t, float, ck::half_t, 4>(
do_verification,
init_method,
do_log,
time_kernel,
length,
stride,
reduce,
float(alpha),
float(beta),
norm_type);
}
else if(data_type == NormDataType::F32_F32)
{
ck::profiler::profile_normalization_impl<float, float, float, 4>(do_verification,
init_method,
do_log,
time_kernel,
length,
stride,
reduce,
float(alpha),
float(beta),
norm_type);
}
else
{
throw std::runtime_error("not implemented yet");
}
}
else
{