mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-27 10:13:56 +00:00
unify cast name
This commit is contained in:
@@ -337,9 +337,9 @@ class LoadedModel:
|
||||
real_async_memory = 0
|
||||
mem_counter = 0
|
||||
for m in self.real_model.modules():
|
||||
if hasattr(m, "ldm_patched_cast_weights"):
|
||||
m.prev_ldm_patched_cast_weights = m.ldm_patched_cast_weights
|
||||
m.ldm_patched_cast_weights = True
|
||||
if hasattr(m, "parameters_manual_cast"):
|
||||
m.prev_parameters_manual_cast = m.parameters_manual_cast
|
||||
m.parameters_manual_cast = True
|
||||
module_mem = module_size(m)
|
||||
if mem_counter + module_mem < async_kept_memory:
|
||||
m.to(self.device)
|
||||
@@ -366,9 +366,9 @@ class LoadedModel:
|
||||
def model_unload(self, avoid_model_moving=False):
|
||||
if self.model_accelerated:
|
||||
for m in self.real_model.modules():
|
||||
if hasattr(m, "prev_ldm_patched_cast_weights"):
|
||||
m.ldm_patched_cast_weights = m.prev_ldm_patched_cast_weights
|
||||
del m.prev_ldm_patched_cast_weights
|
||||
if hasattr(m, "prev_parameters_manual_cast"):
|
||||
m.parameters_manual_cast = m.prev_parameters_manual_cast
|
||||
del m.prev_parameters_manual_cast
|
||||
|
||||
self.model_accelerated = False
|
||||
|
||||
|
||||
@@ -82,83 +82,83 @@ def cleanup_cache():
|
||||
|
||||
class disable_weight_init:
|
||||
class Linear(torch.nn.Linear):
|
||||
ldm_patched_cast_weights = False
|
||||
parameters_manual_cast = False
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
def forward_parameters_manual_cast(self, input):
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
return self.forward_ldm_patched_cast_weights(*args, **kwargs)
|
||||
if self.parameters_manual_cast:
|
||||
return self.forward_parameters_manual_cast(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
ldm_patched_cast_weights = False
|
||||
parameters_manual_cast = False
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
def forward_parameters_manual_cast(self, input):
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
return self.forward_ldm_patched_cast_weights(*args, **kwargs)
|
||||
if self.parameters_manual_cast:
|
||||
return self.forward_parameters_manual_cast(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
ldm_patched_cast_weights = False
|
||||
parameters_manual_cast = False
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
def forward_parameters_manual_cast(self, input):
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
return self.forward_ldm_patched_cast_weights(*args, **kwargs)
|
||||
if self.parameters_manual_cast:
|
||||
return self.forward_parameters_manual_cast(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class GroupNorm(torch.nn.GroupNorm):
|
||||
ldm_patched_cast_weights = False
|
||||
parameters_manual_cast = False
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
def forward_parameters_manual_cast(self, input):
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
return self.forward_ldm_patched_cast_weights(*args, **kwargs)
|
||||
if self.parameters_manual_cast:
|
||||
return self.forward_parameters_manual_cast(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
ldm_patched_cast_weights = False
|
||||
parameters_manual_cast = False
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
def forward_parameters_manual_cast(self, input):
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
return self.forward_ldm_patched_cast_weights(*args, **kwargs)
|
||||
if self.parameters_manual_cast:
|
||||
return self.forward_parameters_manual_cast(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@@ -174,16 +174,16 @@ class disable_weight_init:
|
||||
|
||||
class manual_cast(disable_weight_init):
|
||||
class Linear(disable_weight_init.Linear):
|
||||
ldm_patched_cast_weights = True
|
||||
parameters_manual_cast = True
|
||||
|
||||
class Conv2d(disable_weight_init.Conv2d):
|
||||
ldm_patched_cast_weights = True
|
||||
parameters_manual_cast = True
|
||||
|
||||
class Conv3d(disable_weight_init.Conv3d):
|
||||
ldm_patched_cast_weights = True
|
||||
parameters_manual_cast = True
|
||||
|
||||
class GroupNorm(disable_weight_init.GroupNorm):
|
||||
ldm_patched_cast_weights = True
|
||||
parameters_manual_cast = True
|
||||
|
||||
class LayerNorm(disable_weight_init.LayerNorm):
|
||||
ldm_patched_cast_weights = True
|
||||
parameters_manual_cast = True
|
||||
|
||||
Reference in New Issue
Block a user