mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
WIP on clip vision encoder
This commit is contained in:
@@ -39,7 +39,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
@@ -242,10 +242,21 @@ class StableDiffusion:
|
||||
device_map="auto",
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
|
||||
# load the transformer
|
||||
subfolder = "transformer"
|
||||
# check if it is just the unet
|
||||
if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)):
|
||||
subfolder = None
|
||||
# load the transformer only from the save
|
||||
transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype, subfolder=subfolder)
|
||||
|
||||
|
||||
# replace the to function with a no-op since it throws an error instead of a warning
|
||||
text_encoder.to = lambda *args, **kwargs: None
|
||||
pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained(
|
||||
model_path,
|
||||
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
dtype=dtype,
|
||||
device=self.device_torch,
|
||||
@@ -1081,10 +1092,14 @@ class StableDiffusion:
|
||||
else:
|
||||
noise_pred = noise_pred
|
||||
else:
|
||||
if self.unet.device != self.device_torch:
|
||||
self.unet.to(self.device_torch)
|
||||
if self.unet.dtype != self.torch_dtype:
|
||||
self.unet = self.unet.to(dtype=self.torch_dtype)
|
||||
noise_pred = self.unet(
|
||||
latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
@@ -1485,10 +1500,16 @@ class StableDiffusion:
|
||||
# saving in diffusers format
|
||||
if not output_file.endswith('.safetensors'):
|
||||
# diffusers
|
||||
self.pipeline.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
if self.is_pixart:
|
||||
self.unet.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
else:
|
||||
self.pipeline.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
# save out meta config
|
||||
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
|
||||
Reference in New Issue
Block a user