Fixed breaking change with diffusers. Allow flowmatch on normal stable diffusion models.

This commit is contained in:
Jaret Burkett
2024-08-22 14:36:22 -06:00
parent e07a98a50c
commit 338c77d677
4 changed files with 37 additions and 28 deletions

View File

@@ -13,7 +13,7 @@ from transformers import CLIPImageProcessor
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from toolkit.paths import SD_SCRIPTS_ROOT
import torchvision.transforms.functional
from toolkit.image_utils import show_img
from toolkit.image_utils import show_img, show_tensors
sys.path.append(SD_SCRIPTS_ROOT)
@@ -34,7 +34,7 @@ parser.add_argument('--epochs', type=int, default=1)
args = parser.parse_args()
dataset_folder = args.dataset_folder
resolution = 512
resolution = 1024
bucket_tolerance = 64
batch_size = 1
@@ -55,8 +55,8 @@ class FakeSD:
dataset_config = DatasetConfig(
dataset_path=dataset_folder,
clip_image_path=dataset_folder,
square_crop=True,
# clip_image_path=dataset_folder,
# square_crop=True,
resolution=resolution,
# caption_ext='json',
default_caption='default',
@@ -88,32 +88,37 @@ for epoch in range(args.epochs):
# img_batch = color_block_imgs(img_batch, neg1_1=True)
chunks = torch.chunk(img_batch, batch_size, dim=0)
# put them so they are size by side
big_img = torch.cat(chunks, dim=3)
big_img = big_img.squeeze(0)
# chunks = torch.chunk(img_batch, batch_size, dim=0)
# # put them so they are size by side
# big_img = torch.cat(chunks, dim=3)
# big_img = big_img.squeeze(0)
#
# control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0)
# big_control_img = torch.cat(control_chunks, dim=3)
# big_control_img = big_control_img.squeeze(0) * 2 - 1
#
#
# # resize control image
# big_control_img = torchvision.transforms.Resize((width, height))(big_control_img)
#
# big_img = torch.cat([big_img, big_control_img], dim=2)
#
# min_val = big_img.min()
# max_val = big_img.max()
#
# big_img = (big_img / 2 + 0.5).clamp(0, 1)
control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0)
big_control_img = torch.cat(control_chunks, dim=3)
big_control_img = big_control_img.squeeze(0) * 2 - 1
big_img = img_batch
# big_img = big_img.clamp(-1, 1)
# resize control image
big_control_img = torchvision.transforms.Resize((width, height))(big_control_img)
big_img = torch.cat([big_img, big_control_img], dim=2)
min_val = big_img.min()
max_val = big_img.max()
big_img = (big_img / 2 + 0.5).clamp(0, 1)
show_tensors(big_img)
# convert to image
img = transforms.ToPILImage()(big_img)
# img = transforms.ToPILImage()(big_img)
#
# show_img(img)
show_img(img)
time.sleep(1.0)
time.sleep(0.2)
# if not last epoch
if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader)