mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
60 lines
1.5 KiB
Python
60 lines
1.5 KiB
Python
import torch
|
|
import copy
|
|
|
|
empty_preset = {
|
|
'vae': {
|
|
'training': False,
|
|
'device': 'cpu',
|
|
},
|
|
'unet': {
|
|
'training': False,
|
|
'requires_grad': False,
|
|
'device': 'cpu',
|
|
},
|
|
'text_encoder': {
|
|
'training': False,
|
|
'requires_grad': False,
|
|
'device': 'cpu',
|
|
}
|
|
}
|
|
|
|
|
|
def get_train_sd_device_state_preset (
|
|
device: torch.DeviceObjType,
|
|
train_unet: bool = False,
|
|
train_text_encoder: bool = False,
|
|
cached_latents: bool = False,
|
|
train_lora: bool = False,
|
|
train_embedding: bool = False,
|
|
):
|
|
preset = copy.deepcopy(empty_preset)
|
|
if not cached_latents:
|
|
preset['vae']['device'] = device
|
|
|
|
if train_unet:
|
|
preset['unet']['training'] = True
|
|
preset['unet']['requires_grad'] = True
|
|
preset['unet']['device'] = device
|
|
else:
|
|
preset['unet']['device'] = device
|
|
|
|
if train_text_encoder:
|
|
preset['text_encoder']['training'] = True
|
|
preset['text_encoder']['requires_grad'] = True
|
|
preset['text_encoder']['device'] = device
|
|
else:
|
|
preset['text_encoder']['device'] = device
|
|
|
|
if train_embedding:
|
|
preset['text_encoder']['training'] = True
|
|
preset['text_encoder']['requires_grad'] = True
|
|
preset['text_encoder']['training'] = True
|
|
preset['unet']['training'] = True
|
|
|
|
|
|
if train_lora:
|
|
preset['text_encoder']['requires_grad'] = False
|
|
preset['unet']['requires_grad'] = False
|
|
|
|
return preset
|