Added training for pixart-a

This commit is contained in:
Jaret Burkett
2024-02-13 16:00:04 -07:00
parent 4ec4025cbb
commit 93b52932c1
10 changed files with 288 additions and 24 deletions

View File

@@ -263,6 +263,8 @@ def load_custom_adapter_model(
if path_to_file.endswith('.safetensors'):
raw_state_dict = load_file(path_to_file, device)
combined_state_dict = OrderedDict()
device = device if isinstance(device, torch.device) else torch.device(device)
dtype = dtype if isinstance(dtype, torch.dtype) else get_torch_dtype(dtype)
for combo_key, value in raw_state_dict.items():
key_split = combo_key.split('.')
module_name = key_split.pop(0)