mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 00:39:22 +00:00
Added training for pixart-a
This commit is contained in:
@@ -13,7 +13,7 @@ from toolkit.models.te_adapter import TEAdapter
|
||||
from toolkit.models.vd_adapter import VisionDirectAdapter
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||
from toolkit.saving import load_ip_adapter_model
|
||||
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
@@ -99,6 +99,13 @@ class CustomAdapter(torch.nn.Module):
|
||||
tokenizer.add_tokens([self.flag_word], special_tokens=True)
|
||||
else:
|
||||
self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True)
|
||||
elif self.config.name_or_path is not None:
|
||||
loaded_state_dict = load_custom_adapter_model(
|
||||
self.config.name_or_path,
|
||||
self.sd_ref().device,
|
||||
dtype=self.sd_ref().dtype,
|
||||
)
|
||||
self.load_state_dict(loaded_state_dict, strict=False)
|
||||
|
||||
def setup_adapter(self):
|
||||
if self.adapter_type == 'photo_maker':
|
||||
@@ -287,6 +294,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
strict = False
|
||||
if self.config.train_only_image_encoder and 'vd_adapter' not in state_dict and 'dvadapter' not in state_dict:
|
||||
# we are loading pure clip weights.
|
||||
self.vision_encoder.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
if 'lora_weights' in state_dict:
|
||||
# todo add LoRA
|
||||
@@ -332,6 +342,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
if 'vd_adapter' in state_dict:
|
||||
self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict)
|
||||
if 'dvadapter' in state_dict:
|
||||
self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=strict)
|
||||
|
||||
if 'vision_encoder' in state_dict and self.config.train_image_encoder:
|
||||
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
|
||||
@@ -346,6 +358,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
def state_dict(self) -> OrderedDict:
|
||||
state_dict = OrderedDict()
|
||||
if self.config.train_only_image_encoder:
|
||||
return self.vision_encoder.state_dict()
|
||||
|
||||
if self.adapter_type == 'photo_maker':
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["id_encoder"] = self.vision_encoder.state_dict()
|
||||
@@ -364,7 +379,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
state_dict["te_adapter"] = self.te_adapter.state_dict()
|
||||
return state_dict
|
||||
elif self.adapter_type == 'vision_direct':
|
||||
state_dict["vd_adapter"] = self.vd_adapter.state_dict()
|
||||
state_dict["dvadapter"] = self.vd_adapter.state_dict()
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
|
||||
return state_dict
|
||||
elif self.adapter_type == 'ilora':
|
||||
if self.config.train_image_encoder:
|
||||
@@ -617,6 +634,12 @@ class CustomAdapter(torch.nn.Module):
|
||||
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
|
||||
return clip_image.detach()
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
if self.config.train_image_encoder:
|
||||
self.vision_encoder.train(mode)
|
||||
else:
|
||||
super().train(mode)
|
||||
|
||||
def trigger_pre_te(
|
||||
self,
|
||||
tensors_0_1: torch.Tensor,
|
||||
@@ -735,6 +758,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.train_only_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
return
|
||||
if self.config.type == 'photo_maker':
|
||||
yield from self.fuse_module.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
@@ -753,5 +779,13 @@ class CustomAdapter(torch.nn.Module):
|
||||
elif self.config.type == 'vision_direct':
|
||||
for attn_processor in self.vd_adapter.adapter_modules:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
if hasattr(self.vision_encoder, "enable_gradient_checkpointing"):
|
||||
self.vision_encoder.enable_gradient_checkpointing()
|
||||
elif hasattr(self.vision_encoder, 'gradient_checkpointing'):
|
||||
self.vision_encoder.gradient_checkpointing = True
|
||||
|
||||
Reference in New Issue
Block a user