mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
More guidance work. Improved LoRA module resolver for unet. Added vega mappings and LoRA training for it. Various other bigfixes and changes
This commit is contained in:
@@ -68,6 +68,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.vae.to('cpu')
|
||||
flush()
|
||||
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
|
||||
if self.adapter is not None:
|
||||
self.adapter.to(self.device_torch)
|
||||
|
||||
# you can expand these in a child class to make customization easier
|
||||
def calculate_loss(
|
||||
@@ -507,8 +509,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.unet.train()
|
||||
prior_pred = prior_pred.detach()
|
||||
# remove the residuals as we wont use them on prediction when matching control
|
||||
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_block_additional_residuals']
|
||||
if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_intrablock_additional_residuals']
|
||||
# restore network
|
||||
# self.network.multiplier = network_weight_list
|
||||
self.network.is_active = was_network_active
|
||||
@@ -746,7 +748,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
down_block_additional_residuals
|
||||
]
|
||||
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
prior_pred = None
|
||||
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
|
||||
|
||||
Reference in New Issue
Block a user