diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f278fccac..b1d907ba4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -679,18 +679,19 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self, prio_comfy_cast_weights=False): + def _load_list(self, prio_comfy_cast_weights=False, default_device=None): loading = [] for n, m in self.model.named_modules(): - params = [] - skip = False - for name, param in m.named_parameters(recurse=False): - params.append(name) + default = False + params = { name: param for name, param in m.named_parameters(recurse=False) } for name, param in m.named_parameters(recurse=True): if name not in params: - skip = True # skip random weights in non leaf modules + default = True # default random weights in non leaf modules break - if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): + if default and default_device is not None: + for param in params.values(): + param.data = param.data.to(device=default_device) + if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0): module_mem = comfy.model_management.module_size(m) module_offload_mem = module_mem if hasattr(m, "comfy_cast_weights"): @@ -1495,7 +1496,7 @@ class ModelPatcherDynamic(ModelPatcher): #with pin and unpin syncrhonization which can be expensive for small weights #with a high layer rate (e.g. autoregressive LLMs). #prioritize the non-comfy weights (note the order reverse). - loading = self._load_list(prio_comfy_cast_weights=True) + loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) loading.sort(reverse=True) for x in loading: @@ -1579,7 +1580,7 @@ class ModelPatcherDynamic(ModelPatcher): return 0 if vbar is None else vbar.free_memory(memory_to_free) def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(prio_comfy_cast_weights=True) + loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) for x in loading: _, _, _, _, m, _ = x ram_to_unload -= comfy.pinned_memory.unpin_memory(m) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index b6735d210..54f3d5595 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -355,13 +355,6 @@ class RMSNorm(nn.Module): -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None): if not isinstance(theta, list): theta = [theta] @@ -390,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di else: cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) - out.append((cos, sin)) + sin_split = sin.shape[-1] // 2 + out.append((cos, sin[..., : sin_split], -sin[..., sin_split :])) if len(out) == 1: return out[0] return out - def apply_rope(xq, xk, freqs_cis): org_dtype = xq.dtype cos = freqs_cis[0] sin = freqs_cis[1] - q_embed = (xq * cos) + (rotate_half(xq) * sin) - k_embed = (xk * cos) + (rotate_half(xk) * sin) + nsin = freqs_cis[2] + + q_embed = (xq * cos) + q_split = q_embed.shape[-1] // 2 + q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin) + q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin) + + k_embed = (xk * cos) + k_split = k_embed.shape[-1] // 2 + k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin) + k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin) + return q_embed.to(org_dtype), k_embed.to(org_dtype)