mirror of
https://github.com/huchenlei/HandRefinerPortable.git
synced 2026-04-30 11:11:13 +00:00
Plumbing on device. Allow device != cuda
This commit is contained in:
@@ -97,8 +97,10 @@ class MeshGraphormerMediapipe(Preprocessor):
|
||||
# init three transformer-encoder blocks in a loop
|
||||
for i in range(len(output_feat_dim)):
|
||||
config_class, model_class = BertConfig, Graphormer
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name \
|
||||
else args.model_name_or_path)
|
||||
config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
config.output_attentions = False
|
||||
config.img_feature_dim = input_feat_dim[i]
|
||||
@@ -155,7 +157,7 @@ class MeshGraphormerMediapipe(Preprocessor):
|
||||
#logger.info('Backbone total parameters: {}'.format(backbone_total_params))
|
||||
|
||||
# build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
|
||||
_model = Graphormer_Network(args, config, backbone, trans_encoder)
|
||||
_model = Graphormer_Network(args, config, backbone, trans_encoder, device=args.device)
|
||||
|
||||
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
|
||||
# for fine-tuning or resume training or inference, load weights from checkpoint
|
||||
|
||||
Reference in New Issue
Block a user