mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 08:20:35 +00:00
* Base ace step 1.5 xl added. Generating, still wip on training and ui * Base training code done * Fix some issues with caching text embeddings. Update sample cards to show audio * Fix issue with quantizing ace step * Add album artwork to samples with waveform. * Cleanup logs * Add album art endpoint to speed up album art loading * Made an make video with artwork script * Make ui handle basic audio models. Make multi line adjustments to the editor and better syntax hilighting. * Add prompt tagging system for special tagged models. * prompt tagging processing for ui working. * Moved default samples to a special file so we can add more when needed and they can be adjusted for a specific model * Add a captioner job with music captioner that is prepped for use with the ui * Add basit ui setup for captioning modal and handeling captioning jobs * Starting captioning job from ui working. Still better management for it. * Better filtering of job options in the job view for captioning jobs * Added qwen3 vl as a captioner for images * Have an indicator when a dataset is being captioned. * Adjust the way caption jobs look in the queue * Fix a few issues. Adjust defaults. * Version bump * Added ace step to the readme.
100 lines
3.3 KiB
Python
100 lines
3.3 KiB
Python
import json
|
|
|
|
import torch
|
|
|
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
|
from toolkit.models.base_model import BaseModel
|
|
from toolkit.prompt_utils import PromptEmbeds
|
|
|
|
|
|
class BaseAudioModel(BaseModel):
|
|
sample_rate = 48000
|
|
|
|
def __init__(
|
|
self,
|
|
device,
|
|
model_config: ModelConfig,
|
|
dtype="bf16",
|
|
custom_pipeline=None,
|
|
noise_scheduler=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
|
|
)
|
|
self.is_audio_model = True
|
|
|
|
def generate_single_image(
|
|
self,
|
|
pipeline,
|
|
gen_config: GenerateImageConfig,
|
|
conditional_embeds: PromptEmbeds,
|
|
unconditional_embeds: PromptEmbeds,
|
|
generator: torch.Generator,
|
|
extra: dict,
|
|
):
|
|
# This is called on the base model. We override it to make it make more sense for audio models.
|
|
return self.generate_single_audio(
|
|
pipeline,
|
|
gen_config,
|
|
conditional_embeds,
|
|
unconditional_embeds,
|
|
generator,
|
|
extra,
|
|
)
|
|
|
|
def generate_single_audio(
|
|
self,
|
|
pipeline,
|
|
gen_config: GenerateImageConfig,
|
|
conditional_embeds: PromptEmbeds,
|
|
unconditional_embeds: PromptEmbeds,
|
|
generator: torch.Generator,
|
|
extra: dict,
|
|
):
|
|
# This is called on the base model. We override it to make it make more sense for audio models.
|
|
raise NotImplementedError(
|
|
"generate_single_audio is not implemented for this model"
|
|
)
|
|
|
|
def get_model_has_grad(self):
|
|
return False
|
|
|
|
def get_te_has_grad(self):
|
|
return False
|
|
|
|
def save_model(self, output_path, meta, save_dtype):
|
|
# we need to save the model, vae, text encoder, and tokenizer together since they are all trained together and depend on each other
|
|
raise NotImplementedError(
|
|
"save_model is not implemented for this model. Use the pipeline directly instead."
|
|
)
|
|
|
|
def convert_lora_weights_before_save(self, state_dict):
|
|
# currently starte with transformer. but needs to start with diffusion_model. for comfyui
|
|
new_sd = {}
|
|
for key, value in state_dict.items():
|
|
new_key = key.replace("transformer.", "diffusion_model.")
|
|
new_sd[new_key] = value
|
|
return new_sd
|
|
|
|
def convert_lora_weights_before_load(self, state_dict):
|
|
# saved as diffusion_model. but needs to be transformer. for ai-toolkit
|
|
new_sd = {}
|
|
for key, value in state_dict.items():
|
|
new_key = key.replace("diffusion_model.", "transformer.")
|
|
new_sd[new_key] = value
|
|
return new_sd
|
|
|
|
def encode_images(self, image_list: torch.Tensor, device=None, dtype=None):
|
|
# make it more obvious for audio models
|
|
return self.encode_audio(image_list, device=device, dtype=dtype)
|
|
|
|
def encode_audio(self, audio_tensor: torch.Tensor, device=None, dtype=None):
|
|
if device is None:
|
|
device = self.device_torch
|
|
if dtype is None:
|
|
dtype = self.torch_dtype
|
|
if self.vae.device == torch.device("cpu"):
|
|
self.vae.to(device)
|
|
return self.vae.encode(audio_tensor.to(device=device, dtype=dtype))
|