diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp index 9510cedc43..acda55ef1e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp @@ -813,11 +813,11 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 } } - // TODO: This might use unnecessary memory when we need to transpose A but not B. Need to - // check how this is used. std::size_t GetWorkspaceBTensorSizeBytes() const { - if constexpr(NeedTransposeKernel) + if constexpr(NeedTransposeKernel && + (is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW())) { const long_index_t b_acum = ck::accumulate_n( b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); @@ -983,12 +983,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3 is_NGCDHW_GKZYXC_NGKDHW()) { p_as_grid[0] = type_convert(arg.p_workspace_); - p_e_grid = type_convert(arg.p_workspace_) + - (arg.GetWorkspaceATensorSizeBytes() + - arg.GetWorkspaceBTensorSizeBytes()) / // TODO: This offset might be - // unnecessary if we are not - // doing a B transpose. - sizeof(EDataType); + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); } }