mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
More work on custom adapter
This commit is contained in:
@@ -245,12 +245,13 @@ class CustomAdapter(torch.nn.Module):
|
||||
def state_dict(self) -> OrderedDict:
|
||||
state_dict = OrderedDict()
|
||||
if self.adapter_type == 'photo_maker':
|
||||
if self.train_image_encoder:
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["id_encoder"] = self.vision_encoder.state_dict()
|
||||
|
||||
state_dict["fuse_module"] = self.fuse_module.state_dict()
|
||||
|
||||
# todo save LoRA
|
||||
return state_dict
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -788,7 +788,7 @@ def apply_snr_weight(
|
||||
offset = 0
|
||||
if noise_scheduler.timesteps[0] == 1000:
|
||||
offset = 1
|
||||
snr = torch.stack([all_snr[t - offset] for t in timesteps])
|
||||
snr = torch.stack([all_snr[(t - offset).int()] for t in timesteps])
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
if fixed:
|
||||
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
|
||||
|
||||
Reference in New Issue
Block a user