mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-28 08:13:58 +00:00
Work to omprove pixart training
This commit is contained in:
@@ -38,9 +38,9 @@ import torch
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
@@ -185,7 +185,6 @@ class StableDiffusion:
|
||||
}
|
||||
if self.model_config.vae_path is not None:
|
||||
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
||||
|
||||
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
@@ -331,11 +330,11 @@ class StableDiffusion:
|
||||
if self.model_config.is_pixart_sigma:
|
||||
# load the transformer only from the save
|
||||
transformer = Transformer2DModel.from_pretrained(
|
||||
model_path,
|
||||
model_path if self.model_config.unet_path is None else self.model_config.unet_path,
|
||||
torch_dtype=self.torch_dtype,
|
||||
subfolder='transformer'
|
||||
)
|
||||
pipe: PixArtAlphaPipeline = PixArtSigmaPipeline.from_pretrained(
|
||||
pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained(
|
||||
main_model_path,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
@@ -357,6 +356,9 @@ class StableDiffusion:
|
||||
device=self.device_torch,
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
|
||||
if self.model_config.unet_sample_size is not None:
|
||||
pipe.transformer.config.sample_size = self.model_config.unet_sample_size
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
@@ -592,7 +594,7 @@ class StableDiffusion:
|
||||
**extra_args
|
||||
)
|
||||
elif self.is_pixart:
|
||||
pipeline = PixArtAlphaPipeline(
|
||||
pipeline = PixArtSigmaPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
@@ -1243,7 +1245,7 @@ class StableDiffusion:
|
||||
elif self.pipeline.transformer.config.sample_size == 32:
|
||||
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
||||
else:
|
||||
raise ValueError("Invalid sample size")
|
||||
raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}")
|
||||
orig_height, orig_width = height, width
|
||||
height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user