Plumbing on device. Allow device != cuda

This commit is contained in:
huchenlei
2024-01-03 15:28:22 -05:00
parent 373886cde5
commit 0d668aff1c
7 changed files with 19 additions and 74 deletions

View File

@@ -97,8 +97,10 @@ class MeshGraphormerMediapipe(Preprocessor):
# init three transformer-encoder blocks in a loop
for i in range(len(output_feat_dim)):
config_class, model_class = BertConfig, Graphormer
config = config_class.from_pretrained(args.config_name if args.config_name \
else args.model_name_or_path)
config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
device=args.device,
)
config.output_attentions = False
config.img_feature_dim = input_feat_dim[i]
@@ -155,7 +157,7 @@ class MeshGraphormerMediapipe(Preprocessor):
#logger.info('Backbone total parameters: {}'.format(backbone_total_params))
# build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
_model = Graphormer_Network(args, config, backbone, trans_encoder)
_model = Graphormer_Network(args, config, backbone, trans_encoder, device=args.device)
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
# for fine-tuning or resume training or inference, load weights from checkpoint