mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Compatability fixes
This commit is contained in:
@@ -216,9 +216,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# called before LoRA network is loaded but after model is loaded
|
||||
# attach the adapter here so it is there before we load the network
|
||||
adapter_path = 'TencentARC/t2iadapter_depth_sd15v2'
|
||||
if self.sd.is_xl:
|
||||
if self.model_config.is_xl:
|
||||
adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0'
|
||||
|
||||
print(f"Loading T2I Adapter from {adapter_path}")
|
||||
|
||||
# dont name this adapter since we are not training it
|
||||
self.t2i_adapter = T2IAdapter.from_pretrained(
|
||||
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
|
||||
|
||||
Reference in New Issue
Block a user