mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Compatability fixes
This commit is contained in:
@@ -748,6 +748,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
sig = inspect.signature(self.network.prepare_optimizer_params)
|
sig = inspect.signature(self.network.prepare_optimizer_params)
|
||||||
if 'default_lr' in sig.parameters:
|
if 'default_lr' in sig.parameters:
|
||||||
config['default_lr'] = self.train_config.lr
|
config['default_lr'] = self.train_config.lr
|
||||||
|
if 'learning_rate' in sig.parameters:
|
||||||
|
config['learning_rate'] = self.train_config.lr
|
||||||
params_net = self.network.prepare_optimizer_params(
|
params_net = self.network.prepare_optimizer_params(
|
||||||
**config
|
**config
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -216,9 +216,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
# called before LoRA network is loaded but after model is loaded
|
# called before LoRA network is loaded but after model is loaded
|
||||||
# attach the adapter here so it is there before we load the network
|
# attach the adapter here so it is there before we load the network
|
||||||
adapter_path = 'TencentARC/t2iadapter_depth_sd15v2'
|
adapter_path = 'TencentARC/t2iadapter_depth_sd15v2'
|
||||||
if self.sd.is_xl:
|
if self.model_config.is_xl:
|
||||||
adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0'
|
adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0'
|
||||||
|
|
||||||
|
print(f"Loading T2I Adapter from {adapter_path}")
|
||||||
|
|
||||||
# dont name this adapter since we are not training it
|
# dont name this adapter since we are not training it
|
||||||
self.t2i_adapter = T2IAdapter.from_pretrained(
|
self.t2i_adapter = T2IAdapter.from_pretrained(
|
||||||
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
|
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
safetensors
|
safetensors
|
||||||
diffusers==0.21.1
|
diffusers==0.21.3
|
||||||
transformers
|
transformers
|
||||||
lycoris_lora
|
lycoris-lora==1.8.3
|
||||||
flatten_json
|
flatten_json
|
||||||
pyyaml
|
pyyaml
|
||||||
oyaml
|
oyaml
|
||||||
|
|||||||
Reference in New Issue
Block a user