WIP on wan

This commit is contained in:
Jaret Burkett
2025-03-01 16:12:52 -07:00
parent acc79956aa
commit f5e40dfa62
4 changed files with 45 additions and 6 deletions

View File

@@ -3,7 +3,38 @@ import torch
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanPipeline
from toolkit.paths import REPOS_ROOT
import sys
import os
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
WAN_ROOT = os.path.join(REPOS_ROOT, "wan21")
sys.path.append(WAN_ROOT)
if True:
from wan.text2video import WanT2V
from wan.distributed.fsdp import shard_model
from wan.modules.model import WanModel
from wan.modules.t5 import T5EncoderModel
from wan.modules.vae import WanVAE
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
class Wan21(BaseModel):
def __init__(
@@ -21,9 +52,9 @@ class Wan21(BaseModel):
# these must be implemented in child classes
def load_model(self):
# override this in child classes
raise NotImplementedError(
"load_model must be implemented in child classes")
self.pipeline = Wan21(
)
def get_generation_pipeline(self):
# override this in child classes