mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Work on mean flow. Minor bug fixes. Omnigen improvements
This commit is contained in:
@@ -239,7 +239,10 @@ class OmniGen2Model(BaseModel):
|
||||
**kwargs
|
||||
):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
||||
try:
|
||||
timestep = timestep.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# optional_kwargs = {}
|
||||
# if 'ref_image_hidden_states' in set(inspect.signature(self.model.forward).parameters.keys()):
|
||||
|
||||
@@ -305,14 +305,14 @@ class OmniGen2Pipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
|
||||
# untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because Gemma can only handle sequences up to"
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
# if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
# removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
# logger.warning(
|
||||
# "The following part of your input was truncated because Gemma can only handle sequences up to"
|
||||
# f" {max_sequence_length} tokens: {removed_text}"
|
||||
# )
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||
prompt_embeds = self.mllm(
|
||||
|
||||
Reference in New Issue
Block a user