mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-06 21:49:57 +00:00
88 lines
2.5 KiB
Python
88 lines
2.5 KiB
Python
|
|
import torch
|
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
|
from toolkit.models.base_model import BaseModel
|
|
from toolkit.prompt_utils import PromptEmbeds
|
|
from toolkit.paths import REPOS_ROOT
|
|
import sys
|
|
import os
|
|
|
|
import gc
|
|
import logging
|
|
import math
|
|
import os
|
|
import random
|
|
import sys
|
|
import types
|
|
from contextlib import contextmanager
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.cuda.amp as amp
|
|
import torch.distributed as dist
|
|
from tqdm import tqdm
|
|
|
|
|
|
WAN_ROOT = os.path.join(REPOS_ROOT, "wan21")
|
|
sys.path.append(WAN_ROOT)
|
|
|
|
if True:
|
|
from wan.text2video import WanT2V
|
|
from wan.distributed.fsdp import shard_model
|
|
from wan.modules.model import WanModel
|
|
from wan.modules.t5 import T5EncoderModel
|
|
from wan.modules.vae import WanVAE
|
|
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
|
get_sampling_sigmas, retrieve_timesteps)
|
|
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
|
|
|
class Wan21(BaseModel):
|
|
def __init__(
|
|
self,
|
|
device,
|
|
model_config: ModelConfig,
|
|
dtype='bf16',
|
|
custom_pipeline=None,
|
|
noise_scheduler=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(device, model_config, dtype,
|
|
custom_pipeline, noise_scheduler, **kwargs)
|
|
self.is_flow_matching = True
|
|
# these must be implemented in child classes
|
|
|
|
def load_model(self):
|
|
self.pipeline = Wan21(
|
|
|
|
)
|
|
|
|
def get_generation_pipeline(self):
|
|
# override this in child classes
|
|
raise NotImplementedError(
|
|
"get_generation_pipeline must be implemented in child classes")
|
|
|
|
def generate_single_image(
|
|
self,
|
|
gen_config: GenerateImageConfig,
|
|
conditional_embeds: PromptEmbeds,
|
|
unconditional_embeds: PromptEmbeds,
|
|
generator: torch.Generator,
|
|
extra: dict,
|
|
):
|
|
# override this in child classes
|
|
raise NotImplementedError(
|
|
"generate_single_image must be implemented in child classes")
|
|
|
|
def get_noise_prediction(
|
|
latent_model_input: torch.Tensor,
|
|
timestep: torch.Tensor, # 0 to 1000 scale
|
|
text_embeddings: PromptEmbeds,
|
|
**kwargs
|
|
):
|
|
raise NotImplementedError(
|
|
"get_noise_prediction must be implemented in child classes")
|
|
|
|
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
|
raise NotImplementedError(
|
|
"get_prompt_embeds must be implemented in child classes")
|