Update forge_reference.py

This commit is contained in:
lllyasviel
2024-01-30 20:42:54 -08:00
parent c6a344aed3
commit cf6c5768ae

View File

@@ -83,6 +83,8 @@ class PreprocessorReference(Preprocessor):
if not (sigma_min <= sigma <= sigma_max):
return h
C = int(h.shape[1])
if self.is_recording_style:
self.recorded_h[location] = torch.std_mean(h, dim=(2, 3), keepdim=True, correction=0)
return h
@@ -121,6 +123,8 @@ class PreprocessorReference(Preprocessor):
location = (transformer_options['block'][0], transformer_options['block'][1],
transformer_options['block_index'])
C = int(q.shape[2])
if self.is_recording_style:
self.recorded_attn1[location] = (k, v)
return sdp(q, k, v, transformer_options)