unify cast name

This commit is contained in:
layerdiffusion
2024-07-30 08:42:51 -06:00
parent 9a48c9eff3
commit abd4d4d83d
2 changed files with 31 additions and 31 deletions

View File

@@ -337,9 +337,9 @@ class LoadedModel:
real_async_memory = 0
mem_counter = 0
for m in self.real_model.modules():
if hasattr(m, "ldm_patched_cast_weights"):
m.prev_ldm_patched_cast_weights = m.ldm_patched_cast_weights
m.ldm_patched_cast_weights = True
if hasattr(m, "parameters_manual_cast"):
m.prev_parameters_manual_cast = m.parameters_manual_cast
m.parameters_manual_cast = True
module_mem = module_size(m)
if mem_counter + module_mem < async_kept_memory:
m.to(self.device)
@@ -366,9 +366,9 @@ class LoadedModel:
def model_unload(self, avoid_model_moving=False):
if self.model_accelerated:
for m in self.real_model.modules():
if hasattr(m, "prev_ldm_patched_cast_weights"):
m.ldm_patched_cast_weights = m.prev_ldm_patched_cast_weights
del m.prev_ldm_patched_cast_weights
if hasattr(m, "prev_parameters_manual_cast"):
m.parameters_manual_cast = m.prev_parameters_manual_cast
del m.prev_parameters_manual_cast
self.model_accelerated = False

View File

@@ -82,83 +82,83 @@ def cleanup_cache():
class disable_weight_init:
class Linear(torch.nn.Linear):
ldm_patched_cast_weights = False
parameters_manual_cast = False
def reset_parameters(self):
return None
def forward_ldm_patched_cast_weights(self, input):
def forward_parameters_manual_cast(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:
return self.forward_ldm_patched_cast_weights(*args, **kwargs)
if self.parameters_manual_cast:
return self.forward_parameters_manual_cast(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv2d(torch.nn.Conv2d):
ldm_patched_cast_weights = False
parameters_manual_cast = False
def reset_parameters(self):
return None
def forward_ldm_patched_cast_weights(self, input):
def forward_parameters_manual_cast(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_stream_worker(weight, bias, signal):
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:
return self.forward_ldm_patched_cast_weights(*args, **kwargs)
if self.parameters_manual_cast:
return self.forward_parameters_manual_cast(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv3d(torch.nn.Conv3d):
ldm_patched_cast_weights = False
parameters_manual_cast = False
def reset_parameters(self):
return None
def forward_ldm_patched_cast_weights(self, input):
def forward_parameters_manual_cast(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_stream_worker(weight, bias, signal):
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:
return self.forward_ldm_patched_cast_weights(*args, **kwargs)
if self.parameters_manual_cast:
return self.forward_parameters_manual_cast(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class GroupNorm(torch.nn.GroupNorm):
ldm_patched_cast_weights = False
parameters_manual_cast = False
def reset_parameters(self):
return None
def forward_ldm_patched_cast_weights(self, input):
def forward_parameters_manual_cast(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:
return self.forward_ldm_patched_cast_weights(*args, **kwargs)
if self.parameters_manual_cast:
return self.forward_parameters_manual_cast(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm):
ldm_patched_cast_weights = False
parameters_manual_cast = False
def reset_parameters(self):
return None
def forward_ldm_patched_cast_weights(self, input):
def forward_parameters_manual_cast(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:
return self.forward_ldm_patched_cast_weights(*args, **kwargs)
if self.parameters_manual_cast:
return self.forward_parameters_manual_cast(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -174,16 +174,16 @@ class disable_weight_init:
class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear):
ldm_patched_cast_weights = True
parameters_manual_cast = True
class Conv2d(disable_weight_init.Conv2d):
ldm_patched_cast_weights = True
parameters_manual_cast = True
class Conv3d(disable_weight_init.Conv3d):
ldm_patched_cast_weights = True
parameters_manual_cast = True
class GroupNorm(disable_weight_init.GroupNorm):
ldm_patched_cast_weights = True
parameters_manual_cast = True
class LayerNorm(disable_weight_init.LayerNorm):
ldm_patched_cast_weights = True
parameters_manual_cast = True