More work on custom adapter

This commit is contained in:
Jaret Burkett
2024-01-16 17:41:26 -07:00
parent eebd3c8212
commit 655533d4c7
2 changed files with 3 additions and 2 deletions

View File

@@ -245,12 +245,13 @@ class CustomAdapter(torch.nn.Module):
def state_dict(self) -> OrderedDict:
state_dict = OrderedDict()
if self.adapter_type == 'photo_maker':
if self.train_image_encoder:
if self.config.train_image_encoder:
state_dict["id_encoder"] = self.vision_encoder.state_dict()
state_dict["fuse_module"] = self.fuse_module.state_dict()
# todo save LoRA
return state_dict
else:
raise NotImplementedError

View File

@@ -788,7 +788,7 @@ def apply_snr_weight(
offset = 0
if noise_scheduler.timesteps[0] == 1000:
offset = 1
snr = torch.stack([all_snr[t - offset] for t in timesteps])
snr = torch.stack([all_snr[(t - offset).int()] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
if fixed:
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr