mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added caching to image sizes so we dont do it every time.
This commit is contained in:
@@ -11,10 +11,10 @@ import json
|
||||
# te_path = "google/flan-t5-xl"
|
||||
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
||||
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
||||
model_path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
|
||||
te_path = "google/flan-t5-large"
|
||||
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5l_000034000.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw"
|
||||
model_path = "/home/jaret/Dev/models/hf/objective-reality-16ch"
|
||||
te_path = "google/flan-t5-xl"
|
||||
te_aug_path = "/mnt/Train2/out/ip_adapter/t5xl-sd15-16ch_v1/t5xl-sd15-16ch_v1_000115000.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/t5xl-sd15-16ch_sd15_v1"
|
||||
|
||||
|
||||
print("Loading te adapter")
|
||||
@@ -28,13 +28,13 @@ is_pixart = "pixart" in model_path.lower()
|
||||
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
|
||||
transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16)
|
||||
# transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16)
|
||||
|
||||
if is_pixart:
|
||||
pipeline_class = PixArtSigmaPipeline
|
||||
|
||||
if is_diffusers:
|
||||
sd = pipeline_class.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16)
|
||||
sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
else:
|
||||
sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16)
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_pixart:
|
||||
unet = sd.transformer
|
||||
unet_sd = sd.transformer.state_dict()
|
||||
else:
|
||||
unet = sd.transformer
|
||||
unet = sd.unet
|
||||
unet_sd = sd.unet.state_dict()
|
||||
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ for epoch in range(args.epochs):
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
|
||||
img_batch = color_block_imgs(img_batch, neg1_1=True)
|
||||
# img_batch = color_block_imgs(img_batch, neg1_1=True)
|
||||
|
||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# put them so they are size by side
|
||||
@@ -208,9 +208,9 @@ for epoch in range(args.epochs):
|
||||
# convert to image
|
||||
img = transforms.ToPILImage()(big_img)
|
||||
|
||||
show_img(img)
|
||||
# show_img(img)
|
||||
|
||||
time.sleep(1.0)
|
||||
# time.sleep(1.0)
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
Reference in New Issue
Block a user