Files
ai-toolkit/toolkit/sd_device_states_presets.py

60 lines
1.5 KiB
Python

import torch
import copy
empty_preset = {
'vae': {
'training': False,
'device': 'cpu',
},
'unet': {
'training': False,
'requires_grad': False,
'device': 'cpu',
},
'text_encoder': {
'training': False,
'requires_grad': False,
'device': 'cpu',
}
}
def get_train_sd_device_state_preset (
device: torch.DeviceObjType,
train_unet: bool = False,
train_text_encoder: bool = False,
cached_latents: bool = False,
train_lora: bool = False,
train_embedding: bool = False,
):
preset = copy.deepcopy(empty_preset)
if not cached_latents:
preset['vae']['device'] = device
if train_unet:
preset['unet']['training'] = True
preset['unet']['requires_grad'] = True
preset['unet']['device'] = device
else:
preset['unet']['device'] = device
if train_text_encoder:
preset['text_encoder']['training'] = True
preset['text_encoder']['requires_grad'] = True
preset['text_encoder']['device'] = device
else:
preset['text_encoder']['device'] = device
if train_embedding:
preset['text_encoder']['training'] = True
preset['text_encoder']['requires_grad'] = True
preset['text_encoder']['training'] = True
preset['unet']['training'] = True
if train_lora:
preset['text_encoder']['requires_grad'] = False
preset['unet']['requires_grad'] = False
return preset