Compatability fixes

This commit is contained in:
Jaret Burkett
2023-09-29 14:07:37 -06:00
parent 8509da60cb
commit 8d9450ad7c
3 changed files with 7 additions and 3 deletions

View File

@@ -748,6 +748,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
sig = inspect.signature(self.network.prepare_optimizer_params)
if 'default_lr' in sig.parameters:
config['default_lr'] = self.train_config.lr
if 'learning_rate' in sig.parameters:
config['learning_rate'] = self.train_config.lr
params_net = self.network.prepare_optimizer_params(
**config
)

View File

@@ -216,9 +216,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
# called before LoRA network is loaded but after model is loaded
# attach the adapter here so it is there before we load the network
adapter_path = 'TencentARC/t2iadapter_depth_sd15v2'
if self.sd.is_xl:
if self.model_config.is_xl:
adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0'
print(f"Loading T2I Adapter from {adapter_path}")
# dont name this adapter since we are not training it
self.t2i_adapter = T2IAdapter.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"

View File

@@ -1,9 +1,9 @@
torch
torchvision
safetensors
diffusers==0.21.1
diffusers==0.21.3
transformers
lycoris_lora
lycoris-lora==1.8.3
flatten_json
pyyaml
oyaml