mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added working ilora trainer
This commit is contained in:
@@ -46,7 +46,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
self.assistant_adapter: Union['T2IAdapter', None]
|
||||
self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None]
|
||||
self.do_prior_prediction = False
|
||||
self.do_long_prompts = False
|
||||
self.do_guided_loss = False
|
||||
@@ -76,10 +76,18 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.adapter_assist_name_or_path is not None:
|
||||
adapter_path = self.train_config.adapter_assist_name_or_path
|
||||
|
||||
# dont name this adapter since we are not training it
|
||||
self.assistant_adapter = T2IAdapter.from_pretrained(
|
||||
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
|
||||
).to(self.device_torch)
|
||||
if self.train_config.adapter_assist_type == "t2i":
|
||||
# dont name this adapter since we are not training it
|
||||
self.assistant_adapter = T2IAdapter.from_pretrained(
|
||||
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
|
||||
).to(self.device_torch)
|
||||
elif self.train_config.adapter_assist_type == "control_net":
|
||||
self.assistant_adapter = ControlNetModel.from_pretrained(
|
||||
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
|
||||
).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
|
||||
else:
|
||||
raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}")
|
||||
|
||||
self.assistant_adapter.eval()
|
||||
self.assistant_adapter.requires_grad_(False)
|
||||
flush()
|
||||
@@ -955,10 +963,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
adapter_strength_max = 1.0
|
||||
else:
|
||||
# training with assistance, we want it low
|
||||
adapter_strength_min = 0.4
|
||||
adapter_strength_max = 0.7
|
||||
# adapter_strength_min = 0.9
|
||||
# adapter_strength_max = 1.1
|
||||
# adapter_strength_min = 0.4
|
||||
# adapter_strength_max = 0.7
|
||||
adapter_strength_min = 0.9
|
||||
adapter_strength_max = 1.1
|
||||
|
||||
adapter_conditioning_scale = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
|
||||
Reference in New Issue
Block a user