diff --git a/backend/nn/unet.py b/backend/nn/unet.py index a8c721d6..33e0bf36 100644 --- a/backend/nn/unet.py +++ b/backend/nn/unet.py @@ -202,7 +202,6 @@ class BasicTransformerBlock(nn.Module): self.checkpoint = checkpoint self.n_heads = n_heads self.d_head = d_head - self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa def forward(self, x, context=None, transformer_options={}): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)