diff --git a/demo.ipynb b/demo.ipynb index 0755de1..62e6c3c 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -106,7 +106,10 @@ "model = model.to(device)\n", "\n", "with torch.no_grad():\n", - " caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)\n", + " # beam search\n", + " caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5) \n", + " # nucleus sampling\n", + " # caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5) \n", " print('caption: '+caption[0])" ] },