mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
BatchNorm forward instance/external api/profiler/tests/client example (#511)
* Update to device_batchnorm_forward base class to include all template parameters for problem description * Add batchnorm forward instances and external api * Add batchnorm forward profiler module which uses the external api * Add some comments in batchnorm_forward example to explain the dimensions in lengths[] * Replace the reference_batchnorm_forward_nhwc_c by generic reference_batchnorm_forward * Improvement to the batchnorm infer base API * Add batchnorm forward client example which shows using the batchnorm forward external API * Add test for batchnorm forward * Tuning the batchnorm profiler initialized values and error threshold * Add support for bhalf_t in instances/external api/tests * Add support for int8_t in instances/external api/tests * Add support for double in instances/external api/tests * Let ScaleDataType and BiasDataType be same as XDataType and YDataType when creating instances * Checking before running best instance in batchnorm_fwd_nhwc client example * Add checking for YElementwiseOp in batchnorm_forward external API * Add more types in batchnorm forward profiler * Add more test lengths Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
This commit is contained in:
@@ -24,6 +24,7 @@ int profile_softmax(int, char*[]);
|
||||
int profile_layernorm(int, char*[]);
|
||||
int profile_groupnorm(int, char*[]);
|
||||
int profile_reduce(int, char*[]);
|
||||
int profile_batchnorm_forward(int, char*[]);
|
||||
|
||||
static void print_helper_message()
|
||||
{
|
||||
@@ -46,7 +47,8 @@ static void print_helper_message()
|
||||
" grouped_conv_fwd: Grouped Convolution Forward\n"
|
||||
" grouped_conv_bwd_weight: Grouped Convolution Backward Weight\n"
|
||||
" softmax: Softmax\n"
|
||||
" reduce: Reduce\n");
|
||||
" reduce: Reduce\n"
|
||||
" bnorm_fwd: Batchnorm forward\n");
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -142,6 +144,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return profile_groupnorm(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "bnorm_fwd") == 0)
|
||||
{
|
||||
return profile_batchnorm_forward(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
print_helper_message();
|
||||
|
||||
Reference in New Issue
Block a user