mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-14 15:07:22 +00:00
Fixed breaking change with diffusers. Allow flowmatch on normal stable diffusion models.
This commit is contained in:
@@ -13,7 +13,7 @@ from transformers import CLIPImageProcessor
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||
import torchvision.transforms.functional
|
||||
from toolkit.image_utils import show_img
|
||||
from toolkit.image_utils import show_img, show_tensors
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
@@ -34,7 +34,7 @@ parser.add_argument('--epochs', type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_folder = args.dataset_folder
|
||||
resolution = 512
|
||||
resolution = 1024
|
||||
bucket_tolerance = 64
|
||||
batch_size = 1
|
||||
|
||||
@@ -55,8 +55,8 @@ class FakeSD:
|
||||
|
||||
dataset_config = DatasetConfig(
|
||||
dataset_path=dataset_folder,
|
||||
clip_image_path=dataset_folder,
|
||||
square_crop=True,
|
||||
# clip_image_path=dataset_folder,
|
||||
# square_crop=True,
|
||||
resolution=resolution,
|
||||
# caption_ext='json',
|
||||
default_caption='default',
|
||||
@@ -88,32 +88,37 @@ for epoch in range(args.epochs):
|
||||
|
||||
# img_batch = color_block_imgs(img_batch, neg1_1=True)
|
||||
|
||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# put them so they are size by side
|
||||
big_img = torch.cat(chunks, dim=3)
|
||||
big_img = big_img.squeeze(0)
|
||||
# chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# # put them so they are size by side
|
||||
# big_img = torch.cat(chunks, dim=3)
|
||||
# big_img = big_img.squeeze(0)
|
||||
#
|
||||
# control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0)
|
||||
# big_control_img = torch.cat(control_chunks, dim=3)
|
||||
# big_control_img = big_control_img.squeeze(0) * 2 - 1
|
||||
#
|
||||
#
|
||||
# # resize control image
|
||||
# big_control_img = torchvision.transforms.Resize((width, height))(big_control_img)
|
||||
#
|
||||
# big_img = torch.cat([big_img, big_control_img], dim=2)
|
||||
#
|
||||
# min_val = big_img.min()
|
||||
# max_val = big_img.max()
|
||||
#
|
||||
# big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0)
|
||||
big_control_img = torch.cat(control_chunks, dim=3)
|
||||
big_control_img = big_control_img.squeeze(0) * 2 - 1
|
||||
big_img = img_batch
|
||||
# big_img = big_img.clamp(-1, 1)
|
||||
|
||||
|
||||
# resize control image
|
||||
big_control_img = torchvision.transforms.Resize((width, height))(big_control_img)
|
||||
|
||||
big_img = torch.cat([big_img, big_control_img], dim=2)
|
||||
|
||||
min_val = big_img.min()
|
||||
max_val = big_img.max()
|
||||
|
||||
big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
||||
show_tensors(big_img)
|
||||
|
||||
# convert to image
|
||||
img = transforms.ToPILImage()(big_img)
|
||||
# img = transforms.ToPILImage()(big_img)
|
||||
#
|
||||
# show_img(img)
|
||||
|
||||
show_img(img)
|
||||
|
||||
time.sleep(1.0)
|
||||
time.sleep(0.2)
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
Reference in New Issue
Block a user