Added base setup for training t2i adapters. Currently untested, saw something else shiny i wanted to finish sirst. Added content_or_style to the training config. It defaults to balanced, which is standard uniform time step sampling. If style or content is passed, it will use cubic sampling for timesteps to favor timesteps that are beneficial for training them. for style, favor later timesteps. For content, favor earlier timesteps.

This commit is contained in:
Jaret Burkett
2023-09-16 08:30:38 -06:00
parent 17e4fe40d7
commit 27f343fc08
8 changed files with 314 additions and 84 deletions

View File

@@ -1,3 +1,5 @@
from typing import Union
import torch
import copy
@@ -15,16 +17,22 @@ empty_preset = {
'training': False,
'requires_grad': False,
'device': 'cpu',
}
},
'adapter': {
'training': False,
'requires_grad': False,
'device': 'cpu',
},
}
def get_train_sd_device_state_preset (
device: torch.DeviceObjType,
def get_train_sd_device_state_preset(
device: Union[str, torch.device],
train_unet: bool = False,
train_text_encoder: bool = False,
cached_latents: bool = False,
train_lora: bool = False,
train_adapter: bool = False,
train_embedding: bool = False,
):
preset = copy.deepcopy(empty_preset)
@@ -51,9 +59,14 @@ def get_train_sd_device_state_preset (
preset['text_encoder']['training'] = True
preset['unet']['training'] = True
if train_lora:
preset['text_encoder']['requires_grad'] = False
preset['unet']['requires_grad'] = False
if train_adapter:
preset['adapter']['requires_grad'] = True
preset['adapter']['training'] = True
preset['adapter']['device'] = device
preset['unet']['training'] = True
return preset