Fix device issue

This commit is contained in:
huchenlei
2024-01-03 19:06:35 -05:00
parent 0d668aff1c
commit 5926b7e47b
9 changed files with 60 additions and 116 deletions

View File

@@ -8,7 +8,7 @@ class MeshGraphormerDetector:
self.pipeline = pipeline
@classmethod
def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device="cuda"):
def from_pretrained(cls, pretrained_model_or_path, filename=None, hrnet_filename=None, cache_dir=None, device=None):
filename = filename or "graphormer_hand_state_dict.bin"
hrnet_filename = hrnet_filename or "hrnetv2_w64_imagenet_pretrained.pth"
args.resume_checkpoint = custom_hf_download(pretrained_model_or_path, filename, cache_dir)

View File

@@ -98,10 +98,9 @@ class MeshGraphormerMediapipe(Preprocessor):
for i in range(len(output_feat_dim)):
config_class, model_class = BertConfig, Graphormer
config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
device=args.device,
args.config_name if args.config_name else args.model_name_or_path
)
setattr(config, "device", args.device)
config.output_attentions = False
config.img_feature_dim = input_feat_dim[i]
config.output_feature_dim = output_feat_dim[i]