From b0bd6e3895652b387d9a57b6b2f21910ff972f70 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Thu, 24 Aug 2023 01:13:07 +0800 Subject: [PATCH] Add workspace setting up for batchnorm bwd/fwd client examples (#860) [ROCm/composable_kernel commit: 350d64f351701af68e58c64e0e97efe07f7bf126] --- client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp | 6 ++++++ client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp index c0140f71c1..1ed36e0f50 100644 --- a/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp @@ -191,6 +191,12 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); } diff --git a/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp b/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp index 3653733436..f9af011c84 100644 --- a/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp @@ -187,6 +187,12 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); }