mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed cleanup of emebddings.
This commit is contained in:
@@ -13,6 +13,21 @@ def flush(garbage_collect=True):
|
||||
gc.collect()
|
||||
|
||||
|
||||
def get_mean_std(tensor):
|
||||
if len(tensor.shape) == 3:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
elif len(tensor.shape) != 4:
|
||||
raise Exception("Expected tensor of shape (batch_size, channels, width, height)")
|
||||
mean, variance = torch.mean(
|
||||
tensor, dim=[2, 3], keepdim=True
|
||||
), torch.var(
|
||||
tensor, dim=[2, 3],
|
||||
keepdim=True
|
||||
)
|
||||
std = torch.sqrt(variance + 1e-5)
|
||||
return mean, std
|
||||
|
||||
|
||||
def adain(content_features, style_features):
|
||||
# Assumes that the content and style features are of shape (batch_size, channels, width, height)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user