mirror of
https://github.com/huchenlei/HandRefinerPortable.git
synced 2026-04-30 11:11:13 +00:00
Fix device issue
This commit is contained in:
@@ -8,7 +8,7 @@ class MeshGraphormerDetector:
|
||||
self.pipeline = pipeline
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device="cuda"):
|
||||
def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device=None):
|
||||
filename = filename or "graphormer_hand_state_dict.bin"
|
||||
hrnet_filename = hrnet_filename or "hrnetv2_w64_imagenet_pretrained.pth"
|
||||
args.resume_checkpoint = custom_hf_download(pretrained_model_or_path, filename, cache_dir)
|
||||
|
||||
@@ -98,10 +98,9 @@ class MeshGraphormerMediapipe(Preprocessor):
|
||||
for i in range(len(output_feat_dim)):
|
||||
config_class, model_class = BertConfig, Graphormer
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
device=args.device,
|
||||
args.config_name if args.config_name else args.model_name_or_path
|
||||
)
|
||||
|
||||
setattr(config, "device", args.device)
|
||||
config.output_attentions = False
|
||||
config.img_feature_dim = input_feat_dim[i]
|
||||
config.output_feature_dim = output_feat_dim[i]
|
||||
|
||||
Reference in New Issue
Block a user