mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
WIP on wan
This commit is contained in:
@@ -3,7 +3,38 @@ import torch
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanPipeline
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
import os
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
WAN_ROOT = os.path.join(REPOS_ROOT, "wan21")
|
||||
sys.path.append(WAN_ROOT)
|
||||
|
||||
if True:
|
||||
from wan.text2video import WanT2V
|
||||
from wan.distributed.fsdp import shard_model
|
||||
from wan.modules.model import WanModel
|
||||
from wan.modules.t5 import T5EncoderModel
|
||||
from wan.modules.vae import WanVAE
|
||||
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas, retrieve_timesteps)
|
||||
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
|
||||
class Wan21(BaseModel):
|
||||
def __init__(
|
||||
@@ -21,9 +52,9 @@ class Wan21(BaseModel):
|
||||
# these must be implemented in child classes
|
||||
|
||||
def load_model(self):
|
||||
# override this in child classes
|
||||
raise NotImplementedError(
|
||||
"load_model must be implemented in child classes")
|
||||
self.pipeline = Wan21(
|
||||
|
||||
)
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
# override this in child classes
|
||||
|
||||
Reference in New Issue
Block a user