mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Pixel shuffle adapter. Some bug fixes thrown in
This commit is contained in:
@@ -11,6 +11,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.models.clip_fusion import CLIPFusionModule
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.models.control_lora_adapter import ControlLoraAdapter
|
||||
from toolkit.models.subpixel_adapter import SubpixelAdapter
|
||||
from toolkit.models.ilora import InstantLoRAModule
|
||||
from toolkit.models.single_value_adapter import SingleValueAdapter
|
||||
from toolkit.models.te_adapter import TEAdapter
|
||||
@@ -103,6 +104,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.single_value_adapter: SingleValueAdapter = None
|
||||
self.redux_adapter: ReduxImageEncoder = None
|
||||
self.control_lora: ControlLoraAdapter = None
|
||||
self.subpixel_adapter: SubpixelAdapter = None
|
||||
|
||||
self.conditional_embeds: Optional[torch.Tensor] = None
|
||||
self.unconditional_embeds: Optional[torch.Tensor] = None
|
||||
@@ -253,6 +255,13 @@ class CustomAdapter(torch.nn.Module):
|
||||
config=self.config,
|
||||
train_config=self.train_config
|
||||
)
|
||||
elif self.adapter_type == 'subpixel':
|
||||
self.subpixel_adapter = SubpixelAdapter(
|
||||
self,
|
||||
sd=self.sd_ref(),
|
||||
config=self.config,
|
||||
train_config=self.train_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown adapter type: {self.adapter_type}")
|
||||
|
||||
@@ -284,7 +293,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
def setup_clip(self):
|
||||
adapter_config = self.config
|
||||
sd = self.sd_ref()
|
||||
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora"]:
|
||||
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel"]:
|
||||
return
|
||||
if self.config.type == 'photo_maker':
|
||||
try:
|
||||
@@ -502,6 +511,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
for k2, v2 in v.items():
|
||||
new_dict[k + '.' + k2] = v2
|
||||
self.control_lora.load_weights(new_dict, strict=strict)
|
||||
|
||||
if self.adapter_type == 'subpixel':
|
||||
# state dict is seperated. so recombine it
|
||||
new_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
for k2, v2 in v.items():
|
||||
new_dict[k + '.' + k2] = v2
|
||||
self.subpixel_adapter.load_weights(new_dict, strict=strict)
|
||||
|
||||
pass
|
||||
|
||||
@@ -558,6 +575,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
for k, v in d.items():
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
elif self.adapter_type == 'subpixel':
|
||||
d = self.subpixel_adapter.get_state_dict()
|
||||
for k, v in d.items():
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -702,7 +724,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
prompt: Union[List[str], str],
|
||||
is_unconditional: bool = False,
|
||||
):
|
||||
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora']:
|
||||
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel']:
|
||||
return prompt
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
# todo allow for training
|
||||
@@ -1225,6 +1247,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
param_list = self.control_lora.get_params()
|
||||
for param in param_list:
|
||||
yield param
|
||||
elif self.config.type == 'subpixel':
|
||||
param_list = self.subpixel_adapter.get_params()
|
||||
for param in param_list:
|
||||
yield param
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
Reference in New Issue
Block a user