diff --git a/toolkit/ema.py b/toolkit/ema.py index e3b3a7ea..b34554bb 100644 --- a/toolkit/ema.py +++ b/toolkit/ema.py @@ -137,7 +137,8 @@ class ExponentialMovingAverage: update_param = False if self.use_feedback: - param_float.add_(tmp) + # make feedback 10x decay + param_float.add_(tmp * 10) update_param = True if self.param_multiplier != 1.0: diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index e911484a..01d7f278 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -469,8 +469,9 @@ class DiffusionFeatureExtractor4(nn.Module): ) # embeds = id_embeds['hidden_states'][-2] # penultimate layer - embeds = id_embeds['pooler_output'] - return embeds + image_embeds = id_embeds['pooler_output'] + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + return image_embeds def forward( self,