mirror of
https://github.com/wildminder/ComfyUI-VibeVoice.git
synced 2026-05-01 04:01:37 +00:00
Transformers 4.56+ Compatibility & Force Offload Fix
This commit is contained in:
@@ -300,7 +300,23 @@ class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, Gener
|
||||
)
|
||||
|
||||
max_cache_length = generation_config.max_length - 1
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
|
||||
# Backwards compatible fix for _prepare_cache_for_generation method signature
|
||||
# New transformers version expects 5 args, old version expects 6
|
||||
import inspect
|
||||
try:
|
||||
sig = inspect.signature(self._prepare_cache_for_generation)
|
||||
if len(sig.parameters) == 5:
|
||||
# New transformers version (4.56+)
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, batch_size, max_cache_length, device)
|
||||
else:
|
||||
# Old transformers version (pre-4.56)
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
|
||||
except Exception as e:
|
||||
# Fallback to try both versions
|
||||
try:
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, batch_size, max_cache_length, device)
|
||||
except TypeError:
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
|
||||
model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
|
||||
for k, v in model_kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
@@ -551,8 +567,8 @@ class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, Gener
|
||||
negative_model_kwargs['attention_mask'][sample_idx, :] = 0
|
||||
negative_model_kwargs['attention_mask'][sample_idx, -1] = 1
|
||||
# update past key values
|
||||
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
|
||||
negative_model_kwargs['past_key_values'].value_cache)):
|
||||
for layer_idx in range(len(negative_model_kwargs['past_key_values'])):
|
||||
k_cache, v_cache = negative_model_kwargs['past_key_values'][layer_idx]
|
||||
# Process each non-diffusion sample
|
||||
for sample_idx in diffusion_start_indices.tolist():
|
||||
# Shift cache for this sample
|
||||
@@ -604,8 +620,8 @@ class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, Gener
|
||||
negative_model_kwargs['attention_mask'][sample_idx, start_idx] = 0
|
||||
|
||||
# 2. Update past_key_values
|
||||
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
|
||||
negative_model_kwargs['past_key_values'].value_cache)):
|
||||
for layer_idx in range(len(negative_model_kwargs['past_key_values'])):
|
||||
k_cache, v_cache = negative_model_kwargs['past_key_values'][layer_idx]
|
||||
# Process each non-diffusion sample
|
||||
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
|
||||
if start_idx + 1 < k_cache.shape[2] - 1:
|
||||
|
||||
Reference in New Issue
Block a user