diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 26fb7614..540b7a69 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -245,12 +245,13 @@ class CustomAdapter(torch.nn.Module): def state_dict(self) -> OrderedDict: state_dict = OrderedDict() if self.adapter_type == 'photo_maker': - if self.train_image_encoder: + if self.config.train_image_encoder: state_dict["id_encoder"] = self.vision_encoder.state_dict() state_dict["fuse_module"] = self.fuse_module.state_dict() # todo save LoRA + return state_dict else: raise NotImplementedError diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index cc18e0fa..b466296b 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -788,7 +788,7 @@ def apply_snr_weight( offset = 0 if noise_scheduler.timesteps[0] == 1000: offset = 1 - snr = torch.stack([all_snr[t - offset] for t in timesteps]) + snr = torch.stack([all_snr[(t - offset).int()] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) if fixed: snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr