mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-11 21:49:49 +00:00
Added training for pixart-a
This commit is contained in:
@@ -263,6 +263,8 @@ def load_custom_adapter_model(
|
||||
if path_to_file.endswith('.safetensors'):
|
||||
raw_state_dict = load_file(path_to_file, device)
|
||||
combined_state_dict = OrderedDict()
|
||||
device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
dtype = dtype if isinstance(dtype, torch.dtype) else get_torch_dtype(dtype)
|
||||
for combo_key, value in raw_state_dict.items():
|
||||
key_split = combo_key.split('.')
|
||||
module_name = key_split.pop(0)
|
||||
|
||||
Reference in New Issue
Block a user