mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-23 07:49:24 +00:00
Added initial support for Hidream E1 training
This commit is contained in:
@@ -52,6 +52,8 @@ BASE_MODEL_PATH = "HiDream-ai/HiDream-I1-Full"
|
||||
|
||||
class HidreamModel(BaseModel):
|
||||
arch = "hidream"
|
||||
hidream_transformer_class = HiDreamImageTransformer2DModel
|
||||
hidream_pipeline_class = HiDreamImagePipeline
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -123,7 +125,7 @@ class HidreamModel(BaseModel):
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
|
||||
transformer = HiDreamImageTransformer2DModel.from_pretrained(
|
||||
transformer = self.hidream_transformer_class.from_pretrained(
|
||||
model_path,
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16
|
||||
@@ -216,7 +218,7 @@ class HidreamModel(BaseModel):
|
||||
flush()
|
||||
|
||||
if self.low_vram:
|
||||
self.print_and_status_update("Moving ecerything to device")
|
||||
self.print_and_status_update("Moving everything to device")
|
||||
# move it all back
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
vae.to(self.device_torch, dtype=dtype)
|
||||
@@ -233,7 +235,7 @@ class HidreamModel(BaseModel):
|
||||
text_encoder_4.eval()
|
||||
text_encoder_3.eval()
|
||||
|
||||
pipe = HiDreamImagePipeline(
|
||||
pipe = self.hidream_pipeline_class(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
||||
Reference in New Issue
Block a user