Update forge_reference.py

This commit is contained in:
lllyasviel
2024-01-30 20:45:18 -08:00
parent cf6c5768ae
commit 035ad4836a

View File

@@ -83,7 +83,11 @@ class PreprocessorReference(Preprocessor):
if not (sigma_min <= sigma <= sigma_max):
return h
C = int(h.shape[1])
channel = int(h.shape[1])
minimal_channel = 1280 - 640 * weight
if channel < minimal_channel:
return h
if self.is_recording_style:
self.recorded_h[location] = torch.std_mean(h, dim=(2, 3), keepdim=True, correction=0)
@@ -123,7 +127,11 @@ class PreprocessorReference(Preprocessor):
location = (transformer_options['block'][0], transformer_options['block'][1],
transformer_options['block_index'])
C = int(q.shape[2])
channel = int(q.shape[2])
minimal_channel = 1280 - 1280 * weight
if channel < minimal_channel:
return sdp(q, k, v, transformer_options)
if self.is_recording_style:
self.recorded_attn1[location] = (k, v)