mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Cleanup and small bug fixes
This commit is contained in:
@@ -213,9 +213,12 @@ class LoRAModule(torch.nn.Module):
|
||||
device = state_dict['lora_up.weight'].device
|
||||
|
||||
# todo should we do this at fp32?
|
||||
if isinstance(self.normalize_scaler, torch.Tensor):
|
||||
scaler = self.normalize_scaler.clone().detach()
|
||||
else:
|
||||
scaler = torch.tensor(self.normalize_scaler).to(device, dtype=dtype)
|
||||
|
||||
total_module_scale = torch.tensor(self.normalize_scaler / target_normalize_scaler) \
|
||||
.to(device, dtype=dtype)
|
||||
total_module_scale = scaler / target_normalize_scaler
|
||||
num_modules_layers = 2 # up and down
|
||||
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
|
||||
.to(device, dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user