mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Added initial support for training i2v adapter WIP
This commit is contained in:
@@ -151,14 +151,14 @@ class NetworkConfig:
|
||||
self.lokr_factor = kwargs.get('lokr_factor', -1)
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora']
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']
|
||||
|
||||
CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state']
|
||||
|
||||
|
||||
class AdapterConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip, control_net
|
||||
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip, control_net, i2v
|
||||
self.in_channels: int = kwargs.get('in_channels', 3)
|
||||
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
|
||||
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
|
||||
@@ -255,6 +255,10 @@ class AdapterConfig:
|
||||
|
||||
# for subpixel adapter
|
||||
self.subpixel_downscale_factor: int = kwargs.get('subpixel_downscale_factor', 8)
|
||||
|
||||
# for i2v adapter
|
||||
# append the masked start frame. During pretraining we will only do the vision encoder
|
||||
self.i2v_do_start_frame: bool = kwargs.get('i2v_do_start_frame', False)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
@@ -955,6 +959,8 @@ class GenerateImageConfig:
|
||||
# video
|
||||
if self.num_frames == 1:
|
||||
raise ValueError(f"Expected 1 img but got a list {len(image)}")
|
||||
if self.num_frames > 1 and self.output_ext not in ['webp']:
|
||||
self.output_ext = 'webp'
|
||||
if self.output_ext == 'webp':
|
||||
# save as animated webp
|
||||
duration = 1000 // self.fps # Convert fps to milliseconds per frame
|
||||
@@ -1075,6 +1081,8 @@ class GenerateImageConfig:
|
||||
self.extra_values = [float(val) for val in content.split(',')]
|
||||
elif flag == 'frames':
|
||||
self.num_frames = int(content)
|
||||
elif flag == 'num_frames':
|
||||
self.num_frames = int(content)
|
||||
elif flag == 'fps':
|
||||
self.fps = int(content)
|
||||
elif flag == 'ctrl_img':
|
||||
|
||||
Reference in New Issue
Block a user