Work on mean flow. Minor bug fixes. Omnigen improvements

This commit is contained in:
Jaret Burkett
2025-06-26 13:46:20 -06:00
parent 84c6edca7e
commit 8d9c47316a
4 changed files with 128 additions and 95 deletions

View File

@@ -239,7 +239,10 @@ class OmniGen2Model(BaseModel):
**kwargs
):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
try:
timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
except Exception as e:
pass
# optional_kwargs = {}
# if 'ref_image_hidden_states' in set(inspect.signature(self.model.forward).parameters.keys()):

View File

@@ -305,14 +305,14 @@ class OmniGen2Pipeline(DiffusionPipeline):
)
text_input_ids = text_inputs.input_ids.to(device)
untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
# untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because Gemma can only handle sequences up to"
f" {max_sequence_length} tokens: {removed_text}"
)
# if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
# removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
# logger.warning(
# "The following part of your input was truncated because Gemma can only handle sequences up to"
# f" {max_sequence_length} tokens: {removed_text}"
# )
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = self.mllm(