diff --git a/toolkit/models/ilora2.py b/toolkit/models/ilora2.py index de11bea9..5ddc19d6 100644 --- a/toolkit/models/ilora2.py +++ b/toolkit/models/ilora2.py @@ -138,7 +138,7 @@ class InstantLoRAMidModule(torch.nn.Module): self.down_dim = self.down_shape[1] if self.do_down else 0 self.mid_dim = self.up_shape[1] if self.do_mid else 0 - self.out_dim = self.up_shape[0] if self.do_down else 0 + self.out_dim = self.up_shape[0] if self.do_up else 0 self.embed = None