From 655533d4c7b730713f4ac6b6fff91590af16120e Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 16 Jan 2024 17:41:26 -0700 Subject: [PATCH] More work on custom adapter --- toolkit/custom_adapter.py | 3 ++- toolkit/train_tools.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 26fb7614..540b7a69 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -245,12 +245,13 @@ class CustomAdapter(torch.nn.Module): def state_dict(self) -> OrderedDict: state_dict = OrderedDict() if self.adapter_type == 'photo_maker': - if self.train_image_encoder: + if self.config.train_image_encoder: state_dict["id_encoder"] = self.vision_encoder.state_dict() state_dict["fuse_module"] = self.fuse_module.state_dict() # todo save LoRA + return state_dict else: raise NotImplementedError diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index cc18e0fa..b466296b 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -788,7 +788,7 @@ def apply_snr_weight( offset = 0 if noise_scheduler.timesteps[0] == 1000: offset = 1 - snr = torch.stack([all_snr[t - offset] for t in timesteps]) + snr = torch.stack([all_snr[(t - offset).int()] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) if fixed: snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr