Added working ilora trainer

This commit is contained in:
Jaret Burkett
2024-06-12 09:33:45 -06:00
parent 3f3636b788
commit cb5d28cba9
6 changed files with 261 additions and 196 deletions

View File

@@ -46,7 +46,7 @@ class SDTrainer(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
self.assistant_adapter: Union['T2IAdapter', None]
self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None]
self.do_prior_prediction = False
self.do_long_prompts = False
self.do_guided_loss = False
@@ -76,10 +76,18 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.adapter_assist_name_or_path is not None:
adapter_path = self.train_config.adapter_assist_name_or_path
# dont name this adapter since we are not training it
self.assistant_adapter = T2IAdapter.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
).to(self.device_torch)
if self.train_config.adapter_assist_type == "t2i":
# dont name this adapter since we are not training it
self.assistant_adapter = T2IAdapter.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
).to(self.device_torch)
elif self.train_config.adapter_assist_type == "control_net":
self.assistant_adapter = ControlNetModel.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
else:
raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}")
self.assistant_adapter.eval()
self.assistant_adapter.requires_grad_(False)
flush()
@@ -955,10 +963,10 @@ class SDTrainer(BaseSDTrainProcess):
adapter_strength_max = 1.0
else:
# training with assistance, we want it low
adapter_strength_min = 0.4
adapter_strength_max = 0.7
# adapter_strength_min = 0.9
# adapter_strength_max = 1.1
# adapter_strength_min = 0.4
# adapter_strength_max = 0.7
adapter_strength_min = 0.9
adapter_strength_max = 1.1
adapter_conditioning_scale = torch.rand(
(1,), device=self.device_torch, dtype=dtype