WIP on clip vision encoder

This commit is contained in:
Jaret Burkett
2024-03-13 07:24:08 -06:00
parent d87b49882c
commit 72de68d8aa
4 changed files with 164 additions and 73 deletions

View File

@@ -39,7 +39,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel
import diffusers
from diffusers import \
AutoencoderKL, \
@@ -242,10 +242,21 @@ class StableDiffusion:
device_map="auto",
torch_dtype=self.torch_dtype,
)
# load the transformer
subfolder = "transformer"
# check if it is just the unet
if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)):
subfolder = None
# load the transformer only from the save
transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype, subfolder=subfolder)
# replace the to function with a no-op since it throws an error instead of a warning
text_encoder.to = lambda *args, **kwargs: None
pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained(
model_path,
"PixArt-alpha/PixArt-XL-2-1024-MS",
transformer=transformer,
text_encoder=text_encoder,
dtype=dtype,
device=self.device_torch,
@@ -1081,10 +1092,14 @@ class StableDiffusion:
else:
noise_pred = noise_pred
else:
if self.unet.device != self.device_torch:
self.unet.to(self.device_torch)
if self.unet.dtype != self.torch_dtype:
self.unet = self.unet.to(dtype=self.torch_dtype)
noise_pred = self.unet(
latent_model_input.to(self.device_torch, self.torch_dtype),
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
@@ -1485,10 +1500,16 @@ class StableDiffusion:
# saving in diffusers format
if not output_file.endswith('.safetensors'):
# diffusers
self.pipeline.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
if self.is_pixart:
self.unet.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
else:
self.pipeline.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
# save out meta config
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
with open(meta_path, 'w') as f: