Added initial support for Hidream E1 training

This commit is contained in:
Jaret Burkett
2025-07-27 15:12:56 -06:00
parent 3f518d9951
commit cefa2ca5fe
6 changed files with 1410 additions and 7 deletions

View File

@@ -52,6 +52,8 @@ BASE_MODEL_PATH = "HiDream-ai/HiDream-I1-Full"
class HidreamModel(BaseModel):
arch = "hidream"
hidream_transformer_class = HiDreamImageTransformer2DModel
hidream_pipeline_class = HiDreamImagePipeline
def __init__(
self,
@@ -123,7 +125,7 @@ class HidreamModel(BaseModel):
self.print_and_status_update("Loading transformer")
transformer = HiDreamImageTransformer2DModel.from_pretrained(
transformer = self.hidream_transformer_class.from_pretrained(
model_path,
subfolder="transformer",
torch_dtype=torch.bfloat16
@@ -216,7 +218,7 @@ class HidreamModel(BaseModel):
flush()
if self.low_vram:
self.print_and_status_update("Moving ecerything to device")
self.print_and_status_update("Moving everything to device")
# move it all back
transformer.to(self.device_torch, dtype=dtype)
vae.to(self.device_torch, dtype=dtype)
@@ -233,7 +235,7 @@ class HidreamModel(BaseModel):
text_encoder_4.eval()
text_encoder_3.eval()
pipe = HiDreamImagePipeline(
pipe = self.hidream_pipeline_class(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,