mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Adapter work. Bug fixes. Auto adjust LR when resuming optimizer.
This commit is contained in:
@@ -7,7 +7,7 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from diffusers import T2IAdapter
|
||||
from torch.utils.data import DataLoader
|
||||
from diffusers import StableDiffusionXLAdapterPipeline
|
||||
from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
|
||||
@@ -100,25 +100,43 @@ class ReferenceGenerator(BaseExtensionProcess):
|
||||
|
||||
if self.generate_config.t2i_adapter_path is not None:
|
||||
self.adapter = T2IAdapter.from_pretrained(
|
||||
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=self.torch_dtype, varient="fp16"
|
||||
self.generate_config.t2i_adapter_path,
|
||||
torch_dtype=self.torch_dtype,
|
||||
varient="fp16"
|
||||
).to(device)
|
||||
|
||||
midas_depth = MidasDetector.from_pretrained(
|
||||
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
|
||||
).to(device)
|
||||
|
||||
pipe = StableDiffusionXLAdapterPipeline(
|
||||
vae=self.sd.vae,
|
||||
unet=self.sd.unet,
|
||||
text_encoder=self.sd.text_encoder[0],
|
||||
text_encoder_2=self.sd.text_encoder[1],
|
||||
tokenizer=self.sd.tokenizer[0],
|
||||
tokenizer_2=self.sd.tokenizer[1],
|
||||
scheduler=get_sampler(self.generate_config.sampler),
|
||||
adapter=self.adapter,
|
||||
).to(device)
|
||||
if self.model_config.is_xl:
|
||||
pipe = StableDiffusionXLAdapterPipeline(
|
||||
vae=self.sd.vae,
|
||||
unet=self.sd.unet,
|
||||
text_encoder=self.sd.text_encoder[0],
|
||||
text_encoder_2=self.sd.text_encoder[1],
|
||||
tokenizer=self.sd.tokenizer[0],
|
||||
tokenizer_2=self.sd.tokenizer[1],
|
||||
scheduler=get_sampler(self.generate_config.sampler),
|
||||
adapter=self.adapter,
|
||||
).to(device, dtype=self.torch_dtype)
|
||||
else:
|
||||
pipe = StableDiffusionAdapterPipeline(
|
||||
vae=self.sd.vae,
|
||||
unet=self.sd.unet,
|
||||
text_encoder=self.sd.text_encoder,
|
||||
tokenizer=self.sd.tokenizer,
|
||||
scheduler=get_sampler(self.generate_config.sampler),
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
adapter=self.adapter,
|
||||
).to(device, dtype=self.torch_dtype)
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
# midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
|
||||
|
||||
num_batches = len(self.data_loader)
|
||||
@@ -176,6 +194,7 @@ class ReferenceGenerator(BaseExtensionProcess):
|
||||
adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
|
||||
guidance_scale=self.generate_config.guidance_scale,
|
||||
).images[0]
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
gen_images.save(output_path)
|
||||
|
||||
# save caption
|
||||
|
||||
Reference in New Issue
Block a user