mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
t2i training working from what I can tell at least
This commit is contained in:
@@ -250,7 +250,7 @@ class StableDiffusion:
|
||||
# add hacks to unet to help training
|
||||
# pipe.unet = prepare_unet_for_training(pipe.unet)
|
||||
|
||||
self.unet = pipe.unet
|
||||
self.unet: 'UNet2DConditionModel' = pipe.unet
|
||||
self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
@@ -360,8 +360,9 @@ class StableDiffusion:
|
||||
extra = {}
|
||||
if gen_config.adapter_image_path is not None:
|
||||
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
|
||||
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
||||
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
||||
extra['image'] = validation_image
|
||||
extra['adapter_conditioning_scale'] = 1.0
|
||||
|
||||
if self.network is not None:
|
||||
self.network.multiplier = gen_config.network_multiplier
|
||||
@@ -933,7 +934,7 @@ class StableDiffusion:
|
||||
self.device_state['adapter'] = {
|
||||
'training': self.adapter.training,
|
||||
'device': self.adapter.device,
|
||||
'requires_grad': self.adapter.requires_grad,
|
||||
'requires_grad': self.adapter.adapter.conv_in.weight.requires_grad,
|
||||
}
|
||||
|
||||
def restore_device_state(self):
|
||||
|
||||
Reference in New Issue
Block a user