mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-26 07:13:57 +00:00
Refactor qwen5b model code to be qwen 5b specific
This commit is contained in:
@@ -3,7 +3,7 @@ from .hidream import HidreamModel, HidreamE1Model
|
||||
from .f_light import FLiteModel
|
||||
from .omnigen2 import OmniGen2Model
|
||||
from .flux_kontext import FluxKontextModel
|
||||
from .wan22 import Wan22Model
|
||||
from .wan22 import Wan225bModel
|
||||
from .qwen_image import QwenImageModel
|
||||
|
||||
AI_TOOLKIT_MODELS = [
|
||||
@@ -14,6 +14,6 @@ AI_TOOLKIT_MODELS = [
|
||||
FLiteModel,
|
||||
OmniGen2Model,
|
||||
FluxKontextModel,
|
||||
Wan22Model,
|
||||
Wan225bModel,
|
||||
QwenImageModel,
|
||||
]
|
||||
|
||||
@@ -136,7 +136,7 @@ class QwenImageModel(BaseModel):
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing Text Encoder")
|
||||
quantize(text_encoder, weights=get_qtype(
|
||||
self.model_config.qtype))
|
||||
self.model_config.qtype_te))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .wan22_model import Wan22Model
|
||||
from .wan22_5b_model import Wan225bModel
|
||||
@@ -80,7 +80,7 @@ def time_text_monkeypatch(
|
||||
|
||||
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
|
||||
|
||||
class Wan22Model(Wan21):
|
||||
class Wan225bModel(Wan21):
|
||||
arch = "wan22_5b"
|
||||
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||
_wan_expand_timesteps = True
|
||||
Reference in New Issue
Block a user