mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
BatchNorm forward instance/external api/profiler/tests/client example (#511)
* Update to device_batchnorm_forward base class to include all template parameters for problem description
* Add batchnorm forward instances and external api
* Add batchnorm forward profiler module which uses the external api
* Add some comments in batchnorm_forward example to explain the dimensions in lengths[]
* Replace the reference_batchnorm_forward_nhwc_c by generic reference_batchnorm_forward
* Improvement to the batchnorm infer base API
* Add batchnorm forward client example which shows using the batchnorm forward external API
* Add test for batchnorm forward
* Tuning the batchnorm profiler initialized values and error threshold
* Add support for bhalf_t in instances/external api/tests
* Add support for int8_t in instances/external api/tests
* Add support for double in instances/external api/tests
* Let ScaleDataType and BiasDataType be same as XDataType and YDataType when creating instances
* Checking before running best instance in batchnorm_fwd_nhwc client example
* Add checking for YElementwiseOp in batchnorm_forward external API
* Add more types in batchnorm forward profiler
* Add more test lengths
Co-authored-by: rocking5566 <ChunYu.Lai@amd.com>
[ROCm/composable_kernel commit: 4e6a5575be]
This commit is contained in:
@@ -24,6 +24,7 @@ int profile_softmax(int, char*[]);
|
||||
int profile_layernorm(int, char*[]);
|
||||
int profile_groupnorm(int, char*[]);
|
||||
int profile_reduce(int, char*[]);
|
||||
int profile_batchnorm_forward(int, char*[]);
|
||||
|
||||
static void print_helper_message()
|
||||
{
|
||||
@@ -46,7 +47,8 @@ static void print_helper_message()
|
||||
" grouped_conv_fwd: Grouped Convolution Forward\n"
|
||||
" grouped_conv_bwd_weight: Grouped Convolution Backward Weight\n"
|
||||
" softmax: Softmax\n"
|
||||
" reduce: Reduce\n");
|
||||
" reduce: Reduce\n"
|
||||
" bnorm_fwd: Batchnorm forward\n");
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -142,6 +144,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return profile_groupnorm(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "bnorm_fwd") == 0)
|
||||
{
|
||||
return profile_batchnorm_forward(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
print_helper_message();
|
||||
|
||||
Reference in New Issue
Block a user