From abd4d4d83dbd91f3210f2383844e5b86b8c7535f Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Tue, 30 Jul 2024 08:42:51 -0600 Subject: [PATCH] unify cast name --- ldm_patched/modules/model_management.py | 12 +++--- ldm_patched/modules/ops.py | 50 ++++++++++++------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index 0fb26eb9..39e594c0 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -337,9 +337,9 @@ class LoadedModel: real_async_memory = 0 mem_counter = 0 for m in self.real_model.modules(): - if hasattr(m, "ldm_patched_cast_weights"): - m.prev_ldm_patched_cast_weights = m.ldm_patched_cast_weights - m.ldm_patched_cast_weights = True + if hasattr(m, "parameters_manual_cast"): + m.prev_parameters_manual_cast = m.parameters_manual_cast + m.parameters_manual_cast = True module_mem = module_size(m) if mem_counter + module_mem < async_kept_memory: m.to(self.device) @@ -366,9 +366,9 @@ class LoadedModel: def model_unload(self, avoid_model_moving=False): if self.model_accelerated: for m in self.real_model.modules(): - if hasattr(m, "prev_ldm_patched_cast_weights"): - m.ldm_patched_cast_weights = m.prev_ldm_patched_cast_weights - del m.prev_ldm_patched_cast_weights + if hasattr(m, "prev_parameters_manual_cast"): + m.parameters_manual_cast = m.prev_parameters_manual_cast + del m.prev_parameters_manual_cast self.model_accelerated = False diff --git a/ldm_patched/modules/ops.py b/ldm_patched/modules/ops.py index beb8f265..286f83a5 100644 --- a/ldm_patched/modules/ops.py +++ b/ldm_patched/modules/ops.py @@ -82,83 +82,83 @@ def cleanup_cache(): class disable_weight_init: class Linear(torch.nn.Linear): - ldm_patched_cast_weights = False + parameters_manual_cast = False def reset_parameters(self): return None - def forward_ldm_patched_cast_weights(self, input): + def forward_parameters_manual_cast(self, input): weight, bias, signal = cast_bias_weight(self, input) with main_stream_worker(weight, bias, signal): return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): - if self.ldm_patched_cast_weights: - return self.forward_ldm_patched_cast_weights(*args, **kwargs) + if self.parameters_manual_cast: + return self.forward_parameters_manual_cast(*args, **kwargs) else: return super().forward(*args, **kwargs) class Conv2d(torch.nn.Conv2d): - ldm_patched_cast_weights = False + parameters_manual_cast = False def reset_parameters(self): return None - def forward_ldm_patched_cast_weights(self, input): + def forward_parameters_manual_cast(self, input): weight, bias, signal = cast_bias_weight(self, input) with main_stream_worker(weight, bias, signal): return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): - if self.ldm_patched_cast_weights: - return self.forward_ldm_patched_cast_weights(*args, **kwargs) + if self.parameters_manual_cast: + return self.forward_parameters_manual_cast(*args, **kwargs) else: return super().forward(*args, **kwargs) class Conv3d(torch.nn.Conv3d): - ldm_patched_cast_weights = False + parameters_manual_cast = False def reset_parameters(self): return None - def forward_ldm_patched_cast_weights(self, input): + def forward_parameters_manual_cast(self, input): weight, bias, signal = cast_bias_weight(self, input) with main_stream_worker(weight, bias, signal): return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): - if self.ldm_patched_cast_weights: - return self.forward_ldm_patched_cast_weights(*args, **kwargs) + if self.parameters_manual_cast: + return self.forward_parameters_manual_cast(*args, **kwargs) else: return super().forward(*args, **kwargs) class GroupNorm(torch.nn.GroupNorm): - ldm_patched_cast_weights = False + parameters_manual_cast = False def reset_parameters(self): return None - def forward_ldm_patched_cast_weights(self, input): + def forward_parameters_manual_cast(self, input): weight, bias, signal = cast_bias_weight(self, input) with main_stream_worker(weight, bias, signal): return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def forward(self, *args, **kwargs): - if self.ldm_patched_cast_weights: - return self.forward_ldm_patched_cast_weights(*args, **kwargs) + if self.parameters_manual_cast: + return self.forward_parameters_manual_cast(*args, **kwargs) else: return super().forward(*args, **kwargs) class LayerNorm(torch.nn.LayerNorm): - ldm_patched_cast_weights = False + parameters_manual_cast = False def reset_parameters(self): return None - def forward_ldm_patched_cast_weights(self, input): + def forward_parameters_manual_cast(self, input): weight, bias, signal = cast_bias_weight(self, input) with main_stream_worker(weight, bias, signal): return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) def forward(self, *args, **kwargs): - if self.ldm_patched_cast_weights: - return self.forward_ldm_patched_cast_weights(*args, **kwargs) + if self.parameters_manual_cast: + return self.forward_parameters_manual_cast(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -174,16 +174,16 @@ class disable_weight_init: class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): - ldm_patched_cast_weights = True + parameters_manual_cast = True class Conv2d(disable_weight_init.Conv2d): - ldm_patched_cast_weights = True + parameters_manual_cast = True class Conv3d(disable_weight_init.Conv3d): - ldm_patched_cast_weights = True + parameters_manual_cast = True class GroupNorm(disable_weight_init.GroupNorm): - ldm_patched_cast_weights = True + parameters_manual_cast = True class LayerNorm(disable_weight_init.LayerNorm): - ldm_patched_cast_weights = True + parameters_manual_cast = True