Fixed cleanup of emebddings.

This commit is contained in:
Jaret Burkett
2023-11-16 20:26:11 -07:00
parent ad50921c41
commit 6280284d8b
2 changed files with 27 additions and 3 deletions

View File

@@ -13,6 +13,21 @@ def flush(garbage_collect=True):
gc.collect()
def get_mean_std(tensor):
if len(tensor.shape) == 3:
tensor = tensor.unsqueeze(0)
elif len(tensor.shape) != 4:
raise Exception("Expected tensor of shape (batch_size, channels, width, height)")
mean, variance = torch.mean(
tensor, dim=[2, 3], keepdim=True
), torch.var(
tensor, dim=[2, 3],
keepdim=True
)
std = torch.sqrt(variance + 1e-5)
return mean, std
def adain(content_features, style_features):
# Assumes that the content and style features are of shape (batch_size, channels, width, height)