diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 5920232038..4bb82baabc 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -900,9 +900,6 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; - // init work space - p_c_workspace_grid_ = nullptr; - block_2_ctile_map_ = GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); @@ -939,9 +936,6 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ std::vector input_left_pads_; std::vector input_right_pads_; index_t k_batch_; - - // external work space - void* p_c_workspace_grid_; }; // Invoker @@ -1017,7 +1011,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ // run kernel for bf16 with splitk const auto run_bf16_splitk = [&](const auto& kernel) { hipGetErrorString(hipMemset( - arg.p_c_workspace_grid_, + arg.p_workspace_, 0, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * sizeof(AccDataType))); @@ -1030,7 +1024,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ 0, arg.p_a_grid_, arg.p_b_grid_, - static_cast(arg.p_c_workspace_grid_), + static_cast(arg.p_workspace_), arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, @@ -1072,7 +1066,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ dim3(type_convert_grid_size), dim3(256), 0, - static_cast(arg.p_c_workspace_grid_), + static_cast(arg.p_workspace_), p_c_grid_tmp_bf16_, a_grid_desc_m0_, b_grid_desc_m0_, @@ -1448,11 +1442,6 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ { return GetWorkSpaceSize(*dynamic_cast(p_arg)); } - - void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override - { - dynamic_cast(p_arg)->p_c_workspace_grid_ = workspace_ptr; - } }; } // namespace device