Did some work on SD rescaler. Need to run a long test on it eventually.

This commit is contained in:
Jaret Burkett
2023-08-02 07:59:27 -06:00
parent 2bf3e529ce
commit 1a25b275c8
4 changed files with 96 additions and 60 deletions

View File

@@ -114,7 +114,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
tokenizer=self.sd.tokenizer[0],
tokenizer_2=self.sd.tokenizer[1],
scheduler=self.sd.noise_scheduler,
)
).to(self.device_torch)
else:
pipeline = StableDiffusionPipeline(
vae=self.sd.vae,
@@ -125,7 +125,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
).to(self.device_torch)
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
@@ -387,7 +387,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
text_embeddings: PromptEmbeds,
timestep: int,
guidance_scale=7.5,
guidance_rescale=0, # 0.7
guidance_rescale=0, # 0.7
add_time_ids=None,
**kwargs,
):
@@ -585,17 +585,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
unet.eval()
if self.network_config is not None:
conv = self.network_config.conv if self.network_config.conv is not None and self.network_config.conv > 0 else None
self.network = LoRASpecialNetwork(
text_encoder=text_encoder,
unet=unet,
lora_dim=self.network_config.linear,
multiplier=1.0,
alpha=self.network_config.alpha,
alpha=self.network_config.linear_alpha,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
conv_lora_dim=conv,
conv_alpha=self.network_config.alpha if conv is not None else None,
conv_lora_dim=self.network_config.conv,
conv_alpha=self.network_config.conv_alpha,
)
self.network.force_to(self.device_torch, dtype=dtype)