mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Add LTX-2 Support (#644)
* WIP, adding support for LTX2 * Training on images working * Fix loading comfy models * Handle converting and deconverting lora so it matches original format * Reworked ui to habdle ltx and propert dataset default overwriting. * Update the way lokr saves to it is more compatable with comfy * Audio loading and synchronization/resampling is working * Add audio to training. Does it work? Maybe, still testing. * Fixed fps default issue for sound * Have ui set fps for accurate audio mapping on ltx * Added audio procession options to the ui for ltx * Clean up requirements
This commit is contained in:
@@ -123,9 +123,13 @@ class FileItemDTO(
|
||||
self.is_reg = self.dataset_config.is_reg
|
||||
self.prior_reg = self.dataset_config.prior_reg
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.audio_data = None
|
||||
self.audio_tensor = None
|
||||
|
||||
def cleanup(self):
|
||||
self.tensor = None
|
||||
self.audio_data = None
|
||||
self.audio_tensor = None
|
||||
self.cleanup_latent()
|
||||
self.cleanup_text_embedding()
|
||||
self.cleanup_control()
|
||||
@@ -154,6 +158,13 @@ class DataLoaderBatchDTO:
|
||||
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
|
||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
||||
self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None
|
||||
self.audio_data: Union[List, None] = [x.audio_data for x in self.file_items] if self.file_items[0].audio_data is not None else None
|
||||
self.audio_tensor: Union[torch.Tensor, None] = None
|
||||
|
||||
# just for holding noise and preds during training
|
||||
self.audio_target: Union[torch.Tensor, None] = None
|
||||
self.audio_pred: Union[torch.Tensor, None] = None
|
||||
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
@@ -304,6 +315,21 @@ class DataLoaderBatchDTO:
|
||||
y.text_embeds = [y.text_embeds]
|
||||
prompt_embeds_list.append(y)
|
||||
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
|
||||
|
||||
if any([x.audio_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_audio_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.audio_tensor is not None:
|
||||
base_audio_tensor = x.audio_tensor
|
||||
break
|
||||
audio_tensors = []
|
||||
for x in self.file_items:
|
||||
if x.audio_tensor is None:
|
||||
audio_tensors.append(torch.zeros_like(base_audio_tensor))
|
||||
else:
|
||||
audio_tensors.append(x.audio_tensor)
|
||||
self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors])
|
||||
|
||||
|
||||
except Exception as e:
|
||||
@@ -336,6 +362,10 @@ class DataLoaderBatchDTO:
|
||||
del self.latents
|
||||
del self.tensor
|
||||
del self.control_tensor
|
||||
del self.audio_tensor
|
||||
del self.audio_data
|
||||
del self.audio_target
|
||||
del self.audio_pred
|
||||
for file_item in self.file_items:
|
||||
file_item.cleanup()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user