mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Did some work on SD rescaler. Need to run a long test on it eventually.
This commit is contained in:
@@ -114,7 +114,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
tokenizer=self.sd.tokenizer[0],
|
||||
tokenizer_2=self.sd.tokenizer[1],
|
||||
scheduler=self.sd.noise_scheduler,
|
||||
)
|
||||
).to(self.device_torch)
|
||||
else:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
vae=self.sd.vae,
|
||||
@@ -125,7 +125,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
).to(self.device_torch)
|
||||
# disable progress bar
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -387,7 +387,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
text_embeddings: PromptEmbeds,
|
||||
timestep: int,
|
||||
guidance_scale=7.5,
|
||||
guidance_rescale=0, # 0.7
|
||||
guidance_rescale=0, # 0.7
|
||||
add_time_ids=None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -585,17 +585,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
unet.eval()
|
||||
|
||||
if self.network_config is not None:
|
||||
conv = self.network_config.conv if self.network_config.conv is not None and self.network_config.conv > 0 else None
|
||||
self.network = LoRASpecialNetwork(
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
lora_dim=self.network_config.linear,
|
||||
multiplier=1.0,
|
||||
alpha=self.network_config.alpha,
|
||||
alpha=self.network_config.linear_alpha,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
conv_lora_dim=conv,
|
||||
conv_alpha=self.network_config.alpha if conv is not None else None,
|
||||
conv_lora_dim=self.network_config.conv,
|
||||
conv_alpha=self.network_config.conv_alpha,
|
||||
)
|
||||
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user