mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
Don't use workspace memory in the case where A needs explicit transposition but B does not.
This commit is contained in:
@@ -813,11 +813,11 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: This might use unnecessary memory when we need to transpose A but not B. Need to
|
||||
// check how this is used.
|
||||
std::size_t GetWorkspaceBTensorSizeBytes() const
|
||||
{
|
||||
if constexpr(NeedTransposeKernel)
|
||||
if constexpr(NeedTransposeKernel &&
|
||||
(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>()))
|
||||
{
|
||||
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
|
||||
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
@@ -983,12 +983,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_as_grid[0] = type_convert<const void*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() +
|
||||
arg.GetWorkspaceBTensorSizeBytes()) / // TODO: This offset might be
|
||||
// unnecessary if we are not
|
||||
// doing a B transpose.
|
||||
sizeof(EDataType);
|
||||
p_e_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
|
||||
sizeof(EDataType);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user