Adapter work. Bug fixes. Auto adjust LR when resuming optimizer.

This commit is contained in:
Jaret Burkett
2024-03-17 10:21:47 -06:00
parent 72de68d8aa
commit 016687bda1
8 changed files with 84 additions and 15 deletions

View File

@@ -7,7 +7,7 @@ import numpy as np
from PIL import Image
from diffusers import T2IAdapter
from torch.utils.data import DataLoader
from diffusers import StableDiffusionXLAdapterPipeline
from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
from tqdm import tqdm
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
@@ -100,25 +100,43 @@ class ReferenceGenerator(BaseExtensionProcess):
if self.generate_config.t2i_adapter_path is not None:
self.adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=self.torch_dtype, varient="fp16"
self.generate_config.t2i_adapter_path,
torch_dtype=self.torch_dtype,
varient="fp16"
).to(device)
midas_depth = MidasDetector.from_pretrained(
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
).to(device)
pipe = StableDiffusionXLAdapterPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder[0],
text_encoder_2=self.sd.text_encoder[1],
tokenizer=self.sd.tokenizer[0],
tokenizer_2=self.sd.tokenizer[1],
scheduler=get_sampler(self.generate_config.sampler),
adapter=self.adapter,
).to(device)
if self.model_config.is_xl:
pipe = StableDiffusionXLAdapterPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder[0],
text_encoder_2=self.sd.text_encoder[1],
tokenizer=self.sd.tokenizer[0],
tokenizer_2=self.sd.tokenizer[1],
scheduler=get_sampler(self.generate_config.sampler),
adapter=self.adapter,
).to(device, dtype=self.torch_dtype)
else:
pipe = StableDiffusionAdapterPipeline(
vae=self.sd.vae,
unet=self.sd.unet,
text_encoder=self.sd.text_encoder,
tokenizer=self.sd.tokenizer,
scheduler=get_sampler(self.generate_config.sampler),
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
adapter=self.adapter,
).to(device, dtype=self.torch_dtype)
pipe.set_progress_bar_config(disable=True)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
# midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
num_batches = len(self.data_loader)
@@ -176,6 +194,7 @@ class ReferenceGenerator(BaseExtensionProcess):
adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
guidance_scale=self.generate_config.guidance_scale,
).images[0]
os.makedirs(os.path.dirname(output_path), exist_ok=True)
gen_images.save(output_path)
# save caption