Transformers 4.56+ Compatibility & Force Offload Fix

This commit is contained in:
drbaph
2025-09-01 19:26:59 +01:00
parent fee5f78cc9
commit f565f123c6
3 changed files with 71 additions and 6 deletions

View File

@@ -300,7 +300,23 @@ class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, Gener
)
max_cache_length = generation_config.max_length - 1
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
# Backwards compatible fix for _prepare_cache_for_generation method signature
# New transformers version expects 5 args, old version expects 6
import inspect
try:
sig = inspect.signature(self._prepare_cache_for_generation)
if len(sig.parameters) == 5:
# New transformers version (4.56+)
self._prepare_cache_for_generation(generation_config, model_kwargs, batch_size, max_cache_length, device)
else:
# Old transformers version (pre-4.56)
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
except Exception as e:
# Fallback to try both versions
try:
self._prepare_cache_for_generation(generation_config, model_kwargs, batch_size, max_cache_length, device)
except TypeError:
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
for k, v in model_kwargs.items():
if isinstance(v, torch.Tensor):
@@ -551,8 +567,8 @@ class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, Gener
negative_model_kwargs['attention_mask'][sample_idx, :] = 0
negative_model_kwargs['attention_mask'][sample_idx, -1] = 1
# update past key values
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
negative_model_kwargs['past_key_values'].value_cache)):
for layer_idx in range(len(negative_model_kwargs['past_key_values'])):
k_cache, v_cache = negative_model_kwargs['past_key_values'][layer_idx]
# Process each non-diffusion sample
for sample_idx in diffusion_start_indices.tolist():
# Shift cache for this sample
@@ -604,8 +620,8 @@ class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, Gener
negative_model_kwargs['attention_mask'][sample_idx, start_idx] = 0
# 2. Update past_key_values
for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache,
negative_model_kwargs['past_key_values'].value_cache)):
for layer_idx in range(len(negative_model_kwargs['past_key_values'])):
k_cache, v_cache = negative_model_kwargs['past_key_values'][layer_idx]
# Process each non-diffusion sample
for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()):
if start_idx + 1 < k_cache.shape[2] - 1: