Added training for pixart-a

This commit is contained in:
Jaret Burkett
2024-02-13 16:00:04 -07:00
parent 4ec4025cbb
commit 93b52932c1
10 changed files with 288 additions and 24 deletions

View File

@@ -13,7 +13,7 @@ from toolkit.models.te_adapter import TEAdapter
from toolkit.models.vd_adapter import VisionDirectAdapter
from toolkit.paths import REPOS_ROOT
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
from toolkit.saving import load_ip_adapter_model
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
@@ -99,6 +99,13 @@ class CustomAdapter(torch.nn.Module):
tokenizer.add_tokens([self.flag_word], special_tokens=True)
else:
self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True)
elif self.config.name_or_path is not None:
loaded_state_dict = load_custom_adapter_model(
self.config.name_or_path,
self.sd_ref().device,
dtype=self.sd_ref().dtype,
)
self.load_state_dict(loaded_state_dict, strict=False)
def setup_adapter(self):
if self.adapter_type == 'photo_maker':
@@ -287,6 +294,9 @@ class CustomAdapter(torch.nn.Module):
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
strict = False
if self.config.train_only_image_encoder and 'vd_adapter' not in state_dict and 'dvadapter' not in state_dict:
# we are loading pure clip weights.
self.vision_encoder.load_state_dict(state_dict, strict=strict)
if 'lora_weights' in state_dict:
# todo add LoRA
@@ -332,6 +342,8 @@ class CustomAdapter(torch.nn.Module):
if 'vd_adapter' in state_dict:
self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict)
if 'dvadapter' in state_dict:
self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=strict)
if 'vision_encoder' in state_dict and self.config.train_image_encoder:
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
@@ -346,6 +358,9 @@ class CustomAdapter(torch.nn.Module):
def state_dict(self) -> OrderedDict:
state_dict = OrderedDict()
if self.config.train_only_image_encoder:
return self.vision_encoder.state_dict()
if self.adapter_type == 'photo_maker':
if self.config.train_image_encoder:
state_dict["id_encoder"] = self.vision_encoder.state_dict()
@@ -364,7 +379,9 @@ class CustomAdapter(torch.nn.Module):
state_dict["te_adapter"] = self.te_adapter.state_dict()
return state_dict
elif self.adapter_type == 'vision_direct':
state_dict["vd_adapter"] = self.vd_adapter.state_dict()
state_dict["dvadapter"] = self.vd_adapter.state_dict()
if self.config.train_image_encoder:
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
return state_dict
elif self.adapter_type == 'ilora':
if self.config.train_image_encoder:
@@ -617,6 +634,12 @@ class CustomAdapter(torch.nn.Module):
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
return clip_image.detach()
def train(self, mode: bool = True):
if self.config.train_image_encoder:
self.vision_encoder.train(mode)
else:
super().train(mode)
def trigger_pre_te(
self,
tensors_0_1: torch.Tensor,
@@ -735,6 +758,9 @@ class CustomAdapter(torch.nn.Module):
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.train_only_image_encoder:
yield from self.vision_encoder.parameters(recurse)
return
if self.config.type == 'photo_maker':
yield from self.fuse_module.parameters(recurse)
if self.config.train_image_encoder:
@@ -753,5 +779,13 @@ class CustomAdapter(torch.nn.Module):
elif self.config.type == 'vision_direct':
for attn_processor in self.vd_adapter.adapter_modules:
yield from attn_processor.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
else:
raise NotImplementedError
def enable_gradient_checkpointing(self):
if hasattr(self.vision_encoder, "enable_gradient_checkpointing"):
self.vision_encoder.enable_gradient_checkpointing()
elif hasattr(self.vision_encoder, 'gradient_checkpointing'):
self.vision_encoder.gradient_checkpointing = True