diff --git a/backend/nn/unet.py b/backend/nn/unet.py index 33e0bf36..30122752 100644 --- a/backend/nn/unet.py +++ b/backend/nn/unet.py @@ -281,10 +281,7 @@ class BasicTransformerBlock(nn.Module): if self.attn2 is not None: n = self.norm2(x) - if self.switch_temporal_ca_to_sa: - context_attn2 = n - else: - context_attn2 = context + context_attn2 = context value_attn2 = None if "attn2_patch" in transformer_patches: patch = transformer_patches["attn2_patch"]