Don't use workspace memory in the case where A needs explicit transposition but B does not.

This commit is contained in:
kiefer
2025-09-16 15:34:35 +00:00
parent a28d10253c
commit a8a5504f31

View File

@@ -813,11 +813,11 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
}
}
// TODO: This might use unnecessary memory when we need to transpose A but not B. Need to
// check how this is used.
std::size_t GetWorkspaceBTensorSizeBytes() const
{
if constexpr(NeedTransposeKernel)
if constexpr(NeedTransposeKernel &&
(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>()))
{
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
@@ -983,12 +983,10 @@ struct DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
p_as_grid[0] = type_convert<const void*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) / // TODO: This offset might be
// unnecessary if we are not
// doing a B transpose.
sizeof(EDataType);
p_e_grid =
type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
}