mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Softmax client example (#396)
* Update Softmax device operation interface. * Update ckProfiler. * Update Softmax UT. * Update example. * Client example. * Clang format Co-authored-by: Adam Osewski <aosewski@amd.com>
This commit is contained in:
@@ -50,7 +50,7 @@ struct ArgParser
|
||||
|
||||
void print_help()
|
||||
{
|
||||
std::cout << "arg1: tensor operation (layernorm/batchnorm/softmax)\n"
|
||||
std::cout << "arg1: tensor operation (batchnorm/softmax)\n"
|
||||
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
|
||||
<< "arg3: verification (0: no; 1: yes)\n"
|
||||
<< "arg4: initialization (0: no init; 1: integer value; 2: decimal value)\n"
|
||||
@@ -91,31 +91,73 @@ int profile_normalization(int argc, char* argv[])
|
||||
arg_parser.long_opts["alpha"].empty() ? 1 : arg_parser.long_opts["alpha"][0];
|
||||
const index_t beta = arg_parser.long_opts["beta"].empty() ? 0 : arg_parser.long_opts["beta"][0];
|
||||
|
||||
if(data_type == NormDataType::F16_F16)
|
||||
if(length.size() == 3)
|
||||
{
|
||||
ck::profiler::profile_normalization_impl<ck::half_t, float, ck::half_t>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
length,
|
||||
stride,
|
||||
reduce,
|
||||
float(alpha),
|
||||
float(beta),
|
||||
norm_type);
|
||||
if(data_type == NormDataType::F16_F16)
|
||||
{
|
||||
ck::profiler::profile_normalization_impl<ck::half_t, float, ck::half_t, 3>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
length,
|
||||
stride,
|
||||
reduce,
|
||||
float(alpha),
|
||||
float(beta),
|
||||
norm_type);
|
||||
}
|
||||
else if(data_type == NormDataType::F32_F32)
|
||||
{
|
||||
ck::profiler::profile_normalization_impl<float, float, float, 3>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
length,
|
||||
stride,
|
||||
reduce,
|
||||
float(alpha),
|
||||
float(beta),
|
||||
norm_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("not implemented yet");
|
||||
}
|
||||
}
|
||||
else if(data_type == NormDataType::F32_F32)
|
||||
else if(length.size() == 4)
|
||||
{
|
||||
ck::profiler::profile_normalization_impl<float, float, float>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
length,
|
||||
stride,
|
||||
reduce,
|
||||
float(alpha),
|
||||
float(beta),
|
||||
norm_type);
|
||||
if(data_type == NormDataType::F16_F16)
|
||||
{
|
||||
ck::profiler::profile_normalization_impl<ck::half_t, float, ck::half_t, 4>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
length,
|
||||
stride,
|
||||
reduce,
|
||||
float(alpha),
|
||||
float(beta),
|
||||
norm_type);
|
||||
}
|
||||
else if(data_type == NormDataType::F32_F32)
|
||||
{
|
||||
ck::profiler::profile_normalization_impl<float, float, float, 4>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
length,
|
||||
stride,
|
||||
reduce,
|
||||
float(alpha),
|
||||
float(beta),
|
||||
norm_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("not implemented yet");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user