diff --git a/backend/memory_management.py b/backend/memory_management.py index bad32392..a43268a2 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -405,7 +405,12 @@ class LoadedModel: mem_counter += module_mem else: memory_in_swap += module_mem + + if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda': + m.to(self.device) # Quantize happens here + m.to(self.model.offload_device) + if PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device): m._apply(lambda x: x.pin_memory()) elif hasattr(m, "weight"):