Pixel shuffle adapter. Some bug fixes thrown in

This commit is contained in:
Jaret Burkett
2025-03-29 21:15:01 -06:00
parent b94d7aafea
commit 860d892214
10 changed files with 594 additions and 11 deletions

View File

@@ -11,6 +11,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.models.clip_fusion import CLIPFusionModule
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.models.control_lora_adapter import ControlLoraAdapter
from toolkit.models.subpixel_adapter import SubpixelAdapter
from toolkit.models.ilora import InstantLoRAModule
from toolkit.models.single_value_adapter import SingleValueAdapter
from toolkit.models.te_adapter import TEAdapter
@@ -103,6 +104,7 @@ class CustomAdapter(torch.nn.Module):
self.single_value_adapter: SingleValueAdapter = None
self.redux_adapter: ReduxImageEncoder = None
self.control_lora: ControlLoraAdapter = None
self.subpixel_adapter: SubpixelAdapter = None
self.conditional_embeds: Optional[torch.Tensor] = None
self.unconditional_embeds: Optional[torch.Tensor] = None
@@ -253,6 +255,13 @@ class CustomAdapter(torch.nn.Module):
config=self.config,
train_config=self.train_config
)
elif self.adapter_type == 'subpixel':
self.subpixel_adapter = SubpixelAdapter(
self,
sd=self.sd_ref(),
config=self.config,
train_config=self.train_config
)
else:
raise ValueError(f"unknown adapter type: {self.adapter_type}")
@@ -284,7 +293,7 @@ class CustomAdapter(torch.nn.Module):
def setup_clip(self):
adapter_config = self.config
sd = self.sd_ref()
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora"]:
if self.config.type in ["text_encoder", "llm_adapter", "single_value", "control_lora", "subpixel"]:
return
if self.config.type == 'photo_maker':
try:
@@ -502,6 +511,14 @@ class CustomAdapter(torch.nn.Module):
for k2, v2 in v.items():
new_dict[k + '.' + k2] = v2
self.control_lora.load_weights(new_dict, strict=strict)
if self.adapter_type == 'subpixel':
# state dict is seperated. so recombine it
new_dict = {}
for k, v in state_dict.items():
for k2, v2 in v.items():
new_dict[k + '.' + k2] = v2
self.subpixel_adapter.load_weights(new_dict, strict=strict)
pass
@@ -558,6 +575,11 @@ class CustomAdapter(torch.nn.Module):
for k, v in d.items():
state_dict[k] = v
return state_dict
elif self.adapter_type == 'subpixel':
d = self.subpixel_adapter.get_state_dict()
for k, v in d.items():
state_dict[k] = v
return state_dict
else:
raise NotImplementedError
@@ -702,7 +724,7 @@ class CustomAdapter(torch.nn.Module):
prompt: Union[List[str], str],
is_unconditional: bool = False,
):
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora']:
if self.adapter_type in ['clip_fusion', 'ilora', 'vision_direct', 'redux', 'control_lora', 'subpixel']:
return prompt
elif self.adapter_type == 'text_encoder':
# todo allow for training
@@ -1225,6 +1247,10 @@ class CustomAdapter(torch.nn.Module):
param_list = self.control_lora.get_params()
for param in param_list:
yield param
elif self.config.type == 'subpixel':
param_list = self.subpixel_adapter.get_params()
for param in param_list:
yield param
else:
raise NotImplementedError