import torch def set_attr(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) def set_attr_raw(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) setattr(obj, attrs[-1], value) def copy_to_param(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) prev.data.copy_(value) def get_attr(obj, attr): attrs = attr.split(".") for name in attrs: obj = getattr(obj, name) return obj def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys(): if k.startswith(prefix): params += sd[k].nelement() return params