mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-29 19:01:20 +00:00
Continued monkey patching fairseq for RVC python 3.11
This commit is contained in:
@@ -26,20 +26,20 @@ sys.modules["fairseq.metrics"] = metrics
|
||||
sys.modules["fairseq.progress_bar"] = progress_bar
|
||||
|
||||
# initialize hydra
|
||||
from fairseq.dataclass.initialize import hydra_init
|
||||
#from fairseq.dataclass.initialize import hydra_init
|
||||
|
||||
#hydra_init()
|
||||
|
||||
import fairseq.criterions # noqa
|
||||
import fairseq.distributed # noqa
|
||||
import fairseq.models # noqa
|
||||
import fairseq.modules # noqa
|
||||
import fairseq.optim # noqa
|
||||
import fairseq.optim.lr_scheduler # noqa
|
||||
import fairseq.pdb # noqa
|
||||
import fairseq.scoring # noqa
|
||||
import fairseq.tasks # noqa
|
||||
import fairseq.token_generation_constraints # noqa
|
||||
#import fairseq.criterions # noqa
|
||||
#import fairseq.distributed # noqa
|
||||
#import fairseq.models # noqa
|
||||
#import fairseq.modules # noqa
|
||||
#import fairseq.optim # noqa
|
||||
#import fairseq.optim.lr_scheduler # noqa
|
||||
#import fairseq.pdb # noqa
|
||||
#import fairseq.scoring # noqa
|
||||
#import fairseq.tasks # noqa
|
||||
#import fairseq.token_generation_constraints # noqa
|
||||
|
||||
import fairseq.benchmark # noqa
|
||||
import fairseq.model_parallel # noqa
|
||||
#import fairseq.benchmark # noqa
|
||||
#import fairseq.model_parallel # noqa
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# import models/tasks to register them
|
||||
from . import dummy_dataset, dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa
|
||||
@@ -1,172 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
|
||||
from fairseq.modules.multihead_attention import MultiheadAttention
|
||||
|
||||
BATCH = [20, 41, 97]
|
||||
SEQ = 64
|
||||
EMB = 48
|
||||
HEADS = 4
|
||||
DROP = 0.1
|
||||
DEVICE = torch.device("cuda")
|
||||
ATTN_MASK_DTYPE = [torch.uint8, torch.bool, torch.float]
|
||||
KEY_PADDING_MASK_DTYPE = [torch.uint8, torch.bool]
|
||||
|
||||
|
||||
def _reset_seeds():
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
|
||||
def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
|
||||
if to_dtype == torch.float:
|
||||
mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
|
||||
return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
|
||||
return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
|
||||
|
||||
|
||||
def benchmark_multihead_attention(
|
||||
label="",
|
||||
attn_dtype=torch.uint8,
|
||||
key_padding_dtype=torch.uint8,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
static_kv=False,
|
||||
batch_size=20,
|
||||
embedding=EMB,
|
||||
seq_len=SEQ,
|
||||
num_heads=HEADS,
|
||||
):
|
||||
|
||||
results = []
|
||||
# device = torch.device("cuda")
|
||||
|
||||
xformers_att_config = '{"name": "scaled_dot_product"}'
|
||||
|
||||
attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len)
|
||||
key_padding_mask = _get_mask(
|
||||
to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len
|
||||
)
|
||||
|
||||
q = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
|
||||
k = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
|
||||
v = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
|
||||
|
||||
_reset_seeds()
|
||||
|
||||
original_mha = MultiheadAttention(
|
||||
embedding,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
xformers_att_config=None,
|
||||
add_bias_kv=add_bias_kv,
|
||||
add_zero_attn=add_zero_attn,
|
||||
)
|
||||
|
||||
xformers_mha = MultiheadAttention(
|
||||
embedding,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
xformers_att_config=xformers_att_config,
|
||||
add_bias_kv=add_bias_kv,
|
||||
add_zero_attn=add_zero_attn,
|
||||
)
|
||||
|
||||
def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
|
||||
original_mha(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
static_kv=static_kv,
|
||||
)
|
||||
|
||||
def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
|
||||
xformers_mha(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
static_kv=static_kv,
|
||||
)
|
||||
|
||||
def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
|
||||
output, _ = original_mha(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
static_kv=static_kv,
|
||||
)
|
||||
loss = torch.norm(output)
|
||||
loss.backward()
|
||||
|
||||
def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
|
||||
output, _ = xformers_mha(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
static_kv=static_kv,
|
||||
)
|
||||
loss = torch.norm(output)
|
||||
loss.backward()
|
||||
|
||||
fns = [
|
||||
original_bench_fw,
|
||||
xformers_bench_fw,
|
||||
original_bench_fw_bw,
|
||||
xformers_bench_fw_bw,
|
||||
]
|
||||
|
||||
for fn in fns:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)",
|
||||
globals={
|
||||
"q": q,
|
||||
"k": k,
|
||||
"v": v,
|
||||
"key_padding_mask": key_padding_mask,
|
||||
"attn_mask": attn_mask,
|
||||
"static_kv": static_kv,
|
||||
"fn": fn,
|
||||
},
|
||||
label="multihead fw + bw",
|
||||
sub_label=f"{fn.__name__}",
|
||||
description=label,
|
||||
).blocked_autorange(min_run_time=1)
|
||||
)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
def run_benchmarks():
|
||||
for attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn in itertools.product(
|
||||
ATTN_MASK_DTYPE, KEY_PADDING_MASK_DTYPE, [True, False], [True, False]
|
||||
):
|
||||
label = f"attn_dtype {attn_dtype}, key_padding_dtype {key_padding_dtype}, \
|
||||
add_bias_kv {add_bias_kv}, add_zero_attn {add_zero_attn}"
|
||||
benchmark_multihead_attention(
|
||||
label=label,
|
||||
attn_dtype=attn_dtype,
|
||||
key_padding_dtype=key_padding_dtype,
|
||||
add_bias_kv=add_bias_kv,
|
||||
add_zero_attn=add_zero_attn,
|
||||
)
|
||||
|
||||
|
||||
run_benchmarks()
|
||||
@@ -1,36 +0,0 @@
|
||||
import numpy as np
|
||||
from fairseq.data import FairseqDataset
|
||||
|
||||
|
||||
class DummyDataset(FairseqDataset):
|
||||
def __init__(self, batch, num_items, item_size):
|
||||
super().__init__()
|
||||
self.batch = batch
|
||||
self.num_items = num_items
|
||||
self.item_size = item_size
|
||||
|
||||
def __getitem__(self, index):
|
||||
return index
|
||||
|
||||
def __len__(self):
|
||||
return self.num_items
|
||||
|
||||
def collater(self, samples):
|
||||
return self.batch
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return np.array([self.item_size] * self.num_items)
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.item_size
|
||||
|
||||
def size(self, index):
|
||||
return self.item_size
|
||||
|
||||
def ordered_indices(self):
|
||||
return np.arange(self.num_items)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
@@ -1,83 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from .dummy_dataset import DummyDataset
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
from omegaconf import II
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyLMConfig(FairseqDataclass):
|
||||
dict_size: int = 49996
|
||||
dataset_size: int = 100000
|
||||
tokens_per_sample: int = field(
|
||||
default=512, metadata={"help": "max sequence length"}
|
||||
)
|
||||
add_bos_token: bool = False
|
||||
batch_size: Optional[int] = II("dataset.batch_size")
|
||||
max_tokens: Optional[int] = II("dataset.max_tokens")
|
||||
max_target_positions: int = II("task.tokens_per_sample")
|
||||
|
||||
|
||||
@register_task("dummy_lm", dataclass=DummyLMConfig)
|
||||
class DummyLMTask(FairseqTask):
|
||||
def __init__(self, cfg: DummyLMConfig):
|
||||
super().__init__(cfg)
|
||||
|
||||
# load dictionary
|
||||
self.dictionary = Dictionary()
|
||||
for i in range(cfg.dict_size):
|
||||
self.dictionary.add_symbol("word{}".format(i))
|
||||
self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
||||
logger.info("dictionary: {} types".format(len(self.dictionary)))
|
||||
|
||||
seq = torch.arange(cfg.tokens_per_sample + 1) + self.dictionary.pad() + 1
|
||||
|
||||
self.dummy_src = seq[:-1]
|
||||
self.dummy_tgt = seq[1:]
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
if self.cfg.batch_size is not None:
|
||||
bsz = self.cfg.batch_size
|
||||
else:
|
||||
bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
|
||||
self.datasets[split] = DummyDataset(
|
||||
{
|
||||
"id": 1,
|
||||
"net_input": {
|
||||
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||
"src_lengths": torch.full(
|
||||
(bsz,), self.cfg.tokens_per_sample, dtype=torch.long
|
||||
),
|
||||
},
|
||||
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
||||
"nsentences": bsz,
|
||||
"ntokens": bsz * self.cfg.tokens_per_sample,
|
||||
},
|
||||
num_items=self.cfg.dataset_size,
|
||||
item_size=self.cfg.tokens_per_sample,
|
||||
)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
@@ -1,94 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from omegaconf import II
|
||||
|
||||
from .dummy_dataset import DummyDataset
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyMaskedLMConfig(FairseqDataclass):
|
||||
dict_size: int = 49996
|
||||
dataset_size: int = 100000
|
||||
tokens_per_sample: int = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "max number of total tokens over all"
|
||||
" segments per sample for BERT dataset"
|
||||
},
|
||||
)
|
||||
batch_size: Optional[int] = II("dataset.batch_size")
|
||||
max_tokens: Optional[int] = II("dataset.max_tokens")
|
||||
max_target_positions: int = II("task.tokens_per_sample")
|
||||
|
||||
|
||||
@register_task("dummy_masked_lm", dataclass=DummyMaskedLMConfig)
|
||||
class DummyMaskedLMTask(FairseqTask):
|
||||
def __init__(self, cfg: DummyMaskedLMConfig):
|
||||
super().__init__(cfg)
|
||||
|
||||
self.dictionary = Dictionary()
|
||||
for i in range(cfg.dict_size):
|
||||
self.dictionary.add_symbol("word{}".format(i))
|
||||
logger.info("dictionary: {} types".format(len(self.dictionary)))
|
||||
# add mask token
|
||||
self.mask_idx = self.dictionary.add_symbol("<mask>")
|
||||
self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
||||
|
||||
mask_idx = 0
|
||||
pad_idx = 1
|
||||
seq = torch.arange(cfg.tokens_per_sample) + pad_idx + 1
|
||||
mask = torch.arange(2, cfg.tokens_per_sample, 7) # ~15%
|
||||
src = seq.clone()
|
||||
src[mask] = mask_idx
|
||||
tgt = torch.full_like(seq, pad_idx)
|
||||
tgt[mask] = seq[mask]
|
||||
|
||||
self.dummy_src = src
|
||||
self.dummy_tgt = tgt
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
if self.cfg.batch_size is not None:
|
||||
bsz = self.cfg.batch_size
|
||||
else:
|
||||
bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
|
||||
self.datasets[split] = DummyDataset(
|
||||
{
|
||||
"id": 1,
|
||||
"net_input": {
|
||||
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||
"src_lengths": torch.full(
|
||||
(bsz,), self.cfg.tokens_per_sample, dtype=torch.long
|
||||
),
|
||||
},
|
||||
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
||||
"nsentences": bsz,
|
||||
"ntokens": bsz * self.cfg.tokens_per_sample,
|
||||
},
|
||||
num_items=self.cfg.dataset_size,
|
||||
item_size=self.cfg.tokens_per_sample,
|
||||
)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
@@ -1,96 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.models import (
|
||||
FairseqDecoder,
|
||||
FairseqLanguageModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
|
||||
|
||||
@register_model("dummy_model")
|
||||
class DummyModel(FairseqLanguageModel):
|
||||
def __init__(self, args, encoder):
|
||||
super().__init__(encoder)
|
||||
self.args = args
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
parser.add_argument("--num-layers", type=int, default=24)
|
||||
parser.add_argument("--embed-dim", type=int, default=1024)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
encoder = DummyEncoder(
|
||||
num_embed=len(task.target_dictionary),
|
||||
embed_dim=args.embed_dim,
|
||||
num_layers=args.num_layers,
|
||||
)
|
||||
return cls(args, encoder)
|
||||
|
||||
def forward(self, src_tokens, masked_tokens=None, **kwargs):
|
||||
return self.decoder(src_tokens, masked_tokens=masked_tokens)
|
||||
|
||||
|
||||
class DummyEncoder(FairseqDecoder):
|
||||
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
|
||||
super().__init__(Dictionary())
|
||||
self.embed = nn.Embedding(
|
||||
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
|
||||
)
|
||||
self.layers_a = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection
|
||||
nn.Linear(3 * embed_dim, embed_dim), # skip self-attention
|
||||
nn.Linear(embed_dim, embed_dim), # output projection
|
||||
nn.Dropout(),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.layers_b = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
nn.Linear(embed_dim, 4 * embed_dim), # FFN
|
||||
nn.ReLU(),
|
||||
nn.Linear(4 * embed_dim, embed_dim), # FFN
|
||||
nn.Dropout(0.1),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.out_proj = nn.Linear(embed_dim, num_embed)
|
||||
|
||||
def forward(self, tokens, masked_tokens=None):
|
||||
x = self.embed(tokens)
|
||||
for layer_a, layer_b in zip(self.layers_a, self.layers_b):
|
||||
x = x + layer_a(x)
|
||||
x = x + layer_b(x)
|
||||
x = self.out_proj(x)
|
||||
if masked_tokens is not None:
|
||||
x = x[masked_tokens]
|
||||
return (x,)
|
||||
|
||||
def max_positions(self):
|
||||
return 1024
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
||||
logits = net_output[0].float()
|
||||
if log_probs:
|
||||
return F.log_softmax(logits, dim=-1)
|
||||
else:
|
||||
return F.softmax(logits, dim=-1)
|
||||
|
||||
|
||||
@register_model_architecture("dummy_model", "dummy_model")
|
||||
def base_architecture(args):
|
||||
pass
|
||||
@@ -1,119 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data import Dictionary, FairseqDataset
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task("dummy_mt")
|
||||
class DummyMTTask(LegacyFairseqTask):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument("--dict-size", default=49996, type=int)
|
||||
parser.add_argument("--dataset-size", default=100000, type=int)
|
||||
parser.add_argument("--src-len", default=30, type=int)
|
||||
parser.add_argument("--tgt-len", default=30, type=int)
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(args)
|
||||
self.dictionary = dictionary
|
||||
self.seed = args.seed
|
||||
|
||||
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
||||
|
||||
self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1
|
||||
self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
"""Setup the task."""
|
||||
dictionary = Dictionary()
|
||||
for i in range(args.dict_size):
|
||||
dictionary.add_symbol("word{}".format(i))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
|
||||
args.max_source_positions = args.src_len + dictionary.pad() + 2
|
||||
args.max_target_positions = args.tgt_len + dictionary.pad() + 2
|
||||
|
||||
return cls(args, dictionary)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
item_size = max(self.args.src_len, self.args.tgt_len)
|
||||
if self.args.batch_size is not None:
|
||||
bsz = self.args.batch_size
|
||||
else:
|
||||
bsz = max(1, self.args.max_tokens // item_size)
|
||||
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
|
||||
self.datasets[split] = DummyDataset(
|
||||
{
|
||||
"id": 1,
|
||||
"net_input": {
|
||||
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||
"src_lengths": torch.full(
|
||||
(bsz,), self.args.src_len, dtype=torch.long
|
||||
),
|
||||
"prev_output_tokens": tgt.clone(),
|
||||
},
|
||||
"target": tgt,
|
||||
"nsentences": bsz,
|
||||
"ntokens": bsz * self.args.tgt_len,
|
||||
},
|
||||
num_items=self.args.dataset_size,
|
||||
item_size=item_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
|
||||
class DummyDataset(FairseqDataset):
|
||||
def __init__(self, batch, num_items, item_size):
|
||||
super().__init__()
|
||||
self.batch = batch
|
||||
self.num_items = num_items
|
||||
self.item_size = item_size
|
||||
|
||||
def __getitem__(self, index):
|
||||
return index
|
||||
|
||||
def __len__(self):
|
||||
return self.num_items
|
||||
|
||||
def collater(self, samples):
|
||||
return self.batch
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return np.array([self.item_size] * self.num_items)
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.item_size
|
||||
|
||||
def size(self, index):
|
||||
return self.item_size
|
||||
|
||||
def ordered_indices(self):
|
||||
return np.arange(self.num_items)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
@@ -1,55 +0,0 @@
|
||||
/*
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
/*
|
||||
CPP Binding for CUDA OP
|
||||
*/
|
||||
|
||||
// CUDA forward declarations
|
||||
torch::Tensor ngram_repeat_block_cuda_forward(
|
||||
torch::Tensor tokens,
|
||||
torch::Tensor lprobs,
|
||||
int bsz,
|
||||
int step,
|
||||
int beam_size,
|
||||
int no_repeat_ngram_size);
|
||||
|
||||
#define CHECK_CUDA(x) \
|
||||
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
// Input check and call to CUDA OP
|
||||
// Backward method not required
|
||||
torch::Tensor ngram_repeat_block_forward(
|
||||
torch::Tensor tokens,
|
||||
torch::Tensor lprobs,
|
||||
int bsz,
|
||||
int step,
|
||||
int beam_size,
|
||||
int no_repeat_ngram_size) {
|
||||
CHECK_INPUT(tokens);
|
||||
CHECK_INPUT(lprobs);
|
||||
assert(bsz > 0);
|
||||
assert(step >= 0);
|
||||
assert(beam_size > 0);
|
||||
assert(no_repeat_ngram_size > 0);
|
||||
|
||||
return ngram_repeat_block_cuda_forward(
|
||||
tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"forward",
|
||||
&ngram_repeat_block_forward,
|
||||
"No Repeat Ngram Block forward (CUDA)");
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
/*
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
|
||||
/*
|
||||
Kernel implementation for blocking repeated n-grams.
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <math.h>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
// Ban repeated ngrams of length = 'no_repeat_ngram_size'
|
||||
__global__ void banRepeatedTokens(
|
||||
long* __restrict__ tokens,
|
||||
float* __restrict__ lprobs,
|
||||
int max_predict_len,
|
||||
int vocab_size,
|
||||
int no_repeat_ngram_size) {
|
||||
auto row = blockIdx.x;
|
||||
auto col = threadIdx.x;
|
||||
auto start = row * (max_predict_len) + col;
|
||||
// Each thread compares ngram starting from
|
||||
// thread index with final ngram starting from
|
||||
// step - no_repeat_ngram_size +2
|
||||
auto check_start_pos = blockDim.x;
|
||||
auto lprob_start = row * vocab_size;
|
||||
bool is_banned = true;
|
||||
extern __shared__ long tokens_shm[];
|
||||
tokens_shm[col] = tokens[start];
|
||||
if (col == blockDim.x - 1) {
|
||||
for (int i = 1; i < no_repeat_ngram_size; i++) {
|
||||
if (col + i < max_predict_len) {
|
||||
tokens_shm[col + i] = tokens[start + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int k = 0; k < no_repeat_ngram_size - 1; k++) {
|
||||
if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) {
|
||||
is_banned = false;
|
||||
}
|
||||
}
|
||||
if (is_banned == true) {
|
||||
auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1];
|
||||
lprobs[lprob_start + token_to_be_banned] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate blocks and threads based on
|
||||
// batch size and sequence length and launch
|
||||
// kernel
|
||||
torch::Tensor ngram_repeat_block_cuda_forward(
|
||||
const torch::Tensor tokens,
|
||||
torch::Tensor lprobs,
|
||||
int bsz,
|
||||
int step,
|
||||
int beam_size,
|
||||
int no_repeat_ngram_size) {
|
||||
int threads = step - no_repeat_ngram_size + 2;
|
||||
if (threads <= 0)
|
||||
return lprobs;
|
||||
int max_predict_len = tokens.size(1);
|
||||
int vocab_size = lprobs.size(1);
|
||||
auto token_ptr = tokens.data_ptr<long>();
|
||||
auto lprob_ptr = lprobs.data_ptr<float>();
|
||||
int blocks = bsz * beam_size;
|
||||
int shared_mem_size = (step + 1) * sizeof(long);
|
||||
|
||||
// Launching N blocks where N is number of samples in a batch (beams*bsz)
|
||||
// Launching T threads where T is number of previous ngrams in a sample
|
||||
// Allocating shared mem per block for fastser access of input tokens since
|
||||
// each token will be accessed N times to compare with current Ngram where
|
||||
// N is Ngram size.
|
||||
banRepeatedTokens<<<blocks, threads, shared_mem_size>>>(
|
||||
token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size);
|
||||
return lprobs;
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
/**
|
||||
* Copyright 2017-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
/*
|
||||
C++ code for solving the linear assignment problem.
|
||||
Based on the Auction Algorithm from
|
||||
https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the
|
||||
implementation from: https://github.com/bkj/auction-lap Adapted to be more
|
||||
efficient when each worker is looking for k jobs instead of 1.
|
||||
*/
|
||||
#include <torch/extension.h>
|
||||
#include <iostream>
|
||||
using namespace torch::indexing;
|
||||
torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score) {
|
||||
int max_iterations = 100;
|
||||
torch::Tensor epsilon =
|
||||
(job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50;
|
||||
epsilon.clamp_min_(1e-04);
|
||||
torch::Tensor worker_and_job_to_score =
|
||||
job_and_worker_to_score.detach().transpose(0, 1).contiguous();
|
||||
int num_workers = worker_and_job_to_score.size(0);
|
||||
int num_jobs = worker_and_job_to_score.size(1);
|
||||
auto device = worker_and_job_to_score.device();
|
||||
int jobs_per_worker = num_jobs / num_workers;
|
||||
torch::Tensor value = worker_and_job_to_score.clone();
|
||||
int counter = 0;
|
||||
torch::Tensor max_value = worker_and_job_to_score.max();
|
||||
|
||||
torch::Tensor bid_indices;
|
||||
torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs});
|
||||
torch::Tensor bids =
|
||||
worker_and_job_to_score.new_empty({num_workers, num_jobs});
|
||||
torch::Tensor bid_increments =
|
||||
worker_and_job_to_score.new_empty({num_workers, jobs_per_worker});
|
||||
torch::Tensor top_values =
|
||||
worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1});
|
||||
torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs});
|
||||
|
||||
torch::Tensor top_index = top_values.to(torch::kLong);
|
||||
torch::Tensor high_bidders = top_index.new_empty({num_jobs});
|
||||
torch::Tensor have_bids = high_bidders.to(torch::kBool);
|
||||
torch::Tensor jobs_indices =
|
||||
torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device));
|
||||
torch::Tensor true_tensor =
|
||||
torch::ones({1}, torch::dtype(torch::kBool).device(device));
|
||||
|
||||
while (true) {
|
||||
bids.zero_();
|
||||
torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1);
|
||||
|
||||
// Each worker bids the difference in value between that job and the k+1th
|
||||
// job
|
||||
torch::sub_out(
|
||||
bid_increments,
|
||||
top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}),
|
||||
top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1));
|
||||
|
||||
bid_increments.add_(epsilon);
|
||||
bids.scatter_(
|
||||
1,
|
||||
top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}),
|
||||
bid_increments);
|
||||
|
||||
if (counter < max_iterations && counter > 0) {
|
||||
// Put in a minimal bid to retain items from the last round if no-one else
|
||||
// bids for them this round
|
||||
bids.view(-1).index_put_({bid_indices}, epsilon);
|
||||
}
|
||||
|
||||
// Find the highest bidding worker per job
|
||||
torch::max_out(high_bids, high_bidders, bids, 0);
|
||||
torch::gt_out(have_bids, high_bids, 0);
|
||||
|
||||
if (have_bids.all().item<bool>()) {
|
||||
// All jobs were bid for
|
||||
break;
|
||||
}
|
||||
|
||||
// Make popular items more expensive
|
||||
cost.add_(high_bids);
|
||||
torch::sub_out(value, worker_and_job_to_score, cost);
|
||||
|
||||
bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids});
|
||||
|
||||
if (counter < max_iterations) {
|
||||
// Make sure that this item will be in the winning worker's top-k next
|
||||
// time.
|
||||
value.view(-1).index_put_({bid_indices}, max_value);
|
||||
} else {
|
||||
// Suboptimal approximation that converges quickly from current solution
|
||||
value.view(-1).index_put_(
|
||||
{bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices}));
|
||||
}
|
||||
|
||||
counter += 1;
|
||||
}
|
||||
|
||||
return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)})
|
||||
.reshape(-1);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment");
|
||||
}
|
||||
@@ -1,157 +0,0 @@
|
||||
/**
|
||||
* Copyright 2017-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <array>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
typedef struct {
|
||||
size_t reflen;
|
||||
size_t predlen;
|
||||
size_t match1;
|
||||
size_t count1;
|
||||
size_t match2;
|
||||
size_t count2;
|
||||
size_t match3;
|
||||
size_t count3;
|
||||
size_t match4;
|
||||
size_t count4;
|
||||
} bleu_stat;
|
||||
|
||||
// left trim (remove pad)
|
||||
void bleu_ltrim(size_t* len, int** sent, int pad) {
|
||||
size_t start = 0;
|
||||
while (start < *len) {
|
||||
if (*(*sent + start) != pad) {
|
||||
break;
|
||||
}
|
||||
start++;
|
||||
}
|
||||
*sent += start;
|
||||
*len -= start;
|
||||
}
|
||||
|
||||
// right trim remove (eos)
|
||||
void bleu_rtrim(size_t* len, int** sent, int pad, int eos) {
|
||||
size_t end = *len - 1;
|
||||
while (end > 0) {
|
||||
if (*(*sent + end) != eos && *(*sent + end) != pad) {
|
||||
break;
|
||||
}
|
||||
end--;
|
||||
}
|
||||
*len = end + 1;
|
||||
}
|
||||
|
||||
// left and right trim
|
||||
void bleu_trim(size_t* len, int** sent, int pad, int eos) {
|
||||
bleu_ltrim(len, sent, pad);
|
||||
bleu_rtrim(len, sent, pad, eos);
|
||||
}
|
||||
|
||||
size_t bleu_hash(int len, int* data) {
|
||||
size_t h = 14695981039346656037ul;
|
||||
size_t prime = 0x100000001b3;
|
||||
char* b = (char*)data;
|
||||
size_t blen = sizeof(int) * len;
|
||||
|
||||
while (blen-- > 0) {
|
||||
h ^= *b++;
|
||||
h *= prime;
|
||||
}
|
||||
|
||||
return h;
|
||||
}
|
||||
|
||||
void bleu_addngram(
|
||||
size_t* ntotal,
|
||||
size_t* nmatch,
|
||||
size_t n,
|
||||
size_t reflen,
|
||||
int* ref,
|
||||
size_t predlen,
|
||||
int* pred) {
|
||||
if (predlen < n) {
|
||||
return;
|
||||
}
|
||||
|
||||
predlen = predlen - n + 1;
|
||||
(*ntotal) += predlen;
|
||||
|
||||
if (reflen < n) {
|
||||
return;
|
||||
}
|
||||
|
||||
reflen = reflen - n + 1;
|
||||
|
||||
std::map<size_t, size_t> count;
|
||||
while (predlen > 0) {
|
||||
size_t w = bleu_hash(n, pred++);
|
||||
count[w]++;
|
||||
predlen--;
|
||||
}
|
||||
|
||||
while (reflen > 0) {
|
||||
size_t w = bleu_hash(n, ref++);
|
||||
if (count[w] > 0) {
|
||||
(*nmatch)++;
|
||||
count[w] -= 1;
|
||||
}
|
||||
reflen--;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
#ifdef _WIN64
|
||||
__declspec(dllexport)
|
||||
#endif
|
||||
void bleu_zero_init(bleu_stat* stat) {
|
||||
std::memset(stat, 0, sizeof(bleu_stat));
|
||||
}
|
||||
|
||||
#ifdef _WIN64
|
||||
__declspec(dllexport)
|
||||
#endif
|
||||
void bleu_one_init(bleu_stat* stat) {
|
||||
bleu_zero_init(stat);
|
||||
stat->count1 = 0;
|
||||
stat->count2 = 1;
|
||||
stat->count3 = 1;
|
||||
stat->count4 = 1;
|
||||
stat->match1 = 0;
|
||||
stat->match2 = 1;
|
||||
stat->match3 = 1;
|
||||
stat->match4 = 1;
|
||||
}
|
||||
|
||||
#ifdef _WIN64
|
||||
__declspec(dllexport)
|
||||
#endif
|
||||
void bleu_add(
|
||||
bleu_stat* stat,
|
||||
size_t reflen,
|
||||
int* ref,
|
||||
size_t predlen,
|
||||
int* pred,
|
||||
int pad,
|
||||
int eos) {
|
||||
|
||||
bleu_trim(&reflen, &ref, pad, eos);
|
||||
bleu_trim(&predlen, &pred, pad, eos);
|
||||
stat->reflen += reflen;
|
||||
stat->predlen += predlen;
|
||||
|
||||
bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred);
|
||||
bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred);
|
||||
bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred);
|
||||
bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred);
|
||||
}
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
/**
|
||||
* Copyright 2017-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
static PyMethodDef method_def[] = {{NULL, NULL, 0, NULL}}; // NOLINT
|
||||
|
||||
static struct PyModuleDef module_def = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"libbleu", /* name of module */
|
||||
// NOLINTNEXTLINE
|
||||
NULL, /* module documentation, may be NULL */
|
||||
-1, /* size of per-interpreter state of the module,
|
||||
or -1 if the module keeps state in global variables. */
|
||||
method_def}; // NOLINT
|
||||
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
PyMODINIT_FUNC init_libbleu()
|
||||
#else
|
||||
PyMODINIT_FUNC PyInit_libbleu()
|
||||
#endif
|
||||
{
|
||||
PyObject* m = PyModule_Create(&module_def);
|
||||
if (!m) {
|
||||
return NULL;
|
||||
}
|
||||
return m;
|
||||
}
|
||||
@@ -1,231 +0,0 @@
|
||||
/**
|
||||
* Copyright 2017-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <pybind11/detail/common.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/torch.h> // @manual=//caffe2:torch_extension
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <iosfwd>
|
||||
#include <memory>
|
||||
#include <new>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
using namespace ::std;
|
||||
|
||||
vector<vector<uint32_t>> edit_distance2_with_dp(
|
||||
vector<uint32_t>& x,
|
||||
vector<uint32_t>& y) {
|
||||
uint32_t lx = x.size();
|
||||
uint32_t ly = y.size();
|
||||
vector<vector<uint32_t>> d(lx + 1, vector<uint32_t>(ly + 1));
|
||||
for (uint32_t i = 0; i < lx + 1; i++) {
|
||||
d[i][0] = i;
|
||||
}
|
||||
for (uint32_t j = 0; j < ly + 1; j++) {
|
||||
d[0][j] = j;
|
||||
}
|
||||
for (uint32_t i = 1; i < lx + 1; i++) {
|
||||
for (uint32_t j = 1; j < ly + 1; j++) {
|
||||
d[i][j] =
|
||||
min(min(d[i - 1][j], d[i][j - 1]) + 1,
|
||||
d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1));
|
||||
}
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
||||
vector<vector<uint32_t>> edit_distance2_backtracking(
|
||||
vector<vector<uint32_t>>& d,
|
||||
vector<uint32_t>& x,
|
||||
vector<uint32_t>& y,
|
||||
uint32_t terminal_symbol) {
|
||||
vector<uint32_t> seq;
|
||||
vector<vector<uint32_t>> edit_seqs(x.size() + 2, vector<uint32_t>());
|
||||
/*
|
||||
edit_seqs:
|
||||
0~x.size() cell is the insertion sequences
|
||||
last cell is the delete sequence
|
||||
*/
|
||||
|
||||
if (x.size() == 0) {
|
||||
edit_seqs.at(0) = y;
|
||||
return edit_seqs;
|
||||
}
|
||||
|
||||
uint32_t i = d.size() - 1;
|
||||
uint32_t j = d.at(0).size() - 1;
|
||||
|
||||
while ((i >= 0) && (j >= 0)) {
|
||||
if ((i == 0) && (j == 0)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
|
||||
seq.push_back(1); // insert
|
||||
seq.push_back(y.at(j - 1));
|
||||
j--;
|
||||
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
|
||||
seq.push_back(2); // delete
|
||||
seq.push_back(x.at(i - 1));
|
||||
i--;
|
||||
} else {
|
||||
seq.push_back(3); // keep
|
||||
seq.push_back(x.at(i - 1));
|
||||
i--;
|
||||
j--;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t prev_op, op, s, word;
|
||||
prev_op = 0, s = 0;
|
||||
for (uint32_t k = 0; k < seq.size() / 2; k++) {
|
||||
op = seq.at(seq.size() - 2 * k - 2);
|
||||
word = seq.at(seq.size() - 2 * k - 1);
|
||||
if (prev_op != 1) {
|
||||
s++;
|
||||
}
|
||||
if (op == 1) // insert
|
||||
{
|
||||
edit_seqs.at(s - 1).push_back(word);
|
||||
} else if (op == 2) // delete
|
||||
{
|
||||
edit_seqs.at(x.size() + 1).push_back(1);
|
||||
} else {
|
||||
edit_seqs.at(x.size() + 1).push_back(0);
|
||||
}
|
||||
|
||||
prev_op = op;
|
||||
}
|
||||
|
||||
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
|
||||
if (edit_seqs[k].size() == 0) {
|
||||
edit_seqs[k].push_back(terminal_symbol);
|
||||
}
|
||||
}
|
||||
return edit_seqs;
|
||||
}
|
||||
|
||||
vector<vector<uint32_t>> edit_distance2_backtracking_with_delete(
|
||||
vector<vector<uint32_t>>& d,
|
||||
vector<uint32_t>& x,
|
||||
vector<uint32_t>& y,
|
||||
uint32_t terminal_symbol,
|
||||
uint32_t deletion_symbol) {
|
||||
vector<uint32_t> seq;
|
||||
vector<vector<uint32_t>> edit_seqs(x.size() + 1, vector<uint32_t>());
|
||||
/*
|
||||
edit_seqs:
|
||||
0~x.size() cell is the insertion sequences
|
||||
last cell is the delete sequence
|
||||
*/
|
||||
|
||||
if (x.size() == 0) {
|
||||
edit_seqs.at(0) = y;
|
||||
return edit_seqs;
|
||||
}
|
||||
|
||||
uint32_t i = d.size() - 1;
|
||||
uint32_t j = d.at(0).size() - 1;
|
||||
|
||||
while ((i >= 0) && (j >= 0)) {
|
||||
if ((i == 0) && (j == 0)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
|
||||
seq.push_back(1); // insert
|
||||
seq.push_back(y.at(j - 1));
|
||||
j--;
|
||||
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
|
||||
seq.push_back(2); // delete
|
||||
seq.push_back(x.at(i - 1));
|
||||
i--;
|
||||
} else {
|
||||
seq.push_back(3); // keep
|
||||
seq.push_back(x.at(i - 1));
|
||||
i--;
|
||||
j--;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t prev_op, op, s, word;
|
||||
prev_op = 0, s = 0;
|
||||
for (uint32_t k = 0; k < seq.size() / 2; k++) {
|
||||
op = seq.at(seq.size() - 2 * k - 2);
|
||||
word = seq.at(seq.size() - 2 * k - 1);
|
||||
if (prev_op != 1) {
|
||||
s++;
|
||||
}
|
||||
if (op == 1) // insert
|
||||
{
|
||||
edit_seqs.at(s - 1).push_back(word);
|
||||
} else if (op == 2) // delete
|
||||
{
|
||||
edit_seqs.at(s - 1).push_back(deletion_symbol);
|
||||
}
|
||||
|
||||
prev_op = op;
|
||||
}
|
||||
|
||||
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
|
||||
if (edit_seqs.at(k).size() == 0) {
|
||||
edit_seqs.at(k).push_back(terminal_symbol);
|
||||
}
|
||||
}
|
||||
return edit_seqs;
|
||||
}
|
||||
|
||||
vector<uint32_t> compute_ed2(
|
||||
vector<vector<uint32_t>>& xs,
|
||||
vector<vector<uint32_t>>& ys) {
|
||||
vector<uint32_t> distances(xs.size());
|
||||
for (uint32_t i = 0; i < xs.size(); i++) {
|
||||
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
||||
distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size());
|
||||
}
|
||||
return distances;
|
||||
}
|
||||
|
||||
vector<vector<vector<uint32_t>>> suggested_ed2_path(
|
||||
vector<vector<uint32_t>>& xs,
|
||||
vector<vector<uint32_t>>& ys,
|
||||
uint32_t terminal_symbol) {
|
||||
vector<vector<vector<uint32_t>>> seq(xs.size());
|
||||
for (uint32_t i = 0; i < xs.size(); i++) {
|
||||
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
||||
seq.at(i) =
|
||||
edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol);
|
||||
}
|
||||
return seq;
|
||||
}
|
||||
|
||||
vector<vector<vector<uint32_t>>> suggested_ed2_path_with_delete(
|
||||
vector<vector<uint32_t>>& xs,
|
||||
vector<vector<uint32_t>>& ys,
|
||||
uint32_t terminal_symbol,
|
||||
uint32_t deletion_symbol) {
|
||||
vector<vector<vector<uint32_t>>> seq(xs.size());
|
||||
for (uint32_t i = 0; i < xs.size(); i++) {
|
||||
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
||||
seq.at(i) = edit_distance2_backtracking_with_delete(
|
||||
d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol);
|
||||
}
|
||||
return seq;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(libnat, m) {
|
||||
m.def("compute_ed2", &compute_ed2, "compute_ed2");
|
||||
m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path");
|
||||
m.def(
|
||||
"suggested_ed2_path_with_delete",
|
||||
&suggested_ed2_path_with_delete,
|
||||
"suggested_ed2_path_with_delete");
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
/**
|
||||
* Copyright 2017-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
/*
|
||||
This code is partially adpoted from
|
||||
https://github.com/1ytic/pytorch-edit-distance
|
||||
*/
|
||||
|
||||
#include <torch/types.h>
|
||||
#include "edit_dist.h"
|
||||
|
||||
#ifndef TORCH_CHECK
|
||||
#define TORCH_CHECK AT_CHECK
|
||||
#endif
|
||||
|
||||
#define CHECK_CUDA(x) \
|
||||
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor LevenshteinDistance(
|
||||
torch::Tensor source,
|
||||
torch::Tensor target,
|
||||
torch::Tensor source_length,
|
||||
torch::Tensor target_length) {
|
||||
CHECK_INPUT(source);
|
||||
CHECK_INPUT(target);
|
||||
CHECK_INPUT(source_length);
|
||||
CHECK_INPUT(target_length);
|
||||
return LevenshteinDistanceCuda(source, target, source_length, target_length);
|
||||
}
|
||||
|
||||
torch::Tensor GenerateDeletionLabel(
|
||||
torch::Tensor source,
|
||||
torch::Tensor operations) {
|
||||
CHECK_INPUT(source);
|
||||
CHECK_INPUT(operations);
|
||||
return GenerateDeletionLabelCuda(source, operations);
|
||||
}
|
||||
|
||||
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel(
|
||||
torch::Tensor target,
|
||||
torch::Tensor operations) {
|
||||
CHECK_INPUT(target);
|
||||
CHECK_INPUT(operations);
|
||||
return GenerateInsertionLabelCuda(target, operations);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance");
|
||||
m.def(
|
||||
"generate_deletion_labels",
|
||||
&GenerateDeletionLabel,
|
||||
"Generate Deletion Label");
|
||||
m.def(
|
||||
"generate_insertion_labels",
|
||||
&GenerateInsertionLabel,
|
||||
"Generate Insertion Label");
|
||||
}
|
||||
@@ -1,344 +0,0 @@
|
||||
/**
|
||||
* Copyright 2017-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include "edit_dist.h"
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <device_launch_parameters.h>
|
||||
#include <utility> // std::pair
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void generate_deletion_label_kernel(
|
||||
const scalar_t* __restrict__ source,
|
||||
const size_t source_size,
|
||||
const size_t operation_size,
|
||||
int* __restrict__ operations,
|
||||
int* __restrict__ labels) {
|
||||
const int index = blockIdx.x;
|
||||
const int offset = index * operation_size;
|
||||
const int offset_label = index * source_size;
|
||||
|
||||
for (int i = 0; i < source_size; i++) {
|
||||
labels[offset_label + i] = 0;
|
||||
}
|
||||
|
||||
int k = 0;
|
||||
for (int i = 0; i < operation_size; i++) {
|
||||
if (operations[offset + i] == 0) {
|
||||
break;
|
||||
} else if (operations[offset + i] == 1) {
|
||||
continue;
|
||||
} else {
|
||||
labels[offset_label + k] = 3 - operations[offset + i];
|
||||
k++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void generate_insertion_label_kernel(
|
||||
const scalar_t* __restrict__ target,
|
||||
const size_t target_size,
|
||||
const size_t operation_size,
|
||||
int* __restrict__ operations,
|
||||
int* __restrict__ labels,
|
||||
int* __restrict__ masks) {
|
||||
const int index = blockIdx.x;
|
||||
const int offset = index * operation_size;
|
||||
const int offset_label = index * target_size;
|
||||
|
||||
int k = 0;
|
||||
int u = 0;
|
||||
int m = 0;
|
||||
|
||||
for (int i = 0; i < target_size; i++) {
|
||||
labels[offset_label + i] = 0;
|
||||
masks[offset_label + i] = 0;
|
||||
}
|
||||
|
||||
for (int i = 0; i < operation_size - 1; i++) {
|
||||
if (operations[offset + i] == 0) {
|
||||
break;
|
||||
} else if (operations[offset + i] == 2) {
|
||||
continue;
|
||||
} else if (operations[offset + i] == 1) {
|
||||
masks[offset_label + m] = 1;
|
||||
u++;
|
||||
m++;
|
||||
} else {
|
||||
labels[offset_label + k] = u;
|
||||
masks[offset_label + m] = 0;
|
||||
k++;
|
||||
m++;
|
||||
u = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void levenshtein_distance_kernel(
|
||||
const scalar_t* __restrict__ source,
|
||||
const scalar_t* __restrict__ target,
|
||||
const int* __restrict__ source_length,
|
||||
const int* __restrict__ target_length,
|
||||
const size_t source_size,
|
||||
const size_t target_size,
|
||||
int* __restrict__ operations,
|
||||
int* __restrict__ errors_curr) {
|
||||
const int index = blockIdx.x;
|
||||
const int offset = index * (source_size + target_size);
|
||||
const int d = index * (source_size + 1) * (target_size + 1);
|
||||
const int t = target_size + 1;
|
||||
|
||||
auto err_idx = [d, t](int i, int j) { return d + i * t + j; };
|
||||
auto opt_idx = [offset](int k) { return offset + k; };
|
||||
|
||||
const int hyp_len = source_length[index];
|
||||
const int ref_len = target_length[index];
|
||||
const scalar_t* hyp_begin = source + index * source_size;
|
||||
const scalar_t* ref_begin = target + index * target_size;
|
||||
|
||||
// dynamic programming
|
||||
for (int i = 0; i <= hyp_len; i++) {
|
||||
errors_curr[err_idx(i, 0)] = i;
|
||||
}
|
||||
for (int j = 0; j <= ref_len; j++) {
|
||||
errors_curr[err_idx(0, j)] = j;
|
||||
}
|
||||
for (int i = 1; i <= hyp_len; i++) {
|
||||
for (int j = 1; j <= ref_len; j++) {
|
||||
errors_curr[err_idx(i, j)] = min(
|
||||
min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
|
||||
1,
|
||||
errors_curr[err_idx(i - 1, j - 1)] +
|
||||
2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
|
||||
}
|
||||
}
|
||||
|
||||
// back-tracing
|
||||
int i = hyp_len;
|
||||
int j = ref_len;
|
||||
int o = hyp_len + ref_len;
|
||||
|
||||
for (int k = 0; k < source_size + target_size; k++) {
|
||||
operations[opt_idx(k)] = 0;
|
||||
}
|
||||
|
||||
while ((i >= 0) && (j >= 0)) {
|
||||
if ((i == 0) && (j == 0)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if ((j > 0) &&
|
||||
(errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
|
||||
o--;
|
||||
operations[opt_idx(o)] = 1;
|
||||
j--; // insertion
|
||||
} else if (
|
||||
(i > 0) &&
|
||||
(errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
|
||||
o--;
|
||||
operations[opt_idx(o)] = 2;
|
||||
i--; // deletion
|
||||
} else {
|
||||
o--;
|
||||
operations[opt_idx(o)] = 3;
|
||||
i--;
|
||||
j--; // do nothing
|
||||
}
|
||||
}
|
||||
|
||||
// moving to the left
|
||||
for (int k = 0; k < hyp_len + ref_len; k++) {
|
||||
if (k + o < hyp_len + ref_len) {
|
||||
operations[opt_idx(k)] = operations[opt_idx(k + o)];
|
||||
} else {
|
||||
operations[opt_idx(k)] = 0; // padding
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void faster_levenshtein_distance_kernel(
|
||||
const scalar_t* __restrict__ source,
|
||||
const scalar_t* __restrict__ target,
|
||||
const int* __restrict__ source_length,
|
||||
const int* __restrict__ target_length,
|
||||
const size_t source_size,
|
||||
const size_t target_size,
|
||||
int* __restrict__ operations) {
|
||||
extern __shared__ short errors[];
|
||||
auto errors_curr = errors;
|
||||
|
||||
const int index = blockIdx.x;
|
||||
const int offset = index * (source_size + target_size);
|
||||
const int t = target_size + 1;
|
||||
|
||||
auto err_idx = [t](int i, int j) { return i * t + j; };
|
||||
auto opt_idx = [offset](int k) { return offset + k; };
|
||||
|
||||
const int hyp_len = source_length[index];
|
||||
const int ref_len = target_length[index];
|
||||
const scalar_t* hyp_begin = source + index * source_size;
|
||||
const scalar_t* ref_begin = target + index * target_size;
|
||||
|
||||
// dynamic programming
|
||||
for (int i = 0; i <= hyp_len; i++) {
|
||||
errors_curr[err_idx(i, 0)] = i;
|
||||
}
|
||||
for (int j = 0; j <= ref_len; j++) {
|
||||
errors_curr[err_idx(0, j)] = j;
|
||||
}
|
||||
for (int i = 1; i <= hyp_len; i++) {
|
||||
for (int j = 1; j <= ref_len; j++) {
|
||||
errors_curr[err_idx(i, j)] = min(
|
||||
min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
|
||||
1,
|
||||
errors_curr[err_idx(i - 1, j - 1)] +
|
||||
2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
|
||||
}
|
||||
}
|
||||
|
||||
// back-tracing
|
||||
int i = hyp_len;
|
||||
int j = ref_len;
|
||||
int o = hyp_len + ref_len;
|
||||
|
||||
for (int k = 0; k < source_size + target_size; k++) {
|
||||
operations[opt_idx(k)] = 0;
|
||||
}
|
||||
|
||||
while ((i >= 0) && (j >= 0)) {
|
||||
if ((i == 0) && (j == 0)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if ((j > 0) &&
|
||||
(errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
|
||||
o--;
|
||||
operations[opt_idx(o)] = 1;
|
||||
j--; // insertion
|
||||
} else if (
|
||||
(i > 0) &&
|
||||
(errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
|
||||
o--;
|
||||
operations[opt_idx(o)] = 2;
|
||||
i--; // deletion
|
||||
} else {
|
||||
o--;
|
||||
operations[opt_idx(o)] = 3;
|
||||
i--;
|
||||
j--; // do nothing
|
||||
}
|
||||
}
|
||||
|
||||
// moving to the left
|
||||
for (int k = 0; k < hyp_len + ref_len; k++) {
|
||||
if (k + o < hyp_len + ref_len) {
|
||||
operations[opt_idx(k)] = operations[opt_idx(k + o)];
|
||||
} else {
|
||||
operations[opt_idx(k)] = 0; // padding
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor GenerateDeletionLabelCuda(
|
||||
torch::Tensor source,
|
||||
torch::Tensor operations) {
|
||||
const auto batch_size = source.size(0);
|
||||
at::TensorOptions options(source.device());
|
||||
options = options.dtype(at::ScalarType::Int);
|
||||
auto labels = torch::empty({batch_size, source.size(1)}, options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] {
|
||||
generate_deletion_label_kernel<scalar_t>
|
||||
<<<batch_size, 1, 0, stream>>>(
|
||||
source.data_ptr<scalar_t>(),
|
||||
source.size(1),
|
||||
operations.size(1),
|
||||
operations.data_ptr<int>(),
|
||||
labels.data_ptr<int>());
|
||||
}));
|
||||
|
||||
return labels;
|
||||
}
|
||||
|
||||
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
|
||||
torch::Tensor target,
|
||||
torch::Tensor operations) {
|
||||
const auto batch_size = target.size(0);
|
||||
at::TensorOptions options(target.device());
|
||||
options = options.dtype(at::ScalarType::Int);
|
||||
auto labels = torch::empty({batch_size, target.size(1)}, options);
|
||||
auto masks = torch::empty({batch_size, target.size(1)}, options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(target.device().index());
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
target.scalar_type(), "generate_insertion_labels", ([&] {
|
||||
generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
|
||||
target.data_ptr<scalar_t>(),
|
||||
target.size(1),
|
||||
operations.size(1),
|
||||
operations.data_ptr<int>(),
|
||||
labels.data_ptr<int>(),
|
||||
masks.data_ptr<int>());
|
||||
}));
|
||||
|
||||
return std::make_pair(labels, masks);
|
||||
}
|
||||
|
||||
torch::Tensor LevenshteinDistanceCuda(
|
||||
torch::Tensor source,
|
||||
torch::Tensor target,
|
||||
torch::Tensor source_length,
|
||||
torch::Tensor target_length) {
|
||||
const auto batch_size = source.size(0);
|
||||
const auto shared_size =
|
||||
(source.size(1) + 1) * (target.size(1) + 1) * sizeof(short);
|
||||
|
||||
at::TensorOptions options(source.device());
|
||||
options = options.dtype(at::ScalarType::Int);
|
||||
auto operations =
|
||||
torch::empty({batch_size, source.size(1) + target.size(1)}, options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
|
||||
|
||||
if (shared_size > 40000) {
|
||||
auto distances = torch::empty(
|
||||
{batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options);
|
||||
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] {
|
||||
levenshtein_distance_kernel<scalar_t>
|
||||
<<<batch_size, 1, 0, stream>>>(
|
||||
source.data_ptr<scalar_t>(),
|
||||
target.data_ptr<scalar_t>(),
|
||||
source_length.data_ptr<int>(),
|
||||
target_length.data_ptr<int>(),
|
||||
source.size(1),
|
||||
target.size(1),
|
||||
operations.data_ptr<int>(),
|
||||
distances.data_ptr<int>());
|
||||
}));
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
source.scalar_type(), "faster_levenshtein_distance", ([&] {
|
||||
faster_levenshtein_distance_kernel<scalar_t>
|
||||
<<<batch_size, 1, shared_size, stream>>>(
|
||||
source.data_ptr<scalar_t>(),
|
||||
target.data_ptr<scalar_t>(),
|
||||
source_length.data_ptr<int>(),
|
||||
target_length.data_ptr<int>(),
|
||||
source.size(1),
|
||||
target.size(1),
|
||||
operations.data_ptr<int>());
|
||||
}));
|
||||
}
|
||||
|
||||
return operations;
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
/**
|
||||
* Copyright 2017-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
torch::Tensor LevenshteinDistanceCuda(
|
||||
torch::Tensor source,
|
||||
torch::Tensor target,
|
||||
torch::Tensor source_length,
|
||||
torch::Tensor target_length);
|
||||
|
||||
torch::Tensor GenerateDeletionLabelCuda(
|
||||
torch::Tensor source,
|
||||
torch::Tensor operations);
|
||||
|
||||
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
|
||||
torch::Tensor source,
|
||||
torch::Tensor operations);
|
||||
@@ -1,4 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
@@ -1,19 +0,0 @@
|
||||
# @package _group_
|
||||
|
||||
hydra:
|
||||
run:
|
||||
dir: .
|
||||
|
||||
defaults:
|
||||
- _self_
|
||||
- task: null
|
||||
- model: null
|
||||
- criterion: cross_entropy
|
||||
- optimizer: null
|
||||
- lr_scheduler: fixed
|
||||
- bpe: null
|
||||
- tokenizer: null
|
||||
- scoring: null
|
||||
- generation: null
|
||||
- common_eval: null
|
||||
- eval_lm: null
|
||||
@@ -1,36 +0,0 @@
|
||||
# @package _group_
|
||||
activation_fn: "relu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 512
|
||||
decoder_output_dim: 512
|
||||
decoder_input_dim: 512
|
||||
decoder_ffn_embed_dim: 4096
|
||||
decoder_layers: 12
|
||||
decoder_attention_heads: 16
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: true
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
||||
@@ -1,36 +0,0 @@
|
||||
# @package _group_
|
||||
activation_fn: "relu"
|
||||
dropout: 0.3
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.1
|
||||
relu_dropout: 0.1
|
||||
decoder_embed_dim: 1024
|
||||
decoder_output_dim: 1024
|
||||
decoder_input_dim: 1024
|
||||
decoder_ffn_embed_dim: 4096
|
||||
decoder_layers: 16
|
||||
decoder_attention_heads: 8
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: true
|
||||
adaptive_softmax_cutoff: "20000,60000"
|
||||
adaptive_softmax_dropout: 0.2
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: true
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: "20000,60000"
|
||||
tie_adaptive_weights: true
|
||||
tie_adaptive_proj: true
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
||||
@@ -1,36 +0,0 @@
|
||||
# @package _group_
|
||||
activation_fn: "relu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.0
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 1024
|
||||
decoder_output_dim: 1024
|
||||
decoder_input_dim: 1024
|
||||
decoder_ffn_embed_dim: 4096
|
||||
decoder_layers: 12
|
||||
decoder_attention_heads: 16
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: false
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
||||
@@ -1,36 +0,0 @@
|
||||
# @package _group_
|
||||
activation_fn: "relu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 512
|
||||
decoder_output_dim: 512
|
||||
decoder_input_dim: 512
|
||||
decoder_ffn_embed_dim: 4096
|
||||
decoder_layers: 12
|
||||
decoder_attention_heads: 16
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: true
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
||||
@@ -1,36 +0,0 @@
|
||||
# @package _group_
|
||||
activation_fn: "gelu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 768
|
||||
decoder_output_dim: 768
|
||||
decoder_input_dim: 768
|
||||
decoder_ffn_embed_dim: 3072
|
||||
decoder_layers: 12
|
||||
decoder_attention_heads: 12
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: false
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
||||
@@ -1,36 +0,0 @@
|
||||
# @package _group_
|
||||
activation_fn: "gelu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 1600
|
||||
decoder_output_dim: 1600
|
||||
decoder_input_dim: 1600
|
||||
decoder_ffn_embed_dim: 6400
|
||||
decoder_layers: 48
|
||||
decoder_attention_heads: 25
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: false
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
||||
@@ -1,36 +0,0 @@
|
||||
# @package _group_
|
||||
activation_fn: "gelu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 1280
|
||||
decoder_output_dim: 1280
|
||||
decoder_input_dim: 1280
|
||||
decoder_ffn_embed_dim: 5120
|
||||
decoder_layers: 36
|
||||
decoder_attention_heads: 20
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: false
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
||||
@@ -1,36 +0,0 @@
|
||||
# @package _group_
|
||||
activation_fn: "gelu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 1024
|
||||
decoder_output_dim: 1024
|
||||
decoder_input_dim: 1024
|
||||
decoder_ffn_embed_dim: 4096
|
||||
decoder_layers: 24
|
||||
decoder_attention_heads: 16
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: false
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
||||
@@ -1,36 +0,0 @@
|
||||
# @package _group_
|
||||
activation_fn: "relu"
|
||||
dropout: 0.3
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.1
|
||||
relu_dropout: 0.1
|
||||
decoder_embed_dim: 1024
|
||||
decoder_output_dim: 1024
|
||||
decoder_input_dim: 1024
|
||||
decoder_ffn_embed_dim: 4096
|
||||
decoder_layers: 16
|
||||
decoder_attention_heads: 8
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: true
|
||||
adaptive_softmax_cutoff: "20000,60000"
|
||||
adaptive_softmax_dropout: 0.2
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: true
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: "20000,60000"
|
||||
tie_adaptive_weights: true
|
||||
tie_adaptive_proj: true
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
||||
@@ -1,5 +0,0 @@
|
||||
# @package _group_
|
||||
activation: gelu
|
||||
vq_type: gumbel
|
||||
vq_depth: 2
|
||||
combine_groups: true
|
||||
@@ -1,8 +0,0 @@
|
||||
# @package _group_
|
||||
|
||||
quantize_targets: true
|
||||
final_dim: 256
|
||||
encoder_layerdrop: 0.05
|
||||
dropout_input: 0.1
|
||||
dropout_features: 0.1
|
||||
feature_grad_mult: 0.1
|
||||
@@ -1,20 +0,0 @@
|
||||
# @package _group_
|
||||
|
||||
quantize_targets: true
|
||||
extractor_mode: layer_norm
|
||||
layer_norm_first: true
|
||||
final_dim: 768
|
||||
latent_temp: [2.0,0.1,0.999995]
|
||||
encoder_layerdrop: 0.0
|
||||
dropout_input: 0.0
|
||||
dropout_features: 0.0
|
||||
dropout: 0.0
|
||||
attention_dropout: 0.0
|
||||
conv_bias: true
|
||||
|
||||
encoder_layers: 24
|
||||
encoder_embed_dim: 1024
|
||||
encoder_ffn_embed_dim: 4096
|
||||
encoder_attention_heads: 16
|
||||
|
||||
feature_grad_mult: 1.0
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""isort:skip_file"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
from fairseq import registry
|
||||
from fairseq.criterions.fairseq_criterion import ( # noqa
|
||||
FairseqCriterion,
|
||||
LegacyFairseqCriterion,
|
||||
)
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
(
|
||||
build_criterion_,
|
||||
register_criterion,
|
||||
CRITERION_REGISTRY,
|
||||
CRITERION_DATACLASS_REGISTRY,
|
||||
) = registry.setup_registry(
|
||||
"--criterion", base_class=FairseqCriterion, default="cross_entropy"
|
||||
)
|
||||
|
||||
|
||||
def build_criterion(cfg: DictConfig, task):
|
||||
return build_criterion_(cfg, task)
|
||||
|
||||
|
||||
# automatically import any Python files in the criterions/ directory
|
||||
for file in sorted(os.listdir(os.path.dirname(__file__))):
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
file_name = file[: file.find(".py")]
|
||||
importlib.import_module("fairseq.criterions." + file_name)
|
||||
@@ -1,123 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch.nn.functional as F
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.dataclass.constants import DDP_BACKEND_CHOICES
|
||||
from omegaconf import II
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdaptiveLossConfig(FairseqDataclass):
|
||||
sentence_avg: bool = II("optimization.sentence_avg")
|
||||
ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend")
|
||||
|
||||
|
||||
@register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig)
|
||||
class AdaptiveLoss(FairseqCriterion):
|
||||
"""This is an implementation of the loss function accompanying the adaptive softmax approximation for
|
||||
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
|
||||
(http://arxiv.org/abs/1609.04309)."""
|
||||
|
||||
def __init__(self, task, sentence_avg):
|
||||
super().__init__(task)
|
||||
self.sentence_avg = sentence_avg
|
||||
|
||||
@classmethod
|
||||
def build_criterion(cls, cfg: AdaptiveLossConfig, task):
|
||||
if cfg.ddp_backend in {"c10d", "pytorch_ddp"}:
|
||||
raise Exception(
|
||||
"AdaptiveLoss is not compatible with the PyTorch "
|
||||
"version of DistributedDataParallel. Please use "
|
||||
"`--ddp-backend=legacy_ddp` instead."
|
||||
)
|
||||
return cls(task, cfg.sentence_avg)
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
|
||||
assert (
|
||||
hasattr(model.decoder, "adaptive_softmax")
|
||||
and model.decoder.adaptive_softmax is not None
|
||||
)
|
||||
adaptive_softmax = model.decoder.adaptive_softmax
|
||||
|
||||
net_output = model(**sample["net_input"])
|
||||
orig_target = model.get_targets(sample, net_output)
|
||||
|
||||
nsentences = orig_target.size(0)
|
||||
orig_target = orig_target.view(-1)
|
||||
|
||||
bsz = orig_target.size(0)
|
||||
|
||||
logits, target = adaptive_softmax(net_output[0], orig_target)
|
||||
assert len(target) == len(logits)
|
||||
|
||||
loss = net_output[0].new(1 if reduce else bsz).zero_()
|
||||
|
||||
for i in range(len(target)):
|
||||
if target[i] is not None:
|
||||
assert target[i].min() >= 0 and target[i].max() <= logits[i].size(1)
|
||||
loss += F.cross_entropy(
|
||||
logits[i],
|
||||
target[i],
|
||||
ignore_index=self.padding_idx,
|
||||
reduction="sum" if reduce else "none",
|
||||
)
|
||||
|
||||
orig = utils.strip_pad(orig_target, self.padding_idx)
|
||||
ntokens = orig.numel()
|
||||
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"ntokens": ntokens,
|
||||
"nsentences": nsentences,
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
||||
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
||||
sample_size = utils.item(
|
||||
sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
if sample_size != ntokens:
|
||||
metrics.log_scalar(
|
||||
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
||||
)
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
||||
)
|
||||
else:
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
@@ -1,100 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
|
||||
from torch import nn
|
||||
|
||||
|
||||
@register_criterion("composite_loss")
|
||||
class CompositeLoss(LegacyFairseqCriterion):
|
||||
"""This is a composite loss that, given a list of model outputs and a list of targets,
|
||||
computes an average of losses for each output-target pair"""
|
||||
|
||||
def __init__(self, args, task):
|
||||
super().__init__(args, task)
|
||||
self.underlying_criterion = args.underlying_criterion
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add criterion-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True,
|
||||
help='underlying criterion to use for the composite loss')
|
||||
# fmt: on
|
||||
|
||||
@staticmethod
|
||||
def build_underlying_criterion(args, task):
|
||||
saved_criterion = args.criterion
|
||||
args.criterion = args.underlying_criterion
|
||||
assert saved_criterion != args.underlying_criterion
|
||||
underlying_criterion = task.build_criterion(args)
|
||||
args.criterion = saved_criterion
|
||||
return underlying_criterion
|
||||
|
||||
@classmethod
|
||||
def build_criterion(cls, args, task):
|
||||
underlying_criterion = CompositeLoss.build_underlying_criterion(args, task)
|
||||
|
||||
class FakeModel(nn.Module):
|
||||
def __init__(self, model, net_out, target):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.net_out = net_out
|
||||
self.target = target
|
||||
|
||||
def forward(self, **unused):
|
||||
return self.net_out
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
||||
return self.model.get_normalized_probs(
|
||||
net_output, log_probs, sample=sample
|
||||
)
|
||||
|
||||
def get_targets(self, *unused):
|
||||
return self.target
|
||||
|
||||
@property
|
||||
def decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
class _CompositeLoss(LegacyFairseqCriterion):
|
||||
def __init__(self, args, task, underlying_criterion):
|
||||
super().__init__(args, task)
|
||||
self.underlying_criterion = underlying_criterion
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
net_outputs = model(**sample["net_input"])
|
||||
targets = sample["target"]
|
||||
|
||||
bsz = targets[0].size(0)
|
||||
loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_()
|
||||
|
||||
sample_size = 0
|
||||
logging_output = {}
|
||||
for o, t in zip(net_outputs[0], targets):
|
||||
m = FakeModel(model, (o, net_outputs[1]), t)
|
||||
sample["target"] = t
|
||||
l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
|
||||
loss += l
|
||||
sample_size += ss
|
||||
|
||||
loss.div_(len(targets))
|
||||
sample_size /= len(targets)
|
||||
|
||||
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def aggregate_logging_outputs(logging_outputs):
|
||||
return underlying_criterion.__class__.aggregate_logging_outputs(
|
||||
logging_outputs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
underlying_criterion.__class__.reduce_metrics(logging_outputs)
|
||||
|
||||
return _CompositeLoss(args, task, underlying_criterion)
|
||||
@@ -1,90 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch.nn.functional as F
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from omegaconf import II
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrossEntropyCriterionConfig(FairseqDataclass):
|
||||
sentence_avg: bool = II("optimization.sentence_avg")
|
||||
|
||||
|
||||
@register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig)
|
||||
class CrossEntropyCriterion(FairseqCriterion):
|
||||
def __init__(self, task, sentence_avg):
|
||||
super().__init__(task)
|
||||
self.sentence_avg = sentence_avg
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
net_output = model(**sample["net_input"])
|
||||
loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def compute_loss(self, model, net_output, sample, reduce=True):
|
||||
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
||||
lprobs = lprobs.view(-1, lprobs.size(-1))
|
||||
target = model.get_targets(sample, net_output).view(-1)
|
||||
loss = F.nll_loss(
|
||||
lprobs,
|
||||
target,
|
||||
ignore_index=self.padding_idx,
|
||||
reduction="sum" if reduce else "none",
|
||||
)
|
||||
return loss, loss
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
# we divide by log(2) to convert the loss from base e to base 2
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
if sample_size != ntokens:
|
||||
metrics.log_scalar(
|
||||
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
||||
)
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
||||
)
|
||||
else:
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
@@ -1,319 +0,0 @@
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import math
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from omegaconf import II
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.data.data_utils import post_process
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.logging.meters import safe_round
|
||||
from fairseq.tasks import FairseqTask
|
||||
|
||||
|
||||
@dataclass
|
||||
class CtcCriterionConfig(FairseqDataclass):
|
||||
zero_infinity: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "zero inf loss when source length <= target length"},
|
||||
)
|
||||
sentence_avg: bool = II("optimization.sentence_avg")
|
||||
post_process: str = field(
|
||||
default="letter",
|
||||
metadata={
|
||||
"help": "how to post process predictions into words. can be letter, "
|
||||
"wordpiece, BPE symbols, etc. "
|
||||
"See fairseq.data.data_utils.post_process() for full list of options"
|
||||
},
|
||||
)
|
||||
wer_kenlm_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
|
||||
},
|
||||
)
|
||||
wer_lexicon: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "lexicon to use with wer_kenlm_model"},
|
||||
)
|
||||
wer_lm_weight: float = field(
|
||||
default=2.0,
|
||||
metadata={"help": "lm weight to use with wer_kenlm_model"},
|
||||
)
|
||||
wer_word_score: float = field(
|
||||
default=-1.0,
|
||||
metadata={"help": "lm word score to use with wer_kenlm_model"},
|
||||
)
|
||||
|
||||
wer_args: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@register_criterion("ctc", dataclass=CtcCriterionConfig)
|
||||
class CtcCriterion(FairseqCriterion):
|
||||
def __init__(
|
||||
self, cfg: CtcCriterionConfig, task: FairseqTask, rdrop_alpha: int = 0.0
|
||||
):
|
||||
super().__init__(task)
|
||||
self.blank_idx = (
|
||||
task.target_dictionary.index(task.blank_symbol)
|
||||
if hasattr(task, "blank_symbol")
|
||||
else 0
|
||||
)
|
||||
self.pad_idx = task.target_dictionary.pad()
|
||||
self.eos_idx = task.target_dictionary.eos()
|
||||
self.post_process = cfg.post_process
|
||||
|
||||
self.rdrop_alpha = rdrop_alpha
|
||||
|
||||
if cfg.wer_args is not None:
|
||||
(
|
||||
cfg.wer_kenlm_model,
|
||||
cfg.wer_lexicon,
|
||||
cfg.wer_lm_weight,
|
||||
cfg.wer_word_score,
|
||||
) = eval(cfg.wer_args)
|
||||
|
||||
if cfg.wer_kenlm_model is not None and cfg.wer_kenlm_model != "":
|
||||
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
||||
|
||||
dec_args = Namespace()
|
||||
dec_args.nbest = 1
|
||||
dec_args.criterion = "ctc"
|
||||
dec_args.kenlm_model = cfg.wer_kenlm_model
|
||||
dec_args.lexicon = cfg.wer_lexicon
|
||||
dec_args.beam = 50
|
||||
dec_args.beam_size_token = min(50, len(task.target_dictionary))
|
||||
dec_args.beam_threshold = min(50, len(task.target_dictionary))
|
||||
dec_args.lm_weight = cfg.wer_lm_weight
|
||||
dec_args.word_score = cfg.wer_word_score
|
||||
dec_args.unk_weight = -math.inf
|
||||
dec_args.sil_weight = 0
|
||||
|
||||
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
|
||||
else:
|
||||
self.w2l_decoder = None
|
||||
|
||||
self.zero_infinity = cfg.zero_infinity
|
||||
self.sentence_avg = cfg.sentence_avg
|
||||
|
||||
def forward(self, model, sample, reduce=True, **kwargs):
|
||||
net_output = model(**sample["net_input"])
|
||||
lprobs = model.get_normalized_probs(
|
||||
net_output, log_probs=True
|
||||
).contiguous() # (T, B, C) from the encoder
|
||||
|
||||
# CTC loss is calculated over duplicated inputs
|
||||
# sample is already duplicated for R-Drop
|
||||
if self.rdrop_alpha > 0:
|
||||
for k, v in sample.items():
|
||||
if k in ["target", "target_lengths"]:
|
||||
sample[k] = torch.cat([v, v.clone()], dim=0)
|
||||
elif k == "net_input":
|
||||
if sample[k]["src_tokens"].size(1) != sample[k]["src_lengths"].size(
|
||||
0
|
||||
):
|
||||
# for decoder CTC loss
|
||||
sample[k]["src_lengths"] = torch.cat(
|
||||
[
|
||||
sample[k]["src_lengths"],
|
||||
sample[k]["src_lengths"].clone(),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if "src_lengths" in sample["net_input"]:
|
||||
input_lengths = sample["net_input"]["src_lengths"]
|
||||
else:
|
||||
if net_output["padding_mask"] is not None:
|
||||
non_padding_mask = ~net_output["padding_mask"]
|
||||
input_lengths = non_padding_mask.long().sum(-1)
|
||||
else:
|
||||
input_lengths = lprobs.new_full(
|
||||
(lprobs.size(1),), lprobs.size(0), dtype=torch.long
|
||||
)
|
||||
|
||||
pad_mask = (sample["target"] != self.pad_idx) & (
|
||||
sample["target"] != self.eos_idx
|
||||
)
|
||||
targets_flat = sample["target"].masked_select(pad_mask)
|
||||
if "target_lengths" in sample:
|
||||
target_lengths = sample["target_lengths"]
|
||||
else:
|
||||
target_lengths = pad_mask.sum(-1)
|
||||
|
||||
with torch.backends.cudnn.flags(enabled=False):
|
||||
loss = F.ctc_loss(
|
||||
lprobs,
|
||||
targets_flat,
|
||||
input_lengths,
|
||||
target_lengths,
|
||||
blank=self.blank_idx,
|
||||
reduction="sum",
|
||||
zero_infinity=self.zero_infinity,
|
||||
)
|
||||
|
||||
ntokens = (
|
||||
sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
|
||||
)
|
||||
|
||||
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data), # * sample['ntokens'],
|
||||
"ntokens": ntokens,
|
||||
"nsentences": sample["id"].numel(),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
|
||||
if not model.training:
|
||||
import editdistance
|
||||
|
||||
with torch.no_grad():
|
||||
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
|
||||
|
||||
c_err = 0
|
||||
c_len = 0
|
||||
w_errs = 0
|
||||
w_len = 0
|
||||
wv_errs = 0
|
||||
for lp, t, inp_l in zip(
|
||||
lprobs_t,
|
||||
sample["target_label"]
|
||||
if "target_label" in sample
|
||||
else sample["target"],
|
||||
input_lengths,
|
||||
):
|
||||
lp = lp[:inp_l].unsqueeze(0)
|
||||
|
||||
decoded = None
|
||||
if self.w2l_decoder is not None:
|
||||
decoded = self.w2l_decoder.decode(lp)
|
||||
if len(decoded) < 1:
|
||||
decoded = None
|
||||
else:
|
||||
decoded = decoded[0]
|
||||
if len(decoded) < 1:
|
||||
decoded = None
|
||||
else:
|
||||
decoded = decoded[0]
|
||||
|
||||
p = (t != self.task.target_dictionary.pad()) & (
|
||||
t != self.task.target_dictionary.eos()
|
||||
)
|
||||
targ = t[p]
|
||||
targ_units = self.task.target_dictionary.string(targ)
|
||||
targ_units_arr = targ.tolist()
|
||||
|
||||
toks = lp.argmax(dim=-1).unique_consecutive()
|
||||
pred_units_arr = toks[toks != self.blank_idx].tolist()
|
||||
|
||||
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
|
||||
c_len += len(targ_units_arr)
|
||||
|
||||
targ_words = post_process(targ_units, self.post_process).split()
|
||||
|
||||
pred_units = self.task.target_dictionary.string(pred_units_arr)
|
||||
pred_words_raw = post_process(pred_units, self.post_process).split()
|
||||
|
||||
if decoded is not None and "words" in decoded:
|
||||
pred_words = decoded["words"]
|
||||
w_errs += editdistance.eval(pred_words, targ_words)
|
||||
wv_errs += editdistance.eval(pred_words_raw, targ_words)
|
||||
else:
|
||||
dist = editdistance.eval(pred_words_raw, targ_words)
|
||||
w_errs += dist
|
||||
wv_errs += dist
|
||||
|
||||
w_len += len(targ_words)
|
||||
|
||||
logging_output["wv_errors"] = wv_errs
|
||||
logging_output["w_errors"] = w_errs
|
||||
logging_output["w_total"] = w_len
|
||||
logging_output["c_errors"] = c_err
|
||||
logging_output["c_total"] = c_len
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
|
||||
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
||||
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
||||
nsentences = utils.item(
|
||||
sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
)
|
||||
sample_size = utils.item(
|
||||
sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
metrics.log_scalar("ntokens", ntokens)
|
||||
metrics.log_scalar("nsentences", nsentences)
|
||||
if sample_size != ntokens:
|
||||
metrics.log_scalar(
|
||||
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
||||
)
|
||||
|
||||
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_c_errors", c_errors)
|
||||
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_c_total", c_total)
|
||||
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_w_errors", w_errors)
|
||||
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_wv_errors", wv_errors)
|
||||
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_w_total", w_total)
|
||||
|
||||
if c_total > 0:
|
||||
metrics.log_derived(
|
||||
"uer",
|
||||
lambda meters: safe_round(
|
||||
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
|
||||
)
|
||||
if meters["_c_total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
if w_total > 0:
|
||||
metrics.log_derived(
|
||||
"wer",
|
||||
lambda meters: safe_round(
|
||||
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
||||
)
|
||||
if meters["_w_total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
metrics.log_derived(
|
||||
"raw_wer",
|
||||
lambda meters: safe_round(
|
||||
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
||||
)
|
||||
if meters["_w_total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
@@ -1,120 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
class FairseqCriterion(_Loss):
|
||||
def __init__(self, task):
|
||||
super().__init__()
|
||||
self.task = task
|
||||
if hasattr(task, "target_dictionary"):
|
||||
tgt_dict = task.target_dictionary
|
||||
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
|
||||
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
"""Add criterion-specific arguments to the parser."""
|
||||
dc = getattr(cls, "__dataclass", None)
|
||||
if dc is not None:
|
||||
gen_parser_from_dataclass(parser, dc())
|
||||
|
||||
@classmethod
|
||||
def build_criterion(cls, cfg: FairseqDataclass, task):
|
||||
"""Construct a criterion from command-line args."""
|
||||
# arguments in the __init__.
|
||||
init_args = {}
|
||||
for p in inspect.signature(cls).parameters.values():
|
||||
if (
|
||||
p.kind == p.POSITIONAL_ONLY
|
||||
or p.kind == p.VAR_POSITIONAL
|
||||
or p.kind == p.VAR_KEYWORD
|
||||
):
|
||||
# we haven't implemented inference for these argument types,
|
||||
# but PRs welcome :)
|
||||
raise NotImplementedError("{} not supported".format(p.kind))
|
||||
|
||||
assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
|
||||
|
||||
if p.name == "task":
|
||||
init_args["task"] = task
|
||||
elif p.name == "cfg":
|
||||
init_args["cfg"] = cfg
|
||||
elif hasattr(cfg, p.name):
|
||||
init_args[p.name] = getattr(cfg, p.name)
|
||||
elif p.default != p.empty:
|
||||
pass # we'll use the default value
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unable to infer Criterion arguments, please implement "
|
||||
"{}.build_criterion".format(cls.__name__)
|
||||
)
|
||||
return cls(**init_args)
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def aggregate_logging_outputs(
|
||||
logging_outputs: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
utils.deprecation_warning(
|
||||
"The aggregate_logging_outputs API is deprecated. "
|
||||
"Please use the reduce_metrics API instead."
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
utils.deprecation_warning(
|
||||
"Criterions should implement the reduce_metrics API. "
|
||||
"Falling back to deprecated aggregate_logging_outputs API."
|
||||
)
|
||||
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
|
||||
for k, v in agg_logging_outputs.items():
|
||||
if k in {"nsentences", "ntokens", "sample_size"}:
|
||||
continue
|
||||
metrics.log_scalar(k, v)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class LegacyFairseqCriterion(FairseqCriterion):
|
||||
def __init__(self, args, task):
|
||||
super().__init__(task=task)
|
||||
self.args = args
|
||||
|
||||
utils.deprecation_warning(
|
||||
"Criterions should take explicit arguments instead of an "
|
||||
"argparse.Namespace object, please update your criterion by "
|
||||
"extending FairseqCriterion instead of LegacyFairseqCriterion."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_criterion(cls, args, task):
|
||||
"""Construct a criterion from command-line args."""
|
||||
return cls(args, task)
|
||||
@@ -1,136 +0,0 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.data.data_utils import lengths_to_mask
|
||||
from fairseq.models.fairseq_model import FairseqEncoderModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class FastSpeech2CriterionConfig(FairseqDataclass):
|
||||
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
|
||||
|
||||
|
||||
@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig)
|
||||
class FastSpeech2Loss(FairseqCriterion):
|
||||
def __init__(self, task, ctc_weight):
|
||||
super().__init__(task)
|
||||
self.ctc_weight = ctc_weight
|
||||
|
||||
def forward(self, model: FairseqEncoderModel, sample, reduction="mean"):
|
||||
src_tokens = sample["net_input"]["src_tokens"]
|
||||
src_lens = sample["net_input"]["src_lengths"]
|
||||
tgt_lens = sample["target_lengths"]
|
||||
_feat_out, _feat_out_post, _, log_dur_out, pitch_out, energy_out = model(
|
||||
src_tokens=src_tokens,
|
||||
src_lengths=src_lens,
|
||||
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
||||
incremental_state=None,
|
||||
target_lengths=tgt_lens,
|
||||
speaker=sample["speaker"],
|
||||
durations=sample["durations"],
|
||||
pitches=sample["pitches"],
|
||||
energies=sample["energies"],
|
||||
)
|
||||
|
||||
src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
|
||||
tgt_mask = lengths_to_mask(sample["target_lengths"])
|
||||
|
||||
pitches, energies = sample["pitches"], sample["energies"]
|
||||
pitch_out, pitches = pitch_out[src_mask], pitches[src_mask]
|
||||
energy_out, energies = energy_out[src_mask], energies[src_mask]
|
||||
|
||||
feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask]
|
||||
l1_loss = F.l1_loss(feat_out, feat, reduction=reduction)
|
||||
if _feat_out_post is not None:
|
||||
l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction)
|
||||
|
||||
pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
|
||||
energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)
|
||||
|
||||
log_dur_out = log_dur_out[src_mask]
|
||||
dur = sample["durations"].float()
|
||||
dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur
|
||||
log_dur = torch.log(dur + 1)[src_mask]
|
||||
dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)
|
||||
|
||||
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
|
||||
if self.ctc_weight > 0.0:
|
||||
lprobs = model.get_normalized_probs((_feat_out,), log_probs=True)
|
||||
lprobs = lprobs.transpose(0, 1) # T x B x C
|
||||
src_mask = lengths_to_mask(src_lens)
|
||||
src_tokens_flat = src_tokens.masked_select(src_mask)
|
||||
ctc_loss = (
|
||||
F.ctc_loss(
|
||||
lprobs,
|
||||
src_tokens_flat,
|
||||
tgt_lens,
|
||||
src_lens,
|
||||
reduction=reduction,
|
||||
zero_infinity=True,
|
||||
)
|
||||
* self.ctc_weight
|
||||
)
|
||||
|
||||
loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss
|
||||
|
||||
sample_size = sample["nsentences"]
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data),
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["nsentences"],
|
||||
"sample_size": sample_size,
|
||||
"l1_loss": utils.item(l1_loss.data),
|
||||
"dur_loss": utils.item(dur_loss.data),
|
||||
"pitch_loss": utils.item(pitch_loss.data),
|
||||
"energy_loss": utils.item(energy_loss.data),
|
||||
"ctc_loss": utils.item(ctc_loss.data),
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
|
||||
ns = [log.get("sample_size", 0) for log in logging_outputs]
|
||||
ntot = sum(ns)
|
||||
ws = [n / (ntot + 1e-8) for n in ns]
|
||||
for key in [
|
||||
"loss",
|
||||
"l1_loss",
|
||||
"dur_loss",
|
||||
"pitch_loss",
|
||||
"energy_loss",
|
||||
"ctc_loss",
|
||||
]:
|
||||
vals = [log.get(key, 0) for log in logging_outputs]
|
||||
val = sum(val * w for val, w in zip(vals, ws))
|
||||
metrics.log_scalar(key, val, ntot, round=3)
|
||||
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
|
||||
|
||||
# inference metrics
|
||||
if "targ_frames" not in logging_outputs[0]:
|
||||
return
|
||||
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
|
||||
for key, new_key in [
|
||||
("mcd_loss", "mcd_loss"),
|
||||
("pred_frames", "pred_ratio"),
|
||||
("nins", "ins_rate"),
|
||||
("ndel", "del_rate"),
|
||||
]:
|
||||
val = sum(log.get(key, 0) for log in logging_outputs)
|
||||
metrics.log_scalar(new_key, val / n, n, round=3)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
return False
|
||||
@@ -1,194 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class HubertCriterionConfig(FairseqDataclass):
|
||||
pred_masked_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "weight for predictive loss for masked frames"},
|
||||
)
|
||||
pred_nomask_weight: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "weight for predictive loss for unmasked frames"},
|
||||
)
|
||||
loss_weights: Optional[List[float]] = field(
|
||||
default=None,
|
||||
metadata={"help": "weights for additional loss terms (not first one)"},
|
||||
)
|
||||
log_keys: List[str] = field(
|
||||
default_factory=lambda: [],
|
||||
metadata={"help": "output keys to log"},
|
||||
)
|
||||
|
||||
|
||||
@register_criterion("hubert", dataclass=HubertCriterionConfig)
|
||||
class HubertCriterion(FairseqCriterion):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
pred_masked_weight,
|
||||
pred_nomask_weight,
|
||||
loss_weights=None,
|
||||
log_keys=None,
|
||||
):
|
||||
super().__init__(task)
|
||||
self.pred_masked_weight = pred_masked_weight
|
||||
self.pred_nomask_weight = pred_nomask_weight
|
||||
self.loss_weights = loss_weights
|
||||
self.log_keys = [] if log_keys is None else log_keys
|
||||
|
||||
def forward(self, model, sample, reduce=True, log_pred=False):
|
||||
"""Compute the loss for the given sample.
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
net_output = model(target_list=sample["target_list"], **sample["net_input"])
|
||||
loss = 0.0
|
||||
sample_size = 0
|
||||
logging_output = {}
|
||||
reduction = "sum" if reduce else "none"
|
||||
|
||||
loss_m_list = []
|
||||
logp_m_list = model.get_logits(net_output, True)
|
||||
targ_m_list = model.get_targets(net_output, True)
|
||||
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
|
||||
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
|
||||
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
|
||||
loss_m_list.append(loss_m)
|
||||
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
|
||||
if self.pred_masked_weight > 0:
|
||||
loss += self.pred_masked_weight * sum(loss_m_list)
|
||||
sample_size += targ_m_list[0].numel()
|
||||
|
||||
loss_u_list = []
|
||||
logp_u_list = model.get_logits(net_output, False)
|
||||
targ_u_list = model.get_targets(net_output, False)
|
||||
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
|
||||
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
|
||||
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
|
||||
loss_u_list.append(loss_u)
|
||||
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
|
||||
if self.pred_nomask_weight > 0:
|
||||
loss += self.pred_nomask_weight * sum(loss_u_list)
|
||||
sample_size += targ_u_list[0].numel()
|
||||
|
||||
if self.loss_weights is not None:
|
||||
assert hasattr(model, "get_extra_losses")
|
||||
extra_losses, names = model.get_extra_losses(net_output)
|
||||
if torch.is_tensor(extra_losses):
|
||||
extra_losses = [extra_losses]
|
||||
names = [names]
|
||||
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
||||
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
||||
assert len(extra_losses) == len(
|
||||
self.loss_weights
|
||||
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
||||
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
||||
if coef != 0 and p is not None:
|
||||
p = coef * p.float() * sample_size
|
||||
loss += p
|
||||
logging_output[f"loss_{n}"] = p.item()
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.item() if reduce else loss,
|
||||
"ntokens": sample_size,
|
||||
"nsentences": sample["id"].numel(),
|
||||
"sample_size": sample_size,
|
||||
**logging_output,
|
||||
}
|
||||
|
||||
for lk in self.log_keys:
|
||||
if lk in net_output:
|
||||
logging_output[lk] = float((net_output[lk]))
|
||||
|
||||
def compute_correct(logits):
|
||||
if logits.numel() == 0:
|
||||
return 0, 0
|
||||
else:
|
||||
assert logits.dim() > 1, logits.shape
|
||||
max = logits.argmax(-1) == 0
|
||||
min = logits.argmin(-1) == 0
|
||||
both = max & min
|
||||
corr = max.long().sum().item() - both.long().sum().item()
|
||||
count = max.numel()
|
||||
return corr, count
|
||||
|
||||
with torch.no_grad():
|
||||
for i, logp_m in enumerate(logp_m_list):
|
||||
corr_m, count_m = compute_correct(logp_m)
|
||||
logging_output[f"correct_m_{i}"] = corr_m
|
||||
logging_output[f"count_m_{i}"] = count_m
|
||||
|
||||
for i, logp_u in enumerate(logp_u_list):
|
||||
corr_u, count_u = compute_correct(logp_u)
|
||||
logging_output[f"correct_u_{i}"] = corr_u
|
||||
logging_output[f"count_u_{i}"] = count_u
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
if sample_size != ntokens:
|
||||
metrics.log_scalar(
|
||||
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
||||
)
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
||||
)
|
||||
else:
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
||||
)
|
||||
|
||||
counts = {}
|
||||
for lk in logging_outputs[0].keys():
|
||||
if lk.startswith("count_"):
|
||||
val = sum(log[lk] for log in logging_outputs)
|
||||
metrics.log_scalar(lk, val)
|
||||
counts[lk] = val
|
||||
|
||||
for lk in logging_outputs[0].keys():
|
||||
if lk.startswith("loss_"):
|
||||
val = sum(log[lk] for log in logging_outputs)
|
||||
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
|
||||
elif lk.startswith("correct_"):
|
||||
val = sum(log[lk] for log in logging_outputs)
|
||||
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
|
||||
|
||||
@staticmethod
|
||||
def aggregate_logging_outputs(logging_outputs):
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return False
|
||||
@@ -1,168 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from omegaconf import II
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
|
||||
label_smoothing: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
||||
)
|
||||
report_accuracy: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "report accuracy metric"},
|
||||
)
|
||||
ignore_prefix_size: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Ignore first N tokens"},
|
||||
)
|
||||
sentence_avg: bool = II("optimization.sentence_avg")
|
||||
|
||||
|
||||
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
|
||||
if target.dim() == lprobs.dim() - 1:
|
||||
target = target.unsqueeze(-1)
|
||||
nll_loss = -lprobs.gather(dim=-1, index=target)
|
||||
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
|
||||
if ignore_index is not None:
|
||||
pad_mask = target.eq(ignore_index)
|
||||
nll_loss.masked_fill_(pad_mask, 0.0)
|
||||
smooth_loss.masked_fill_(pad_mask, 0.0)
|
||||
else:
|
||||
nll_loss = nll_loss.squeeze(-1)
|
||||
smooth_loss = smooth_loss.squeeze(-1)
|
||||
if reduce:
|
||||
nll_loss = nll_loss.sum()
|
||||
smooth_loss = smooth_loss.sum()
|
||||
eps_i = epsilon / (lprobs.size(-1) - 1)
|
||||
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
|
||||
return loss, nll_loss
|
||||
|
||||
|
||||
@register_criterion(
|
||||
"label_smoothed_cross_entropy", dataclass=LabelSmoothedCrossEntropyCriterionConfig
|
||||
)
|
||||
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
sentence_avg,
|
||||
label_smoothing,
|
||||
ignore_prefix_size=0,
|
||||
report_accuracy=False,
|
||||
):
|
||||
super().__init__(task)
|
||||
self.sentence_avg = sentence_avg
|
||||
self.eps = label_smoothing
|
||||
self.ignore_prefix_size = ignore_prefix_size
|
||||
self.report_accuracy = report_accuracy
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
net_output = model(**sample["net_input"])
|
||||
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"nll_loss": nll_loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
if self.report_accuracy:
|
||||
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
||||
logging_output["n_correct"] = utils.item(n_correct.data)
|
||||
logging_output["total"] = utils.item(total.data)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def get_lprobs_and_target(self, model, net_output, sample):
|
||||
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
||||
target = model.get_targets(sample, net_output)
|
||||
if self.ignore_prefix_size > 0:
|
||||
# lprobs: B x T x C
|
||||
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
||||
target = target[:, self.ignore_prefix_size :].contiguous()
|
||||
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
|
||||
|
||||
def compute_loss(self, model, net_output, sample, reduce=True):
|
||||
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
||||
loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs,
|
||||
target,
|
||||
self.eps,
|
||||
ignore_index=self.padding_idx,
|
||||
reduce=reduce,
|
||||
)
|
||||
return loss, nll_loss
|
||||
|
||||
def compute_accuracy(self, model, net_output, sample):
|
||||
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
||||
mask = target.ne(self.padding_idx)
|
||||
n_correct = torch.sum(
|
||||
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
||||
)
|
||||
total = torch.sum(mask)
|
||||
return n_correct, total
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
metrics.log_scalar(
|
||||
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
|
||||
)
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
||||
)
|
||||
|
||||
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
||||
if total > 0:
|
||||
metrics.log_scalar("total", total)
|
||||
n_correct = utils.item(
|
||||
sum(log.get("n_correct", 0) for log in logging_outputs)
|
||||
)
|
||||
metrics.log_scalar("n_correct", n_correct)
|
||||
metrics.log_derived(
|
||||
"accuracy",
|
||||
lambda meters: round(
|
||||
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
|
||||
)
|
||||
if meters["total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
@@ -1,220 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import torch
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import register_criterion
|
||||
from fairseq.criterions.label_smoothed_cross_entropy import (
|
||||
LabelSmoothedCrossEntropyCriterion,
|
||||
LabelSmoothedCrossEntropyCriterionConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
from simuleval.metrics.latency import (
|
||||
AverageLagging,
|
||||
AverageProportion,
|
||||
DifferentiableAverageLagging,
|
||||
)
|
||||
|
||||
LATENCY_METRICS = {
|
||||
"average_lagging": AverageLagging,
|
||||
"average_proportion": AverageProportion,
|
||||
"differentiable_average_lagging": DifferentiableAverageLagging,
|
||||
}
|
||||
except ImportError:
|
||||
LATENCY_METRICS = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig(
|
||||
LabelSmoothedCrossEntropyCriterionConfig
|
||||
):
|
||||
latency_avg_weight: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "weight fot average latency loss."},
|
||||
)
|
||||
latency_var_weight: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "weight fot variance latency loss."},
|
||||
)
|
||||
latency_avg_type: str = field(
|
||||
default="differentiable_average_lagging",
|
||||
metadata={"help": "latency type for average loss"},
|
||||
)
|
||||
latency_var_type: str = field(
|
||||
default="variance_delay",
|
||||
metadata={"help": "latency typ for variance loss"},
|
||||
)
|
||||
latency_gather_method: str = field(
|
||||
default="weighted_average",
|
||||
metadata={"help": "method to gather latency loss for all heads"},
|
||||
)
|
||||
latency_update_after: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Add latency loss after certain steps"},
|
||||
)
|
||||
|
||||
|
||||
@register_criterion(
|
||||
"latency_augmented_label_smoothed_cross_entropy",
|
||||
dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig,
|
||||
)
|
||||
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
||||
LabelSmoothedCrossEntropyCriterion
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
sentence_avg,
|
||||
label_smoothing,
|
||||
ignore_prefix_size,
|
||||
report_accuracy,
|
||||
latency_avg_weight,
|
||||
latency_var_weight,
|
||||
latency_avg_type,
|
||||
latency_var_type,
|
||||
latency_gather_method,
|
||||
latency_update_after,
|
||||
):
|
||||
super().__init__(
|
||||
task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy
|
||||
)
|
||||
assert LATENCY_METRICS is not None, "Please make sure SimulEval is installed."
|
||||
|
||||
self.latency_avg_weight = latency_avg_weight
|
||||
self.latency_var_weight = latency_var_weight
|
||||
self.latency_avg_type = latency_avg_type
|
||||
self.latency_var_type = latency_var_type
|
||||
self.latency_gather_method = latency_gather_method
|
||||
self.latency_update_after = latency_update_after
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
net_output = model(**sample["net_input"])
|
||||
# 1. Compute cross entropy loss
|
||||
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
|
||||
|
||||
# 2. Compute cross latency loss
|
||||
latency_loss, expected_latency, expected_delays_var = self.compute_latency_loss(
|
||||
model, sample, net_output
|
||||
)
|
||||
|
||||
if self.latency_update_after > 0:
|
||||
num_updates = getattr(model.decoder, "num_updates", None)
|
||||
assert (
|
||||
num_updates is not None
|
||||
), "model.decoder doesn't have attribute 'num_updates'"
|
||||
if num_updates <= self.latency_update_after:
|
||||
latency_loss = 0
|
||||
|
||||
loss += latency_loss
|
||||
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"nll_loss": nll_loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
"latency": expected_latency,
|
||||
"delays_var": expected_delays_var,
|
||||
"latency_loss": latency_loss,
|
||||
}
|
||||
|
||||
if self.report_accuracy:
|
||||
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
||||
logging_output["n_correct"] = utils.item(n_correct.data)
|
||||
logging_output["total"] = utils.item(total.data)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def compute_latency_loss(self, model, sample, net_output):
|
||||
assert (
|
||||
net_output[-1].encoder_padding_mask is None
|
||||
or not net_output[-1].encoder_padding_mask[:, 0].any()
|
||||
), "Only right padding on source is supported."
|
||||
# 1. Obtain the expected alignment
|
||||
alpha_list = [item["alpha"] for item in net_output[1].attn_list]
|
||||
num_layers = len(alpha_list)
|
||||
bsz, num_heads, tgt_len, src_len = alpha_list[0].size()
|
||||
|
||||
# bsz * num_layers * num_heads, tgt_len, src_len
|
||||
alpha_all = torch.cat(alpha_list, dim=1).view(-1, tgt_len, src_len)
|
||||
|
||||
# 2 compute expected delays
|
||||
# bsz * num_heads * num_layers, tgt_len, src_len for MMA
|
||||
steps = (
|
||||
torch.arange(1, 1 + src_len)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(1)
|
||||
.expand_as(alpha_all)
|
||||
.type_as(alpha_all)
|
||||
)
|
||||
|
||||
expected_delays = torch.sum(steps * alpha_all, dim=-1)
|
||||
|
||||
target_padding_mask = (
|
||||
model.get_targets(sample, net_output)
|
||||
.eq(self.padding_idx)
|
||||
.unsqueeze(1)
|
||||
.expand(bsz, num_layers * num_heads, tgt_len)
|
||||
.contiguous()
|
||||
.view(-1, tgt_len)
|
||||
)
|
||||
|
||||
src_lengths = (
|
||||
sample["net_input"]["src_lengths"]
|
||||
.unsqueeze(1)
|
||||
.expand(bsz, num_layers * num_heads)
|
||||
.contiguous()
|
||||
.view(-1)
|
||||
)
|
||||
expected_latency = LATENCY_METRICS[self.latency_avg_type](
|
||||
expected_delays, src_lengths, None, target_padding_mask=target_padding_mask
|
||||
)
|
||||
|
||||
# 2.1 average expected latency of heads
|
||||
# bsz, num_layers * num_heads
|
||||
expected_latency = expected_latency.view(bsz, -1)
|
||||
if self.latency_gather_method == "average":
|
||||
# bsz * tgt_len
|
||||
expected_latency = expected_delays.mean(dim=1)
|
||||
elif self.latency_gather_method == "weighted_average":
|
||||
weights = torch.nn.functional.softmax(expected_latency, dim=1)
|
||||
expected_latency = torch.sum(expected_latency * weights, dim=1)
|
||||
elif self.latency_gather_method == "max":
|
||||
expected_latency = expected_latency.max(dim=1)[0]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
expected_latency = expected_latency.sum()
|
||||
avg_loss = self.latency_avg_weight * expected_latency
|
||||
|
||||
# 2.2 variance of expected delays
|
||||
expected_delays_var = (
|
||||
expected_delays.view(bsz, -1, tgt_len).var(dim=1).mean(dim=1)
|
||||
)
|
||||
expected_delays_var = expected_delays_var.sum()
|
||||
var_loss = self.latency_avg_weight * expected_delays_var
|
||||
|
||||
# 3. Final loss
|
||||
latency_loss = avg_loss + var_loss
|
||||
|
||||
return latency_loss, expected_latency, expected_delays_var
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs) -> None:
|
||||
super().reduce_metrics(logging_outputs)
|
||||
latency = sum(log.get("latency", 0) for log in logging_outputs)
|
||||
delays_var = sum(log.get("delays_var", 0) for log in logging_outputs)
|
||||
latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs)
|
||||
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3)
|
||||
metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3)
|
||||
metrics.log_scalar(
|
||||
"latency_loss", latency_loss / nsentences, nsentences, round=3
|
||||
)
|
||||
@@ -1,130 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import register_criterion
|
||||
|
||||
from .label_smoothed_cross_entropy import (
|
||||
LabelSmoothedCrossEntropyCriterion,
|
||||
LabelSmoothedCrossEntropyCriterionConfig,
|
||||
)
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class LabelSmoothedCrossEntropyCriterionWithAlignmentConfig(
|
||||
LabelSmoothedCrossEntropyCriterionConfig
|
||||
):
|
||||
alignment_lambda: float = field(
|
||||
default=0.05, metadata={"help": "weight for the alignment loss"}
|
||||
)
|
||||
|
||||
|
||||
@register_criterion(
|
||||
"label_smoothed_cross_entropy_with_alignment",
|
||||
dataclass=LabelSmoothedCrossEntropyCriterionWithAlignmentConfig,
|
||||
)
|
||||
class LabelSmoothedCrossEntropyCriterionWithAlignment(
|
||||
LabelSmoothedCrossEntropyCriterion
|
||||
):
|
||||
def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda):
|
||||
super().__init__(task, sentence_avg, label_smoothing)
|
||||
self.alignment_lambda = alignment_lambda
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
net_output = model(**sample["net_input"])
|
||||
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data) if reduce else loss.data,
|
||||
"nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
|
||||
alignment_loss = None
|
||||
|
||||
# Compute alignment loss only for training set and non dummy batches.
|
||||
if "alignments" in sample and sample["alignments"] is not None:
|
||||
alignment_loss = self.compute_alignment_loss(sample, net_output)
|
||||
|
||||
if alignment_loss is not None:
|
||||
logging_output["alignment_loss"] = utils.item(alignment_loss.data)
|
||||
loss += self.alignment_lambda * alignment_loss
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def compute_alignment_loss(self, sample, net_output):
|
||||
attn_prob = net_output[1]["attn"][0]
|
||||
bsz, tgt_sz, src_sz = attn_prob.shape
|
||||
attn = attn_prob.view(bsz * tgt_sz, src_sz)
|
||||
|
||||
align = sample["alignments"]
|
||||
align_weights = sample["align_weights"].float()
|
||||
|
||||
if len(align) > 0:
|
||||
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
|
||||
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
|
||||
loss = -(
|
||||
(attn[align[:, 1][:, None], align[:, 0][:, None]]).log()
|
||||
* align_weights[:, None]
|
||||
).sum()
|
||||
else:
|
||||
return None
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
||||
nll_loss_sum = utils.item(
|
||||
sum(log.get("nll_loss", 0) for log in logging_outputs)
|
||||
)
|
||||
alignment_loss_sum = utils.item(
|
||||
sum(log.get("alignment_loss", 0) for log in logging_outputs)
|
||||
)
|
||||
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
||||
sample_size = utils.item(
|
||||
sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
metrics.log_scalar(
|
||||
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
|
||||
)
|
||||
metrics.log_scalar(
|
||||
"alignment_loss",
|
||||
alignment_loss_sum / sample_size / math.log(2),
|
||||
sample_size,
|
||||
round=3,
|
||||
)
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
@@ -1,96 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import register_criterion
|
||||
from fairseq.criterions.label_smoothed_cross_entropy import (
|
||||
LabelSmoothedCrossEntropyCriterion,
|
||||
LabelSmoothedCrossEntropyCriterionConfig,
|
||||
)
|
||||
from fairseq.data.data_utils import lengths_to_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class LabelSmoothedCrossEntropyWithCtcCriterionConfig(
|
||||
LabelSmoothedCrossEntropyCriterionConfig
|
||||
):
|
||||
ctc_weight: float = field(default=1.0, metadata={"help": "weight for CTC loss"})
|
||||
|
||||
|
||||
@register_criterion(
|
||||
"label_smoothed_cross_entropy_with_ctc",
|
||||
dataclass=LabelSmoothedCrossEntropyWithCtcCriterionConfig,
|
||||
)
|
||||
class LabelSmoothedCrossEntropyWithCtcCriterion(LabelSmoothedCrossEntropyCriterion):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
sentence_avg,
|
||||
label_smoothing,
|
||||
ignore_prefix_size,
|
||||
report_accuracy,
|
||||
ctc_weight,
|
||||
):
|
||||
super().__init__(
|
||||
task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy
|
||||
)
|
||||
self.ctc_weight = ctc_weight
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
net_output = model(**sample["net_input"])
|
||||
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
|
||||
|
||||
ctc_loss = torch.tensor(0.0).type_as(loss)
|
||||
if self.ctc_weight > 0.0:
|
||||
ctc_lprobs, ctc_lens = model.get_ctc_output(net_output, sample)
|
||||
ctc_tgt, ctc_tgt_lens = model.get_ctc_target(sample)
|
||||
ctc_tgt_mask = lengths_to_mask(ctc_tgt_lens)
|
||||
ctc_tgt_flat = ctc_tgt.masked_select(ctc_tgt_mask)
|
||||
reduction = "sum" if reduce else "none"
|
||||
ctc_loss = (
|
||||
F.ctc_loss(
|
||||
ctc_lprobs,
|
||||
ctc_tgt_flat,
|
||||
ctc_lens,
|
||||
ctc_tgt_lens,
|
||||
reduction=reduction,
|
||||
zero_infinity=True,
|
||||
)
|
||||
* self.ctc_weight
|
||||
)
|
||||
loss += ctc_loss
|
||||
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data),
|
||||
"nll_loss": utils.item(nll_loss.data),
|
||||
"ctc_loss": utils.item(ctc_loss.data),
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
if self.report_accuracy:
|
||||
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
||||
logging_output["n_correct"] = utils.item(n_correct.data)
|
||||
logging_output["total"] = utils.item(total.data)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs) -> None:
|
||||
super().reduce_metrics(logging_outputs)
|
||||
loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
metrics.log_scalar(
|
||||
"ctc_loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
@@ -1,176 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import register_criterion
|
||||
from fairseq.criterions.label_smoothed_cross_entropy import (
|
||||
LabelSmoothedCrossEntropyCriterion,
|
||||
LabelSmoothedCrossEntropyCriterionConfig,
|
||||
label_smoothed_nll_loss,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RdropLabelSmoothedCrossEntropyCriterionConfig(
|
||||
LabelSmoothedCrossEntropyCriterionConfig
|
||||
):
|
||||
rdrop_alpha: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "alpha for r-drop, 0 means no r-drop"},
|
||||
)
|
||||
|
||||
|
||||
@register_criterion(
|
||||
"label_smoothed_cross_entropy_with_rdrop",
|
||||
dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig,
|
||||
)
|
||||
class RdropLabelSmoothedCrossEntropyCriterion(LabelSmoothedCrossEntropyCriterion):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
sentence_avg,
|
||||
label_smoothing,
|
||||
ignore_prefix_size=0,
|
||||
report_accuracy=False,
|
||||
rdrop_alpha=0.0,
|
||||
):
|
||||
super().__init__(
|
||||
task,
|
||||
sentence_avg,
|
||||
label_smoothing,
|
||||
ignore_prefix_size=ignore_prefix_size,
|
||||
report_accuracy=report_accuracy,
|
||||
)
|
||||
self.sentence_avg = sentence_avg
|
||||
self.eps = label_smoothing
|
||||
self.ignore_prefix_size = ignore_prefix_size
|
||||
self.report_accuracy = report_accuracy
|
||||
self.rdrop_alpha = rdrop_alpha
|
||||
|
||||
def forward(self, model, sample, reduce=True, net_output=None):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
if net_output is None:
|
||||
if self.rdrop_alpha > 0 and sample["net_input"]["src_tokens"].size(
|
||||
0
|
||||
) == sample["target"].size(0):
|
||||
sample = duplicate_input(sample)
|
||||
net_output = model(**sample["net_input"])
|
||||
loss, nll_loss, rdrop_kl_loss = self.compute_loss(
|
||||
model, net_output, sample, reduce=reduce
|
||||
)
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"nll_loss": nll_loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
if self.report_accuracy:
|
||||
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
||||
logging_output["n_correct"] = utils.item(n_correct.data)
|
||||
logging_output["total"] = utils.item(total.data)
|
||||
if self.rdrop_alpha > 0:
|
||||
logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def get_lprobs_and_target(self, model, net_output, sample):
|
||||
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
||||
target = model.get_targets(sample, net_output)
|
||||
if self.rdrop_alpha > 0 or target.size(0) != lprobs.size(0):
|
||||
target = torch.cat([target, target.clone()], dim=0)
|
||||
|
||||
if self.ignore_prefix_size > 0:
|
||||
# lprobs: B x T x C
|
||||
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
||||
target = target[:, self.ignore_prefix_size :].contiguous()
|
||||
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
|
||||
|
||||
def compute_loss(self, model, net_output, sample, reduce=True):
|
||||
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
||||
loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs,
|
||||
target,
|
||||
self.eps,
|
||||
ignore_index=self.padding_idx,
|
||||
reduce=reduce,
|
||||
)
|
||||
|
||||
if self.rdrop_alpha > 0:
|
||||
pad_mask = target[: target.size(0) // 2].unsqueeze(-1).eq(self.padding_idx)
|
||||
rdrop_kl_loss = compute_kl_loss(model, net_output, pad_mask)
|
||||
loss += self.rdrop_alpha * rdrop_kl_loss
|
||||
else:
|
||||
rdrop_kl_loss = loss.new_zeros(1)
|
||||
return loss, nll_loss, rdrop_kl_loss
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
super().reduce_metrics(logging_outputs)
|
||||
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
rdrop_kl_loss = utils.item(
|
||||
sum(log.get("rdrop_kl_loss", 0) for log in logging_outputs)
|
||||
/ sample_size
|
||||
/ math.log(2)
|
||||
)
|
||||
if rdrop_kl_loss > 0:
|
||||
metrics.log_scalar("rdrop_kl_loss", rdrop_kl_loss)
|
||||
|
||||
|
||||
def duplicate_input(sample):
|
||||
if "net_input" in sample.keys():
|
||||
sample_input = sample["net_input"]
|
||||
else:
|
||||
sample_input = sample
|
||||
|
||||
for k, v in sample_input.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
sample_input[k] = torch.cat([v, v.clone()], dim=0)
|
||||
if "net_input" in sample.keys():
|
||||
sample["net_input"] = sample_input
|
||||
else:
|
||||
sample = sample_input
|
||||
return sample
|
||||
|
||||
|
||||
def compute_kl_loss(model, net_output, pad_mask=None, reduce=True):
|
||||
net_prob = model.get_normalized_probs(net_output, log_probs=True)
|
||||
net_prob_tec = model.get_normalized_probs(net_output, log_probs=False)
|
||||
|
||||
net_prob = net_prob.view(-1, net_prob.size(-1))
|
||||
net_prob_tec = net_prob_tec.view(-1, net_prob_tec.size(-1))
|
||||
|
||||
p, q = torch.split(net_prob, net_prob.size(0) // 2, dim=0)
|
||||
p_tec, q_tec = torch.split(net_prob_tec, net_prob_tec.size(0) // 2, dim=0)
|
||||
|
||||
p_loss = torch.nn.functional.kl_div(p, q_tec, reduction="none")
|
||||
q_loss = torch.nn.functional.kl_div(q, p_tec, reduction="none")
|
||||
|
||||
if pad_mask is not None:
|
||||
p_loss.masked_fill_(pad_mask, 0.0)
|
||||
q_loss.masked_fill_(pad_mask, 0.0)
|
||||
|
||||
if reduce:
|
||||
p_loss = p_loss.sum()
|
||||
q_loss = q_loss.sum()
|
||||
|
||||
loss = (p_loss + q_loss) / 2
|
||||
return loss
|
||||
@@ -1,177 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
|
||||
|
||||
def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
|
||||
"""
|
||||
Function to compute the cross entropy loss. The default value of
|
||||
ignore_index is the same as the default value for F.cross_entropy in
|
||||
pytorch.
|
||||
"""
|
||||
assert logits.size(0) == targets.size(
|
||||
-1
|
||||
), "Logits and Targets tensor shapes don't match up"
|
||||
|
||||
loss = F.nll_loss(
|
||||
F.log_softmax(logits, -1, dtype=torch.float32),
|
||||
targets,
|
||||
reduction="sum",
|
||||
ignore_index=ignore_index,
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
@register_criterion("legacy_masked_lm_loss")
|
||||
class LegacyMaskedLmLoss(FairseqCriterion):
|
||||
"""
|
||||
Implementation for the loss used in masked language model (MLM) training.
|
||||
This optionally also computes the next sentence prediction (NSP) loss and
|
||||
adds it to the overall loss based on the specified args. There are three
|
||||
cases to consider:
|
||||
1) Generic MLM training without NSP loss. In this case sentence_targets
|
||||
and sentence_logits are both None.
|
||||
2) BERT training without NSP loss. In this case sentence_targets is
|
||||
not None but sentence_logits is None and we should not be computing
|
||||
a sentence level loss.
|
||||
3) BERT training with NSP loss. In this case both sentence_targets and
|
||||
sentence_logits are not None and we should be computing a sentence
|
||||
level loss. The weight of the sentence level loss is specified as
|
||||
an argument.
|
||||
"""
|
||||
|
||||
def __init__(self, task, masked_lm_only, nsp_loss_weight):
|
||||
super().__init__(task)
|
||||
self.masked_lm_only = masked_lm_only
|
||||
self.nsp_loss_weight = nsp_loss_weight
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Args for MaskedLM Loss"""
|
||||
# Default for masked_lm_only is False so as to not break BERT training
|
||||
parser.add_argument(
|
||||
"--masked-lm-only",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="compute MLM loss only",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nsp-loss-weight",
|
||||
default=1.0,
|
||||
type=float,
|
||||
help="weight for next sentence prediction" " loss (default 1)",
|
||||
)
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
lm_logits, output_metadata = model(**sample["net_input"])
|
||||
|
||||
# reshape lm_logits from (N,T,C) to (N*T,C)
|
||||
lm_logits = lm_logits.view(-1, lm_logits.size(-1))
|
||||
lm_targets = sample["lm_target"].view(-1)
|
||||
lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets, self.padding_idx)
|
||||
|
||||
# compute the number of tokens for which loss is computed. This is used
|
||||
# to normalize the loss
|
||||
ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
|
||||
loss = lm_loss / ntokens
|
||||
nsentences = sample["nsentences"]
|
||||
# nsentences = 0
|
||||
|
||||
# Compute sentence loss if masked_lm_only is False
|
||||
sentence_loss = None
|
||||
if not self.masked_lm_only:
|
||||
sentence_logits = output_metadata["sentence_logits"]
|
||||
sentence_targets = sample["sentence_target"].view(-1)
|
||||
# This needs to be recomputed due to some differences between
|
||||
# TokenBlock and BlockPair dataset. This can be resolved with a
|
||||
# refactor of BERTModel which we will do in the future.
|
||||
# TODO: Remove this after refactor of BERTModel
|
||||
nsentences = sentence_targets.size(0)
|
||||
|
||||
# Check for logits being none which can happen when remove_heads
|
||||
# is set to true in the BERT model. Ideally we should set
|
||||
# masked_lm_only to true in this case, but that requires some
|
||||
# refactor in the BERT model.
|
||||
if sentence_logits is not None:
|
||||
sentence_loss = compute_cross_entropy_loss(
|
||||
sentence_logits, sentence_targets
|
||||
)
|
||||
|
||||
loss += self.nsp_loss_weight * (sentence_loss / nsentences)
|
||||
|
||||
# NOTE: as we are summing up per token mlm loss and per sentence nsp loss
|
||||
# we don't need to use sample_size as denominator for the gradient
|
||||
# here sample_size is just used for logging
|
||||
sample_size = 1
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data) if reduce else loss.data,
|
||||
"lm_loss": utils.item(lm_loss.data) if reduce else lm_loss.data,
|
||||
# sentence loss is not always computed
|
||||
"sentence_loss": (
|
||||
(utils.item(sentence_loss.data) if reduce else sentence_loss.data)
|
||||
if sentence_loss is not None
|
||||
else 0.0
|
||||
),
|
||||
"ntokens": ntokens,
|
||||
"nsentences": nsentences,
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
lm_loss_sum = sum(log.get("lm_loss", 0) for log in logging_outputs)
|
||||
sentence_loss_sum = sum(log.get("sentence_loss", 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
agg_loss = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss",
|
||||
agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.0,
|
||||
sample_size,
|
||||
round=3,
|
||||
)
|
||||
metrics.log_scalar(
|
||||
"lm_loss",
|
||||
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
|
||||
ntokens,
|
||||
round=3,
|
||||
)
|
||||
metrics.log_scalar(
|
||||
"sentence_loss",
|
||||
sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.0,
|
||||
nsentences,
|
||||
round=3,
|
||||
)
|
||||
metrics.log_scalar(
|
||||
"nll_loss",
|
||||
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
|
||||
ntokens,
|
||||
round=3,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
@@ -1,98 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from omegaconf import II
|
||||
|
||||
import torch
|
||||
from fairseq import metrics, modules, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskedLmConfig(FairseqDataclass):
|
||||
tpu: bool = II("common.tpu")
|
||||
|
||||
|
||||
@register_criterion("masked_lm", dataclass=MaskedLmConfig)
|
||||
class MaskedLmLoss(FairseqCriterion):
|
||||
"""
|
||||
Implementation for the loss used in masked language model (MLM) training.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: MaskedLmConfig, task):
|
||||
super().__init__(task)
|
||||
self.tpu = cfg.tpu
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
masked_tokens = sample["target"].ne(self.padding_idx)
|
||||
sample_size = masked_tokens.int().sum()
|
||||
|
||||
# Rare: when all tokens are masked, project all tokens.
|
||||
# We use torch.where to avoid device-to-host transfers,
|
||||
# except on CPU where torch.where is not well supported
|
||||
# (see github.com/pytorch/pytorch/issues/26247).
|
||||
if self.tpu:
|
||||
masked_tokens = None # always project all tokens on TPU
|
||||
elif masked_tokens.device == torch.device("cpu"):
|
||||
if not masked_tokens.any():
|
||||
masked_tokens = None
|
||||
else:
|
||||
masked_tokens = torch.where(
|
||||
masked_tokens.any(),
|
||||
masked_tokens,
|
||||
masked_tokens.new([True]),
|
||||
)
|
||||
|
||||
logits = model(**sample["net_input"], masked_tokens=masked_tokens)[0]
|
||||
targets = model.get_targets(sample, [logits])
|
||||
if masked_tokens is not None:
|
||||
targets = targets[masked_tokens]
|
||||
|
||||
loss = modules.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)),
|
||||
targets.view(-1),
|
||||
reduction="sum",
|
||||
ignore_index=self.padding_idx,
|
||||
)
|
||||
|
||||
logging_output = {
|
||||
"loss": loss if self.tpu else loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["nsentences"],
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
@@ -1,155 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCriterionConfig(FairseqDataclass):
|
||||
loss_weights: Dict[str, float] = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "weights for the loss terms"},
|
||||
)
|
||||
log_keys: List[str] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "additional output keys to log"},
|
||||
)
|
||||
|
||||
|
||||
@register_criterion("model", dataclass=ModelCriterionConfig)
|
||||
class ModelCriterion(FairseqCriterion):
|
||||
"""
|
||||
This criterion relies on the model to supply losses.
|
||||
The losses should be a dictionary of name -> scalar returned by
|
||||
the model either by including it in the net_output dict or by
|
||||
implementing a get_losses(net_output, sample) method. The final loss is
|
||||
a scaled sum of all losses according to weights in loss_weights.
|
||||
If no weights are provided, then all losses are scaled by 1.0.
|
||||
|
||||
The losses will be automatically logged. Additional keys from
|
||||
net_output dict can be logged via the log_keys parameter.
|
||||
"""
|
||||
|
||||
def __init__(self, task, loss_weights=None, log_keys=None):
|
||||
super().__init__(task)
|
||||
self.loss_weights = loss_weights
|
||||
self.log_keys = log_keys
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
net_output = model(**sample["net_input"])
|
||||
|
||||
scaled_losses = {}
|
||||
|
||||
if hasattr(model, "get_losses"):
|
||||
losses = model.get_losses(net_output, sample)
|
||||
elif isinstance(net_output, dict) and "losses" in net_output:
|
||||
losses = net_output["losses"]
|
||||
else:
|
||||
raise Exception("Could not retrieve losses")
|
||||
|
||||
for lk, p in losses.items():
|
||||
try:
|
||||
coef = 1.0 if len(self.loss_weights) == 0 else self.loss_weights[lk]
|
||||
except KeyError:
|
||||
logger.error(
|
||||
f"weight for loss {lk} is not in loss_weights ({self.loss_weights})"
|
||||
)
|
||||
raise
|
||||
if coef != 0 and p is not None:
|
||||
scaled_losses[lk] = coef * p.float()
|
||||
|
||||
loss = sum(scaled_losses.values())
|
||||
|
||||
if "sample_size" in net_output:
|
||||
sample_size = net_output["sample_size"]
|
||||
else:
|
||||
sample_size = loss.numel()
|
||||
|
||||
if reduce and loss.numel() > 1:
|
||||
loss = loss.sum()
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"ntokens": sample_size,
|
||||
"nsentences": sample["id"].numel(),
|
||||
"sample_size": sample_size,
|
||||
"_world_size": 1,
|
||||
}
|
||||
|
||||
for lk in self.log_keys:
|
||||
if lk in net_output and net_output[lk] is not None:
|
||||
if not torch.is_tensor(net_output[lk]) or net_output[lk].numel() == 1:
|
||||
logging_output[lk] = float(net_output[lk])
|
||||
else:
|
||||
for i, v in enumerate(net_output[lk]):
|
||||
logging_output[f"{lk}_{i}"] = float(v)
|
||||
|
||||
if len(scaled_losses) > 1:
|
||||
for lk, l in scaled_losses.items():
|
||||
if l.numel() > 1:
|
||||
l = l.sum()
|
||||
logging_output[f"loss_{lk}"] = l.item()
|
||||
|
||||
if "logs" in net_output:
|
||||
for lgw in net_output["logs"]:
|
||||
logging_output[lgw] = net_output["logs"][lgw]
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
||||
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
||||
nsentences = utils.item(
|
||||
sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
)
|
||||
sample_size = utils.item(
|
||||
sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
)
|
||||
|
||||
metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
|
||||
metrics.log_scalar("ntokens", ntokens)
|
||||
metrics.log_scalar("nsentences", nsentences)
|
||||
|
||||
builtin_keys = {
|
||||
"loss",
|
||||
"ntokens",
|
||||
"nsentences",
|
||||
"sample_size",
|
||||
"_world_size",
|
||||
}
|
||||
|
||||
world_size = utils.item(
|
||||
sum(log.get("_world_size", 0) for log in logging_outputs)
|
||||
)
|
||||
|
||||
for k in logging_outputs[0]:
|
||||
if k not in builtin_keys:
|
||||
val = sum(log.get(k, 0) for log in logging_outputs)
|
||||
if k.startswith("loss_"):
|
||||
metrics.log_scalar(k, val / sample_size, sample_size, round=3)
|
||||
else:
|
||||
metrics.log_scalar(k, val / world_size, round=3)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
@@ -1,180 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from torch import Tensor
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class LabelSmoothedDualImitationCriterionConfig(FairseqDataclass):
|
||||
label_smoothing: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
||||
)
|
||||
|
||||
|
||||
@register_criterion("nat_loss", dataclass=LabelSmoothedDualImitationCriterionConfig)
|
||||
class LabelSmoothedDualImitationCriterion(FairseqCriterion):
|
||||
def __init__(self, task, label_smoothing):
|
||||
super().__init__(task)
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
def _compute_loss(
|
||||
self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0
|
||||
):
|
||||
"""
|
||||
outputs: batch x len x d_model
|
||||
targets: batch x len
|
||||
masks: batch x len
|
||||
|
||||
policy_logprob: if there is some policy
|
||||
depends on the likelihood score as rewards.
|
||||
"""
|
||||
|
||||
def mean_ds(x: Tensor, dim=None) -> Tensor:
|
||||
return (
|
||||
x.float().mean().type_as(x)
|
||||
if dim is None
|
||||
else x.float().mean(dim).type_as(x)
|
||||
)
|
||||
|
||||
if masks is not None:
|
||||
outputs, targets = outputs[masks], targets[masks]
|
||||
|
||||
if masks is not None and not masks.any():
|
||||
nll_loss = torch.tensor(0)
|
||||
loss = nll_loss
|
||||
else:
|
||||
logits = F.log_softmax(outputs, dim=-1)
|
||||
if targets.dim() == 1:
|
||||
losses = F.nll_loss(logits, targets.to(logits.device), reduction="none")
|
||||
|
||||
else: # soft-labels
|
||||
losses = F.kl_div(logits, targets.to(logits.device), reduction="none")
|
||||
losses = losses.sum(-1)
|
||||
|
||||
nll_loss = mean_ds(losses)
|
||||
if label_smoothing > 0:
|
||||
loss = (
|
||||
nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing
|
||||
)
|
||||
else:
|
||||
loss = nll_loss
|
||||
|
||||
loss = loss * factor
|
||||
return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor}
|
||||
|
||||
def _custom_loss(self, loss, name="loss", factor=1.0):
|
||||
return {"name": name, "loss": loss, "factor": factor}
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
nsentences, ntokens = sample["nsentences"], sample["ntokens"]
|
||||
|
||||
# B x T
|
||||
src_tokens, src_lengths = (
|
||||
sample["net_input"]["src_tokens"],
|
||||
sample["net_input"]["src_lengths"],
|
||||
)
|
||||
tgt_tokens, prev_output_tokens = sample["target"], sample["prev_target"]
|
||||
|
||||
outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens)
|
||||
losses, nll_loss = [], []
|
||||
|
||||
for obj in outputs:
|
||||
if outputs[obj].get("loss", None) is None:
|
||||
_losses = self._compute_loss(
|
||||
outputs[obj].get("out"),
|
||||
outputs[obj].get("tgt"),
|
||||
outputs[obj].get("mask", None),
|
||||
outputs[obj].get("ls", 0.0),
|
||||
name=obj + "-loss",
|
||||
factor=outputs[obj].get("factor", 1.0),
|
||||
)
|
||||
else:
|
||||
_losses = self._custom_loss(
|
||||
outputs[obj].get("loss"),
|
||||
name=obj + "-loss",
|
||||
factor=outputs[obj].get("factor", 1.0),
|
||||
)
|
||||
|
||||
losses += [_losses]
|
||||
if outputs[obj].get("nll_loss", False):
|
||||
nll_loss += [_losses.get("nll_loss", 0.0)]
|
||||
|
||||
loss = sum(l["loss"] for l in losses)
|
||||
nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 else loss.new_tensor(0)
|
||||
|
||||
# NOTE:
|
||||
# we don't need to use sample_size as denominator for the gradient
|
||||
# here sample_size is just used for logging
|
||||
sample_size = 1
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"nll_loss": nll_loss.data,
|
||||
"ntokens": ntokens,
|
||||
"nsentences": nsentences,
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
|
||||
for l in losses:
|
||||
logging_output[l["name"]] = (
|
||||
utils.item(l["loss"].data / l["factor"])
|
||||
if reduce
|
||||
else l[["loss"]].data / l["factor"]
|
||||
)
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
sample_size = utils.item(
|
||||
sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
)
|
||||
loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
||||
nll_loss = utils.item(sum(log.get("nll_loss", 0) for log in logging_outputs))
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
metrics.log_scalar(
|
||||
"nll_loss", nll_loss / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
||||
)
|
||||
|
||||
for key in logging_outputs[0]:
|
||||
if key[-5:] == "-loss":
|
||||
val = sum(log.get(key, 0) for log in logging_outputs)
|
||||
metrics.log_scalar(
|
||||
key[:-5],
|
||||
val / sample_size / math.log(2) if sample_size > 0 else 0.0,
|
||||
sample_size,
|
||||
round=3,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
@@ -1,141 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import metrics
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SentencePredictionConfig(FairseqDataclass):
|
||||
classification_head_name: str = field(
|
||||
default="sentence_classification_head",
|
||||
metadata={"help": "name of the classification head to use"},
|
||||
)
|
||||
regression_target: bool = field(
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
@register_criterion("sentence_prediction", dataclass=SentencePredictionConfig)
|
||||
class SentencePredictionCriterion(FairseqCriterion):
|
||||
def __init__(self, cfg: SentencePredictionConfig, task):
|
||||
super().__init__(task)
|
||||
self.classification_head_name = cfg.classification_head_name
|
||||
self.regression_target = cfg.regression_target
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
assert (
|
||||
hasattr(model, "classification_heads")
|
||||
and self.classification_head_name in model.classification_heads
|
||||
), "model must provide sentence classification head for --criterion=sentence_prediction"
|
||||
|
||||
logits, _ = model(
|
||||
**sample["net_input"],
|
||||
features_only=True,
|
||||
classification_head_name=self.classification_head_name,
|
||||
)
|
||||
targets = model.get_targets(sample, [logits]).view(-1)
|
||||
sample_size = targets.numel()
|
||||
|
||||
if not self.regression_target:
|
||||
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
task_loss = F.nll_loss(lprobs, targets, reduction="sum")
|
||||
else:
|
||||
logits = logits.view(-1).float()
|
||||
targets = targets.float()
|
||||
task_loss = F.mse_loss(logits, targets, reduction="sum")
|
||||
|
||||
logging_output = {}
|
||||
loss = task_loss
|
||||
# mha & ffn regularization update
|
||||
if (
|
||||
hasattr(model.args, "mha_reg_scale_factor")
|
||||
and model.args.mha_reg_scale_factor != 0.0
|
||||
):
|
||||
mha_reg_loss = model._get_adaptive_head_loss()
|
||||
loss += mha_reg_loss
|
||||
logging_output.update({"mha_reg_loss": mha_reg_loss})
|
||||
if (
|
||||
hasattr(model.args, "ffn_reg_scale_factor")
|
||||
and model.args.ffn_reg_scale_factor != 0.0
|
||||
):
|
||||
ffn_reg_loss = model._get_adaptive_ffn_loss()
|
||||
loss += ffn_reg_loss
|
||||
logging_output.update({"ffn_reg_loss": ffn_reg_loss})
|
||||
|
||||
logging_output.update(
|
||||
{
|
||||
"loss": loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample_size,
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
)
|
||||
if not self.regression_target:
|
||||
preds = logits.argmax(dim=1)
|
||||
logging_output["ncorrect"] = (preds == targets).sum()
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
mha_reg_loss_sum = sum(log.get("mha_reg_loss", 0) for log in logging_outputs)
|
||||
ffn_reg_loss_sum = sum(log.get("ffn_reg_loss", 0) for log in logging_outputs)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
if mha_reg_loss_sum:
|
||||
metrics.log_scalar(
|
||||
"mha_reg_loss",
|
||||
mha_reg_loss_sum / sample_size / math.log(2),
|
||||
sample_size,
|
||||
round=3,
|
||||
)
|
||||
if ffn_reg_loss_sum:
|
||||
metrics.log_scalar(
|
||||
"ffn_reg_loss",
|
||||
ffn_reg_loss_sum / sample_size / math.log(2),
|
||||
sample_size,
|
||||
round=3,
|
||||
)
|
||||
if sample_size != ntokens:
|
||||
metrics.log_scalar(
|
||||
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
||||
)
|
||||
|
||||
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
|
||||
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
|
||||
metrics.log_scalar(
|
||||
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
@@ -1,63 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq.criterions import register_criterion
|
||||
from fairseq.criterions.sentence_prediction import (
|
||||
SentencePredictionCriterion,
|
||||
SentencePredictionConfig,
|
||||
)
|
||||
|
||||
|
||||
@register_criterion("sentence_prediction_adapters", dataclass=SentencePredictionConfig)
|
||||
class SentencePredictionCriterionAdapters(SentencePredictionCriterion):
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
assert (
|
||||
hasattr(model, "classification_heads")
|
||||
and self.classification_head_name in model.classification_heads
|
||||
), "model must provide sentence classification head for --criterion=sentence_prediction"
|
||||
|
||||
if not hasattr(sample, "lang_id"):
|
||||
# If no language ID is given, we fall back to English
|
||||
lang_id = ["en_XX"] * sample["nsentences"]
|
||||
else:
|
||||
lang_id = sample["lang_id"]
|
||||
|
||||
logits, _ = model(
|
||||
**sample["net_input"],
|
||||
features_only=True,
|
||||
classification_head_name=self.classification_head_name,
|
||||
lang_id=lang_id,
|
||||
)
|
||||
targets = model.get_targets(sample, [logits]).view(-1)
|
||||
sample_size = targets.numel()
|
||||
|
||||
if not self.regression_target:
|
||||
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
loss = F.nll_loss(lprobs, targets, reduction="sum")
|
||||
else:
|
||||
logits = logits.view(-1).float()
|
||||
targets = targets.float()
|
||||
loss = F.mse_loss(logits, targets, reduction="sum")
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample_size,
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
if not self.regression_target:
|
||||
preds = logits.argmax(dim=1)
|
||||
logging_output["ncorrect"] = (preds == targets).sum()
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
@@ -1,120 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
|
||||
|
||||
@register_criterion("sentence_ranking")
|
||||
class SentenceRankingCriterion(FairseqCriterion):
|
||||
def __init__(self, task, ranking_head_name, save_predictions, num_classes):
|
||||
super().__init__(task)
|
||||
self.ranking_head_name = ranking_head_name
|
||||
if save_predictions is not None:
|
||||
self.prediction_h = open(save_predictions, "w")
|
||||
else:
|
||||
self.prediction_h = None
|
||||
self.num_classes = num_classes
|
||||
|
||||
def __del__(self):
|
||||
if self.prediction_h is not None:
|
||||
self.prediction_h.close()
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
# fmt: off
|
||||
parser.add_argument('--save-predictions', metavar='FILE',
|
||||
help='file to save predictions to')
|
||||
parser.add_argument('--ranking-head-name',
|
||||
default='sentence_classification_head',
|
||||
help='name of the ranking head to use')
|
||||
# fmt: on
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute ranking loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
assert (
|
||||
hasattr(model, "classification_heads")
|
||||
and self.ranking_head_name in model.classification_heads
|
||||
), "model must provide sentence ranking head for --criterion=sentence_ranking"
|
||||
|
||||
scores = []
|
||||
for idx in range(self.num_classes):
|
||||
score, _ = model(
|
||||
**sample["net_input{idx}".format(idx=idx + 1)],
|
||||
classification_head_name=self.ranking_head_name,
|
||||
)
|
||||
scores.append(score)
|
||||
|
||||
logits = torch.cat(scores, dim=1)
|
||||
sample_size = logits.size(0)
|
||||
|
||||
if "target" in sample:
|
||||
targets = model.get_targets(sample, [logits]).view(-1)
|
||||
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
loss = F.nll_loss(lprobs, targets, reduction="sum")
|
||||
else:
|
||||
targets = None
|
||||
loss = torch.tensor(0.0, requires_grad=True)
|
||||
|
||||
if self.prediction_h is not None:
|
||||
preds = logits.argmax(dim=1)
|
||||
for i, (id, pred) in enumerate(zip(sample["id"].tolist(), preds.tolist())):
|
||||
if targets is not None:
|
||||
label = targets[i].item()
|
||||
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
|
||||
else:
|
||||
print("{}\t{}".format(id, pred), file=self.prediction_h)
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample_size,
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
if targets is not None:
|
||||
logging_output["ncorrect"] = (logits.argmax(dim=1) == targets).sum()
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
if sample_size != ntokens:
|
||||
metrics.log_scalar(
|
||||
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
||||
)
|
||||
|
||||
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
|
||||
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
|
||||
metrics.log_scalar(
|
||||
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
@@ -1,516 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import register_criterion
|
||||
from fairseq.criterions.ctc import CtcCriterion
|
||||
from fairseq.criterions.label_smoothed_cross_entropy_with_rdrop import (
|
||||
RdropLabelSmoothedCrossEntropyCriterion,
|
||||
RdropLabelSmoothedCrossEntropyCriterionConfig,
|
||||
duplicate_input,
|
||||
)
|
||||
from fairseq.criterions.tacotron2_loss import (
|
||||
Tacotron2Criterion,
|
||||
Tacotron2CriterionConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultitaskCriterion:
|
||||
def __init__(self, multitask_tasks, rdrop_alpha=0.0):
|
||||
self.rdrop_alpha = rdrop_alpha
|
||||
self.rdrop_alpha_mtl = rdrop_alpha
|
||||
|
||||
self.multitask_criterion = OrderedDict()
|
||||
self.multitask_loss_weight = OrderedDict()
|
||||
for task_name, task_obj in multitask_tasks.items():
|
||||
if task_obj.args.get_loss_weight(0) == 0:
|
||||
logger.info(f"Skip {task_name} loss criterion")
|
||||
continue
|
||||
|
||||
rdrop_alpha_task = task_obj.args.rdrop_alpha
|
||||
if rdrop_alpha_task is None:
|
||||
rdrop_alpha_task = rdrop_alpha
|
||||
self.rdrop_alpha_mtl = rdrop_alpha_task
|
||||
logger.info(f"rdrop_alpha is set to {rdrop_alpha_task} for {task_name}")
|
||||
|
||||
if task_obj.args.decoder_type == "ctc":
|
||||
self.multitask_criterion[task_name] = CtcCriterion(
|
||||
task_obj.args.criterion_cfg,
|
||||
task_obj,
|
||||
rdrop_alpha=rdrop_alpha_task,
|
||||
)
|
||||
else:
|
||||
self.multitask_criterion[
|
||||
task_name
|
||||
] = RdropLabelSmoothedCrossEntropyCriterion(
|
||||
task_obj,
|
||||
task_obj.args.criterion_cfg.sentence_avg,
|
||||
label_smoothing=task_obj.args.criterion_cfg.label_smoothing,
|
||||
rdrop_alpha=rdrop_alpha_task,
|
||||
)
|
||||
|
||||
def set_multitask_loss_weight(self, task_name, weight=0.0):
|
||||
self.multitask_loss_weight[task_name] = weight
|
||||
|
||||
def get_multitask_loss(self, model, sample, model_out):
|
||||
logging_output = {}
|
||||
loss = 0.0
|
||||
for task_name, task_criterion in self.multitask_criterion.items():
|
||||
layer_id = task_criterion.task.args.input_layer
|
||||
if isinstance(task_criterion, CtcCriterion):
|
||||
if task_criterion.task.args.input_from == "encoder":
|
||||
if len(model_out["encoder_padding_mask"]) > 0:
|
||||
non_padding_mask = ~model_out["encoder_padding_mask"][0]
|
||||
input_lengths = non_padding_mask.long().sum(-1)
|
||||
else:
|
||||
out = model_out["encoder_states"][layer_id]
|
||||
input_lengths = out.new_full(
|
||||
(out.shape[1],), out.shape[0]
|
||||
).long()
|
||||
|
||||
task_sample = {
|
||||
"net_input": {
|
||||
"src_tokens": model_out["encoder_states"][
|
||||
layer_id
|
||||
], # check batch idx
|
||||
"src_lengths": input_lengths,
|
||||
},
|
||||
"id": sample["id"],
|
||||
}
|
||||
else:
|
||||
task_sample = {
|
||||
"net_input": {
|
||||
"src_tokens": model_out["inner_states"][layer_id],
|
||||
"src_lengths": sample["target_lengths"],
|
||||
},
|
||||
"id": sample["id"],
|
||||
}
|
||||
else:
|
||||
task_sample = {
|
||||
"net_input": {
|
||||
"src_tokens": sample["multitask"][task_name]["net_input"][
|
||||
"prev_output_tokens"
|
||||
],
|
||||
"encoder_out": {
|
||||
"encoder_out": [model_out["encoder_states"][layer_id]],
|
||||
"encoder_padding_mask": model_out["encoder_padding_mask"],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
for key in ["target", "target_lengths", "ntokens"]:
|
||||
task_sample[key] = sample["multitask"][task_name][key]
|
||||
|
||||
if task_name == getattr(model, "mt_task_name", None):
|
||||
decoder_out = model_out["mt_decoder_out"]
|
||||
else:
|
||||
decoder_out = None
|
||||
task_loss, task_sample_size, task_logging_output = task_criterion(
|
||||
model.multitask_decoders[task_name], task_sample, net_output=decoder_out
|
||||
)
|
||||
|
||||
loss = loss + self.multitask_loss_weight[task_name] * task_loss
|
||||
task_logging_output["loss_weight"] = self.multitask_loss_weight[task_name]
|
||||
logging_output[task_name] = task_logging_output
|
||||
return loss, logging_output
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs) -> None:
|
||||
for task_name in logging_outputs[0]["multitask"].keys():
|
||||
# different criterion may return different logging
|
||||
# currently only reduce on loss, the most common one
|
||||
# ideally the way that losses are reduced should also depend on the task type
|
||||
loss_sum = sum(
|
||||
log["multitask"][task_name].get("loss", 0) for log in logging_outputs
|
||||
)
|
||||
sample_size = sum(
|
||||
log["multitask"][task_name].get("sample_size", 0)
|
||||
for log in logging_outputs
|
||||
)
|
||||
|
||||
metrics.log_scalar(
|
||||
f"multitask_{task_name}_loss",
|
||||
loss_sum / sample_size / math.log(2),
|
||||
sample_size,
|
||||
round=3,
|
||||
)
|
||||
|
||||
loss_weight = logging_outputs[0]["multitask"][task_name].get(
|
||||
"loss_weight", 0
|
||||
)
|
||||
metrics.log_scalar(
|
||||
f"multitask_{task_name}_loss_weight",
|
||||
loss_weight,
|
||||
weight=0,
|
||||
priority=250,
|
||||
)
|
||||
|
||||
|
||||
@register_criterion(
|
||||
"speech_to_unit", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig
|
||||
)
|
||||
class SpeechToUnitMultitaskTaskCriterion(
|
||||
RdropLabelSmoothedCrossEntropyCriterion, MultitaskCriterion
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
sentence_avg,
|
||||
label_smoothing,
|
||||
ignore_prefix_size=0,
|
||||
report_accuracy=False,
|
||||
rdrop_alpha=0.0,
|
||||
):
|
||||
super().__init__(
|
||||
task,
|
||||
sentence_avg,
|
||||
label_smoothing,
|
||||
ignore_prefix_size,
|
||||
report_accuracy,
|
||||
rdrop_alpha,
|
||||
)
|
||||
MultitaskCriterion.__init__(self, task.multitask_tasks, rdrop_alpha)
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
net_input_concat = {
|
||||
"src_tokens": sample["net_input"]["src_tokens"],
|
||||
"src_lengths": sample["net_input"]["src_lengths"],
|
||||
"prev_output_tokens": sample["net_input"]["prev_output_tokens"],
|
||||
"tgt_speaker": sample["net_input"].get("tgt_speaker", None),
|
||||
"return_all_hiddens": True,
|
||||
}
|
||||
|
||||
if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0:
|
||||
net_input_concat = duplicate_input(net_input_concat)
|
||||
|
||||
net_output, extra = model(**net_input_concat)
|
||||
loss, nll_loss, rdrop_kl_loss = self.compute_loss(
|
||||
model, [net_output], sample, reduce=reduce
|
||||
)
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"nll_loss": nll_loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
if self.report_accuracy:
|
||||
n_correct, total = self.compute_accuracy(model, [net_output], sample)
|
||||
logging_output["n_correct"] = utils.item(n_correct.data)
|
||||
logging_output["total"] = utils.item(total.data)
|
||||
if self.rdrop_alpha > 0:
|
||||
logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
|
||||
|
||||
if len(self.multitask_criterion) == 0:
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
# multitask
|
||||
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
|
||||
loss += multitask_loss
|
||||
logging_output["multitask"] = multitask_log
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs) -> None:
|
||||
super().reduce_metrics(logging_outputs)
|
||||
|
||||
# inference metrics
|
||||
if "targ_frames" in logging_outputs[0]:
|
||||
n = sum(log.get("norm_frames", 0) for log in logging_outputs)
|
||||
for key, new_key in [
|
||||
("mcd_loss", "mcd_loss"),
|
||||
("pred_frames", "pred_ratio"),
|
||||
("nins", "ins_rate"),
|
||||
("ndel", "del_rate"),
|
||||
]:
|
||||
val = sum(log.get(key, 0) for log in logging_outputs)
|
||||
metrics.log_scalar(new_key, val / n, n, round=3)
|
||||
|
||||
if "multitask" not in logging_outputs[0]:
|
||||
return
|
||||
|
||||
MultitaskCriterion.reduce_metrics(logging_outputs)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@register_criterion(
|
||||
"speech_to_unit_2pass", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig
|
||||
)
|
||||
class SpeechToUnit2passMultitaskTaskCriterion(SpeechToUnitMultitaskTaskCriterion):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
sentence_avg,
|
||||
label_smoothing,
|
||||
ignore_prefix_size=0,
|
||||
report_accuracy=False,
|
||||
rdrop_alpha=0.0,
|
||||
):
|
||||
super().__init__(
|
||||
task,
|
||||
sentence_avg,
|
||||
label_smoothing,
|
||||
ignore_prefix_size,
|
||||
report_accuracy,
|
||||
rdrop_alpha,
|
||||
)
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
net_input_concat = {
|
||||
"src_tokens": sample["net_input"]["src_tokens"],
|
||||
"src_lengths": sample["net_input"]["src_lengths"],
|
||||
"prev_output_tokens": sample["net_input"]["prev_output_tokens"],
|
||||
"prev_output_tokens_mt": sample["multitask"][model.mt_task_name][
|
||||
"net_input"
|
||||
]["prev_output_tokens"],
|
||||
"tgt_speaker": sample["net_input"].get("tgt_speaker", None),
|
||||
"return_all_hiddens": True,
|
||||
}
|
||||
if getattr(model, "asr_task_name", None) is not None:
|
||||
net_input_concat["prev_output_tokens_asr"] = sample["multitask"][
|
||||
model.asr_task_name
|
||||
]["net_input"]["prev_output_tokens"]
|
||||
|
||||
if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0:
|
||||
net_input_concat = duplicate_input(net_input_concat)
|
||||
|
||||
net_output, extra = model(**net_input_concat)
|
||||
loss, nll_loss, rdrop_kl_loss = self.compute_loss(
|
||||
model, [net_output], sample, reduce=reduce
|
||||
)
|
||||
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"nll_loss": nll_loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
if self.report_accuracy:
|
||||
n_correct, total = self.compute_accuracy(model, [net_output], sample)
|
||||
logging_output["n_correct"] = utils.item(n_correct.data)
|
||||
logging_output["total"] = utils.item(total.data)
|
||||
if self.rdrop_alpha > 0:
|
||||
logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
|
||||
|
||||
if len(self.multitask_criterion) == 0:
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
# multitask
|
||||
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
|
||||
loss += multitask_loss
|
||||
logging_output["multitask"] = multitask_log
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
|
||||
@register_criterion("speech_to_spectrogram", dataclass=Tacotron2CriterionConfig)
|
||||
class SpeechToSpectrogramMultitaskTaskCriterion(Tacotron2Criterion, MultitaskCriterion):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
sentence_avg,
|
||||
use_guided_attention_loss,
|
||||
guided_attention_loss_sigma,
|
||||
bce_pos_weight,
|
||||
ctc_weight,
|
||||
):
|
||||
super().__init__(
|
||||
task,
|
||||
sentence_avg,
|
||||
use_guided_attention_loss,
|
||||
guided_attention_loss_sigma,
|
||||
bce_pos_weight,
|
||||
ctc_weight,
|
||||
)
|
||||
MultitaskCriterion.__init__(self, task.multitask_tasks)
|
||||
|
||||
def forward(self, model, sample, reduction="mean"):
|
||||
bsz, max_len, _ = sample["target"].size()
|
||||
feat_tgt = sample["target"]
|
||||
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
|
||||
eos_tgt = torch.arange(max_len).to(sample["target"].device)
|
||||
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
|
||||
eos_tgt = (eos_tgt == (feat_len - 1)).float()
|
||||
|
||||
feat_out, eos_out, extra = model(
|
||||
src_tokens=sample["net_input"]["src_tokens"],
|
||||
src_lengths=sample["net_input"]["src_lengths"],
|
||||
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
||||
tgt_speaker=sample["net_input"]["tgt_speaker"],
|
||||
target_lengths=sample["target_lengths"],
|
||||
return_all_hiddens=True,
|
||||
)
|
||||
|
||||
l1_loss, mse_loss, eos_loss = self.compute_loss(
|
||||
extra["feature_out"],
|
||||
feat_out,
|
||||
eos_out,
|
||||
feat_tgt,
|
||||
eos_tgt,
|
||||
sample["target_lengths"],
|
||||
reduction,
|
||||
)
|
||||
attn_loss = torch.tensor(0.0).type_as(l1_loss)
|
||||
if self.guided_attn is not None:
|
||||
attn_loss = self.guided_attn(
|
||||
extra["attn"],
|
||||
sample["net_input"]["src_lengths"],
|
||||
sample["target_lengths"],
|
||||
reduction,
|
||||
)
|
||||
loss = (
|
||||
l1_loss + mse_loss + eos_loss + attn_loss
|
||||
) # do not include ctc loss as there's no text target
|
||||
|
||||
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data),
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["nsentences"],
|
||||
"sample_size": sample_size,
|
||||
"l1_loss": utils.item(l1_loss.data),
|
||||
"mse_loss": utils.item(mse_loss.data),
|
||||
"eos_loss": utils.item(eos_loss.data),
|
||||
"attn_loss": utils.item(attn_loss.data),
|
||||
}
|
||||
|
||||
if len(self.multitask_criterion) == 0:
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
# multitask
|
||||
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
|
||||
loss += multitask_loss
|
||||
logging_output["multitask"] = multitask_log
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs) -> None:
|
||||
super().reduce_metrics(logging_outputs)
|
||||
|
||||
# inference metrics
|
||||
if "targ_frames" in logging_outputs[0]:
|
||||
n = sum(log.get("norm_frames", 0) for log in logging_outputs)
|
||||
for key, new_key in [
|
||||
("mcd_loss", "mcd_loss"),
|
||||
("pred_frames", "pred_ratio"),
|
||||
("nins", "ins_rate"),
|
||||
("ndel", "del_rate"),
|
||||
]:
|
||||
val = sum(log.get(key, 0) for log in logging_outputs)
|
||||
metrics.log_scalar(new_key, val / n, n, round=3)
|
||||
|
||||
if "multitask" not in logging_outputs[0]:
|
||||
return
|
||||
|
||||
MultitaskCriterion.reduce_metrics(logging_outputs)
|
||||
|
||||
|
||||
@register_criterion("speech_to_spectrogram_2pass", dataclass=Tacotron2CriterionConfig)
|
||||
class SpeechToSpectrogram2passMultitaskTaskCriterion(
|
||||
SpeechToSpectrogramMultitaskTaskCriterion
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
sentence_avg,
|
||||
use_guided_attention_loss,
|
||||
guided_attention_loss_sigma,
|
||||
bce_pos_weight,
|
||||
ctc_weight,
|
||||
):
|
||||
super().__init__(
|
||||
task,
|
||||
sentence_avg,
|
||||
use_guided_attention_loss,
|
||||
guided_attention_loss_sigma,
|
||||
bce_pos_weight,
|
||||
ctc_weight,
|
||||
)
|
||||
|
||||
def forward(self, model, sample, reduction="mean"):
|
||||
bsz, max_len, _ = sample["target"].size()
|
||||
feat_tgt = sample["target"]
|
||||
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
|
||||
eos_tgt = torch.arange(max_len).to(sample["target"].device)
|
||||
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
|
||||
eos_tgt = (eos_tgt == (feat_len - 1)).float()
|
||||
|
||||
feat_out, eos_out, extra = model(
|
||||
src_tokens=sample["net_input"]["src_tokens"],
|
||||
src_lengths=sample["net_input"]["src_lengths"],
|
||||
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
||||
prev_output_tokens_mt=sample["multitask"][model.mt_task_name]["net_input"][
|
||||
"prev_output_tokens"
|
||||
],
|
||||
tgt_speaker=sample["net_input"]["tgt_speaker"],
|
||||
target_lengths=sample["target_lengths"],
|
||||
return_all_hiddens=True,
|
||||
)
|
||||
|
||||
l1_loss, mse_loss, eos_loss = self.compute_loss(
|
||||
extra["feature_out"],
|
||||
feat_out,
|
||||
eos_out,
|
||||
feat_tgt,
|
||||
eos_tgt,
|
||||
sample["target_lengths"],
|
||||
reduction,
|
||||
)
|
||||
attn_loss = torch.tensor(0.0).type_as(l1_loss)
|
||||
if self.guided_attn is not None:
|
||||
attn_loss = self.guided_attn(
|
||||
extra["attn"],
|
||||
sample["net_input"]["src_lengths"],
|
||||
sample["target_lengths"],
|
||||
reduction,
|
||||
)
|
||||
loss = (
|
||||
l1_loss + mse_loss + eos_loss + attn_loss
|
||||
) # do not include ctc loss as there's no text target
|
||||
|
||||
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data),
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["nsentences"],
|
||||
"sample_size": sample_size,
|
||||
"l1_loss": utils.item(l1_loss.data),
|
||||
"mse_loss": utils.item(mse_loss.data),
|
||||
"eos_loss": utils.item(eos_loss.data),
|
||||
"attn_loss": utils.item(attn_loss.data),
|
||||
}
|
||||
|
||||
if len(self.multitask_criterion) == 0:
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
# multitask
|
||||
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
|
||||
loss += multitask_loss
|
||||
logging_output["multitask"] = multitask_log
|
||||
return loss, sample_size, logging_output
|
||||
@@ -1,126 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch.nn.functional as F
|
||||
from fairseq import metrics
|
||||
from fairseq.tasks import FairseqTask
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from omegaconf import II
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechUnitLmCriterionConfig(FairseqDataclass):
|
||||
sentence_avg: bool = II("optimization.sentence_avg")
|
||||
loss_weights: str = field(
|
||||
default="1.;0.0;0.0",
|
||||
metadata={
|
||||
"help": "Weights of the losses that correspond to token, duration, and F0 streams"
|
||||
},
|
||||
)
|
||||
discrete_duration: bool = II("task.discrete_duration")
|
||||
discrete_f0: bool = II("task.discrete_f0")
|
||||
|
||||
|
||||
def mae_loss(pred, targ, mask, reduce=True):
|
||||
if pred.ndim == 3:
|
||||
pred = pred.squeeze(2)
|
||||
else:
|
||||
assert pred.ndim == 2
|
||||
loss = (pred.float() - targ.float()).abs() * (~mask).float()
|
||||
loss = loss.sum() if reduce else loss.view(-1)
|
||||
return loss
|
||||
|
||||
|
||||
def nll_loss(pred, targ, mask, reduce=True):
|
||||
lprob = F.log_softmax(pred, dim=-1)
|
||||
loss = F.nll_loss(lprob.view(-1, lprob.size(-1)), targ.view(-1), reduction="none")
|
||||
loss = loss * (~mask).float().view(-1)
|
||||
loss = loss.sum() if reduce else loss.view(-1)
|
||||
return loss
|
||||
|
||||
|
||||
@register_criterion("speech_unit_lm_criterion", dataclass=SpeechUnitLmCriterionConfig)
|
||||
class SpeechUnitLmCriterion(FairseqCriterion):
|
||||
def __init__(self, cfg: SpeechUnitLmCriterionConfig, task: FairseqTask):
|
||||
super().__init__(task)
|
||||
self.sentence_avg = cfg.sentence_avg
|
||||
self.weights = torch.tensor([float(w) for w in cfg.loss_weights.split(";")])
|
||||
assert self.weights.size(0) == 3
|
||||
assert (self.weights >= 0.0).all()
|
||||
|
||||
self.dur_loss_fn = nll_loss if cfg.discrete_duration else mae_loss
|
||||
self.f0_loss_fn = nll_loss if cfg.discrete_f0 else mae_loss
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
net_output = model(**sample["net_input"])
|
||||
|
||||
token_loss = nll_loss(
|
||||
net_output["token"], sample["target"], sample["mask"], reduce
|
||||
)
|
||||
dur_loss = self.dur_loss_fn(
|
||||
net_output["duration"],
|
||||
sample["dur_target"],
|
||||
sample["dur_mask"],
|
||||
reduce,
|
||||
)
|
||||
f0_loss = self.f0_loss_fn(
|
||||
net_output["f0"],
|
||||
sample["f0_target"],
|
||||
sample["f0_mask"],
|
||||
reduce,
|
||||
)
|
||||
loss = self.weights.to(token_loss.device) * torch.stack(
|
||||
[token_loss, dur_loss, f0_loss], dim=-1
|
||||
)
|
||||
loss = loss.sum() if reduce else loss.sum(-1)
|
||||
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
"loss": loss.detach().sum().item(),
|
||||
"token_loss": token_loss.detach().sum().item(),
|
||||
"dur_loss": dur_loss.detach().sum().item(),
|
||||
"f0_loss": f0_loss.detach().sum().item(),
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
token_loss_sum = sum(log.get("token_loss", 0) for log in logging_outputs)
|
||||
dur_loss_sum = sum(log.get("dur_loss", 0) for log in logging_outputs)
|
||||
f0_loss_sum = sum(log.get("f0_loss", 0) for log in logging_outputs)
|
||||
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
|
||||
|
||||
metrics.log_scalar(
|
||||
"token_loss", token_loss_sum / sample_size, sample_size, round=3
|
||||
)
|
||||
|
||||
metrics.log_scalar("dur_loss", dur_loss_sum / sample_size, sample_size, round=3)
|
||||
|
||||
metrics.log_scalar("f0_loss", f0_loss_sum / sample_size, sample_size, round=3)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
return True
|
||||
@@ -1,226 +0,0 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from omegaconf import II
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.data.data_utils import lengths_to_mask
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tacotron2CriterionConfig(FairseqDataclass):
|
||||
bce_pos_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "weight of positive examples for BCE loss"},
|
||||
)
|
||||
use_guided_attention_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "use guided attention loss"},
|
||||
)
|
||||
guided_attention_loss_sigma: float = field(
|
||||
default=0.4,
|
||||
metadata={"help": "weight of positive examples for BCE loss"},
|
||||
)
|
||||
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
|
||||
sentence_avg: bool = II("optimization.sentence_avg")
|
||||
|
||||
|
||||
class GuidedAttentionLoss(torch.nn.Module):
|
||||
"""
|
||||
Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
|
||||
Networks with Guided Attention (https://arxiv.org/abs/1710.08969)
|
||||
"""
|
||||
|
||||
def __init__(self, sigma):
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def _get_weight(s_len, t_len, sigma):
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(t_len), torch.arange(s_len))
|
||||
grid_x = grid_x.to(s_len.device)
|
||||
grid_y = grid_y.to(s_len.device)
|
||||
w = (grid_y.float() / s_len - grid_x.float() / t_len) ** 2
|
||||
return 1.0 - torch.exp(-w / (2 * (sigma**2)))
|
||||
|
||||
def _get_weights(self, src_lens, tgt_lens):
|
||||
bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens)
|
||||
weights = torch.zeros((bsz, max_t_len, max_s_len))
|
||||
for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)):
|
||||
weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, self.sigma)
|
||||
return weights
|
||||
|
||||
@staticmethod
|
||||
def _get_masks(src_lens, tgt_lens):
|
||||
in_masks = lengths_to_mask(src_lens)
|
||||
out_masks = lengths_to_mask(tgt_lens)
|
||||
return out_masks.unsqueeze(2) & in_masks.unsqueeze(1)
|
||||
|
||||
def forward(self, attn, src_lens, tgt_lens, reduction="mean"):
|
||||
weights = self._get_weights(src_lens, tgt_lens).to(attn.device)
|
||||
masks = self._get_masks(src_lens, tgt_lens).to(attn.device)
|
||||
loss = (weights * attn.transpose(1, 2)).masked_select(masks)
|
||||
loss = torch.sum(loss) if reduction == "sum" else torch.mean(loss)
|
||||
return loss
|
||||
|
||||
|
||||
@register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig)
|
||||
class Tacotron2Criterion(FairseqCriterion):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
sentence_avg,
|
||||
use_guided_attention_loss,
|
||||
guided_attention_loss_sigma,
|
||||
bce_pos_weight,
|
||||
ctc_weight,
|
||||
):
|
||||
super().__init__(task)
|
||||
self.sentence_avg = sentence_avg
|
||||
self.bce_pos_weight = bce_pos_weight
|
||||
|
||||
self.guided_attn = None
|
||||
if use_guided_attention_loss:
|
||||
self.guided_attn = GuidedAttentionLoss(guided_attention_loss_sigma)
|
||||
self.ctc_weight = ctc_weight
|
||||
|
||||
def forward(self, model, sample, reduction="mean"):
|
||||
bsz, max_len, _ = sample["target"].size()
|
||||
feat_tgt = sample["target"]
|
||||
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
|
||||
eos_tgt = torch.arange(max_len).to(sample["target"].device)
|
||||
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
|
||||
eos_tgt = (eos_tgt == (feat_len - 1)).float()
|
||||
src_tokens = sample["net_input"]["src_tokens"]
|
||||
src_lens = sample["net_input"]["src_lengths"]
|
||||
tgt_lens = sample["target_lengths"]
|
||||
|
||||
feat_out, eos_out, extra = model(
|
||||
src_tokens=src_tokens,
|
||||
src_lengths=src_lens,
|
||||
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
||||
incremental_state=None,
|
||||
target_lengths=tgt_lens,
|
||||
speaker=sample["speaker"],
|
||||
)
|
||||
|
||||
l1_loss, mse_loss, eos_loss = self.compute_loss(
|
||||
extra["feature_out"],
|
||||
feat_out,
|
||||
eos_out,
|
||||
feat_tgt,
|
||||
eos_tgt,
|
||||
tgt_lens,
|
||||
reduction,
|
||||
)
|
||||
attn_loss = torch.tensor(0.0).type_as(l1_loss)
|
||||
if self.guided_attn is not None:
|
||||
attn_loss = self.guided_attn(extra["attn"], src_lens, tgt_lens, reduction)
|
||||
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
|
||||
if self.ctc_weight > 0.0:
|
||||
net_output = (feat_out, eos_out, extra)
|
||||
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
||||
lprobs = lprobs.transpose(0, 1) # T x B x C
|
||||
src_mask = lengths_to_mask(src_lens)
|
||||
src_tokens_flat = src_tokens.masked_select(src_mask)
|
||||
ctc_loss = (
|
||||
F.ctc_loss(
|
||||
lprobs,
|
||||
src_tokens_flat,
|
||||
tgt_lens,
|
||||
src_lens,
|
||||
reduction=reduction,
|
||||
zero_infinity=True,
|
||||
)
|
||||
* self.ctc_weight
|
||||
)
|
||||
loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss
|
||||
|
||||
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data),
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["nsentences"],
|
||||
"sample_size": sample_size,
|
||||
"l1_loss": utils.item(l1_loss.data),
|
||||
"mse_loss": utils.item(mse_loss.data),
|
||||
"eos_loss": utils.item(eos_loss.data),
|
||||
"attn_loss": utils.item(attn_loss.data),
|
||||
"ctc_loss": utils.item(ctc_loss.data),
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
feat_out,
|
||||
feat_out_post,
|
||||
eos_out,
|
||||
feat_tgt,
|
||||
eos_tgt,
|
||||
tgt_lens,
|
||||
reduction="mean",
|
||||
):
|
||||
mask = lengths_to_mask(tgt_lens)
|
||||
_eos_out = eos_out[mask].squeeze()
|
||||
_eos_tgt = eos_tgt[mask]
|
||||
_feat_tgt = feat_tgt[mask]
|
||||
_feat_out = feat_out[mask]
|
||||
_feat_out_post = feat_out_post[mask]
|
||||
|
||||
l1_loss = F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + F.l1_loss(
|
||||
_feat_out_post, _feat_tgt, reduction=reduction
|
||||
)
|
||||
mse_loss = F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + F.mse_loss(
|
||||
_feat_out_post, _feat_tgt, reduction=reduction
|
||||
)
|
||||
eos_loss = F.binary_cross_entropy_with_logits(
|
||||
_eos_out,
|
||||
_eos_tgt,
|
||||
pos_weight=torch.tensor(self.bce_pos_weight),
|
||||
reduction=reduction,
|
||||
)
|
||||
return l1_loss, mse_loss, eos_loss
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
|
||||
ns = [log.get("sample_size", 0) for log in logging_outputs]
|
||||
ntot = sum(ns)
|
||||
ws = [n / (ntot + 1e-8) for n in ns]
|
||||
for key in ["loss", "l1_loss", "mse_loss", "eos_loss", "attn_loss", "ctc_loss"]:
|
||||
vals = [log.get(key, 0) for log in logging_outputs]
|
||||
val = sum(val * w for val, w in zip(vals, ws))
|
||||
metrics.log_scalar(key, val, ntot, round=3)
|
||||
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
|
||||
|
||||
# inference metrics
|
||||
if "targ_frames" not in logging_outputs[0]:
|
||||
return
|
||||
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
|
||||
for key, new_key in [
|
||||
("mcd_loss", "mcd_loss"),
|
||||
("pred_frames", "pred_ratio"),
|
||||
("nins", "ins_rate"),
|
||||
("ndel", "del_rate"),
|
||||
]:
|
||||
val = sum(log.get(key, 0) for log in logging_outputs)
|
||||
metrics.log_scalar(new_key, val / n, n, round=3)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
return False
|
||||
@@ -1,230 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.logging.meters import safe_round
|
||||
from fairseq.utils import is_xla_tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wav2VecCriterionConfig(FairseqDataclass):
|
||||
infonce: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)"
|
||||
},
|
||||
)
|
||||
loss_weights: Optional[List[float]] = field(
|
||||
default=None,
|
||||
metadata={"help": "weights for additional loss terms (not first one)"},
|
||||
)
|
||||
log_keys: List[str] = field(
|
||||
default_factory=lambda: [],
|
||||
metadata={"help": "output keys to log"},
|
||||
)
|
||||
|
||||
|
||||
@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig)
|
||||
class Wav2vecCriterion(FairseqCriterion):
|
||||
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
|
||||
super().__init__(task)
|
||||
self.infonce = infonce
|
||||
self.loss_weights = loss_weights
|
||||
self.log_keys = [] if log_keys is None else log_keys
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
net_output = model(**sample["net_input"])
|
||||
logits = model.get_logits(net_output).float()
|
||||
target = model.get_targets(sample, net_output)
|
||||
self.xla = is_xla_tensor(logits)
|
||||
|
||||
# XXX: handle weights on xla.
|
||||
weights = None
|
||||
if hasattr(model, "get_target_weights") and not self.infonce:
|
||||
weights = model.get_target_weights(target, net_output)
|
||||
if torch.is_tensor(weights):
|
||||
weights = weights.float()
|
||||
|
||||
losses = []
|
||||
|
||||
reduction = "none" if ((not reduce) or self.xla) else "sum"
|
||||
if self.infonce:
|
||||
loss = F.cross_entropy(logits, target, reduction=reduction)
|
||||
else:
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
logits, target.float(), weights, reduction=reduction
|
||||
)
|
||||
|
||||
if self.xla:
|
||||
# tpu-comment: since dynamic shapes lead to recompilations on xla,
|
||||
# we don't shrink tensors using mask_indices.
|
||||
# Instead, we use mask indices to adjust loss.
|
||||
mi = (
|
||||
sample["net_input"]["mask_indices"]
|
||||
.transpose(0, 1) # logits are transposed in `model.get_logits`
|
||||
.reshape(logits.size(0))
|
||||
)
|
||||
loss = (loss * mi).sum() if reduce else (loss * mi)
|
||||
|
||||
if "sample_size" in sample:
|
||||
sample_size = sample["sample_size"]
|
||||
elif "mask_indices" in sample["net_input"]:
|
||||
sample_size = sample["net_input"]["mask_indices"].sum()
|
||||
else:
|
||||
sample_size = target.numel() if self.infonce else target.long().sum().item()
|
||||
losses.append(loss.detach().clone())
|
||||
|
||||
if self.loss_weights is not None:
|
||||
assert hasattr(model, "get_extra_losses")
|
||||
extra_losses = model.get_extra_losses(net_output)
|
||||
if torch.is_tensor(extra_losses):
|
||||
extra_losses = [extra_losses]
|
||||
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
||||
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
||||
assert len(extra_losses) == len(
|
||||
self.loss_weights
|
||||
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
||||
for p, coef in zip(extra_losses, self.loss_weights):
|
||||
if coef != 0 and p is not None:
|
||||
p = coef * p.float() * sample_size
|
||||
loss += p
|
||||
losses.append(p)
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.item() if (reduce and not self.xla) else loss.detach(),
|
||||
"ntokens": sample_size,
|
||||
"nsentences": sample["id"].numel(),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
|
||||
for lk in self.log_keys:
|
||||
# Only store "logits" and "target" for computing MAP and MAUC
|
||||
# during validation
|
||||
if lk == "logits":
|
||||
if not self.training:
|
||||
logging_output["logits"] = logits.cpu().numpy()
|
||||
elif lk == "target":
|
||||
if not self.training:
|
||||
# If the targets have been mixed with the predictions of
|
||||
# teacher models, find the original targets
|
||||
if hasattr(model, "get_original_targets"):
|
||||
original_target = model.get_original_targets(sample, net_output)
|
||||
else:
|
||||
original_target = target
|
||||
logging_output["target"] = original_target.cpu().numpy()
|
||||
elif lk in net_output:
|
||||
value = net_output[lk]
|
||||
if not is_xla_tensor(value):
|
||||
value = float(value)
|
||||
logging_output[lk] = value
|
||||
|
||||
if len(losses) > 1:
|
||||
for i, l in enumerate(losses):
|
||||
logging_output[f"loss_{i}"] = l.item() if not self.xla else l.detach()
|
||||
|
||||
if self.infonce:
|
||||
with torch.no_grad():
|
||||
if logits.numel() == 0:
|
||||
corr = 0
|
||||
count = 0
|
||||
else:
|
||||
assert logits.dim() > 1, logits.shape
|
||||
max = logits.argmax(-1) == 0
|
||||
min = logits.argmin(-1) == 0
|
||||
if is_xla_tensor(logits):
|
||||
max, min = max * mi, min * mi
|
||||
both = max & min
|
||||
corr = max.long().sum() - both.long().sum()
|
||||
count = mi.sum()
|
||||
else:
|
||||
both = max & min
|
||||
corr = max.long().sum().item() - both.long().sum().item()
|
||||
count = float(max.numel())
|
||||
|
||||
logging_output["correct"] = corr
|
||||
logging_output["count"] = count
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
||||
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
||||
nsentences = utils.item(
|
||||
sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
)
|
||||
sample_size = utils.item(
|
||||
sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3
|
||||
)
|
||||
metrics.log_scalar("ntokens", ntokens)
|
||||
metrics.log_scalar("nsentences", nsentences)
|
||||
|
||||
correct = sum(log.get("correct", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_correct", correct)
|
||||
|
||||
total = sum(log.get("count", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_total", total)
|
||||
|
||||
if total > 0:
|
||||
metrics.log_derived(
|
||||
"accuracy",
|
||||
lambda meters: safe_round(
|
||||
meters["_correct"].sum / meters["_total"].sum, 5
|
||||
)
|
||||
if meters["_total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
|
||||
builtin_keys = {
|
||||
"loss",
|
||||
"ntokens",
|
||||
"nsentences",
|
||||
"sample_size",
|
||||
"correct",
|
||||
"count",
|
||||
}
|
||||
|
||||
for k in logging_outputs[0]:
|
||||
if k not in builtin_keys:
|
||||
val = sum(log.get(k, 0) for log in logging_outputs)
|
||||
if k.startswith("loss"):
|
||||
metrics.log_scalar(
|
||||
k, val / (sample_size or 1) / math.log(2), sample_size, round=3
|
||||
)
|
||||
else:
|
||||
metrics.log_scalar(k, val / len(logging_outputs), round=3)
|
||||
|
||||
# FIXME: revert when gather based xla reduction is implemented
|
||||
# @staticmethod
|
||||
# def logging_outputs_can_be_summed() -> bool:
|
||||
def logging_outputs_can_be_summed(self) -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
# XXX: Gather based reduction not implemented for xla yet.
|
||||
# So we fall to sum based reduction for xla.
|
||||
return self.xla
|
||||
@@ -1,6 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import criterions, models, modules # noqa
|
||||
@@ -1,14 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
|
||||
# automatically import any Python files in the criterions/ directory
|
||||
for file in sorted(os.listdir(os.path.dirname(__file__))):
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
module = file[: file.find(".py")]
|
||||
importlib.import_module("fairseq.model_parallel.criterions." + module)
|
||||
@@ -1,87 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
|
||||
|
||||
try:
|
||||
from fairseq.model_parallel.megatron.mpu.cross_entropy import (
|
||||
vocab_parallel_cross_entropy,
|
||||
)
|
||||
|
||||
has_megatron_submodule = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
has_megatron_submodule = False
|
||||
|
||||
|
||||
@register_criterion("vocab_parallel_cross_entropy")
|
||||
class VocabParallelCrossEntropyCriterion(FairseqCriterion):
|
||||
def __init__(self, task, sentence_avg):
|
||||
super().__init__(task)
|
||||
self.sentence_avg = sentence_avg
|
||||
if not has_megatron_submodule:
|
||||
raise ImportError(
|
||||
"\n\nPlease install the megatron submodule:"
|
||||
"\n\n git submodule update --init "
|
||||
"fairseq/model_parallel/megatron"
|
||||
)
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
net_output = model(**sample["net_input"])
|
||||
target = sample["target"]
|
||||
|
||||
loss = vocab_parallel_cross_entropy(net_output[0].float(), target)
|
||||
loss = (loss * (target != self.padding_idx)).sum()
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data) if reduce else loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
if sample_size != ntokens:
|
||||
metrics.log_scalar(
|
||||
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
||||
)
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
||||
)
|
||||
else:
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
||||
@@ -1,75 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Train a network across multiple GPUs.
|
||||
"""
|
||||
|
||||
from fairseq.dataclass.configs import FairseqConfig
|
||||
from fairseq.distributed import utils as distributed_utils
|
||||
from fairseq.trainer import Trainer
|
||||
|
||||
try:
|
||||
from fairseq.model_parallel.megatron.mpu import (
|
||||
get_data_parallel_rank,
|
||||
get_data_parallel_world_size,
|
||||
get_model_parallel_src_rank,
|
||||
get_cuda_rng_tracker,
|
||||
)
|
||||
|
||||
has_megatron_submodule = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
has_megatron_submodule = False
|
||||
|
||||
|
||||
class MegatronTrainer(Trainer):
|
||||
"""Main class for model parallel with data parallel training."""
|
||||
|
||||
def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs):
|
||||
if not has_megatron_submodule:
|
||||
raise ImportError(
|
||||
"\n\nPlease install the megatron submodule:"
|
||||
"\n\n git submodule update --init "
|
||||
"fairseq/model_parallel/megatron"
|
||||
)
|
||||
super().__init__(cfg, task, model, criterion, **kwargs)
|
||||
|
||||
def clip_grad_norm(self, clip_norm):
|
||||
def _aggregate_model_parallel_grad_norm(total_norm):
|
||||
total_norm = total_norm**2
|
||||
distributed_utils.all_reduce(
|
||||
total_norm, group=distributed_utils.get_model_parallel_group()
|
||||
)
|
||||
total_norm = total_norm**0.5
|
||||
return total_norm
|
||||
|
||||
return self.optimizer.clip_grad_norm(
|
||||
clip_norm,
|
||||
aggregate_norm_fn=_aggregate_model_parallel_grad_norm,
|
||||
)
|
||||
|
||||
def save_checkpoint(self, filename, extra_state):
|
||||
"""Save all training state in a checkpoint file."""
|
||||
extra_state["rng_tracker_states"] = get_cuda_rng_tracker().get_states()
|
||||
super().save_checkpoint(filename, extra_state)
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
filename,
|
||||
reset_optimizer=False,
|
||||
reset_lr_scheduler=False,
|
||||
optimizer_overrides=None,
|
||||
reset_meters=False,
|
||||
):
|
||||
extra_state = super().load_checkpoint(
|
||||
filename,
|
||||
reset_optimizer=reset_optimizer,
|
||||
reset_lr_scheduler=reset_lr_scheduler,
|
||||
optimizer_overrides=optimizer_overrides,
|
||||
reset_meters=reset_meters,
|
||||
)
|
||||
if extra_state is not None and "rng_tracker_states" in extra_state:
|
||||
get_cuda_rng_tracker().set_states(extra_state["rng_tracker_states"])
|
||||
return extra_state
|
||||
@@ -1,20 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
|
||||
# automatically import any Python files in the models/ directory
|
||||
models_dir = os.path.dirname(__file__)
|
||||
for file in os.listdir(models_dir):
|
||||
path = os.path.join(models_dir, file)
|
||||
if (
|
||||
not file.startswith("_")
|
||||
and not file.startswith(".")
|
||||
and (file.endswith(".py") or os.path.isdir(path))
|
||||
):
|
||||
model_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
module = importlib.import_module("fairseq.model_parallel.models." + model_name)
|
||||
@@ -1,6 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .model import * # noqa
|
||||
@@ -1,600 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import options, utils
|
||||
from fairseq.modules import (
|
||||
AdaptiveSoftmax,
|
||||
LayerNorm,
|
||||
MultiheadAttention,
|
||||
PositionalEmbedding,
|
||||
)
|
||||
|
||||
EncoderOut = namedtuple(
|
||||
"TransformerEncoderOut",
|
||||
[
|
||||
"encoder_out", # T x B x C
|
||||
"encoder_padding_mask", # B x T
|
||||
"encoder_embedding", # B x T x C
|
||||
"encoder_states", # List[T x B x C]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TransformerEncoderEmbedding(nn.Module):
|
||||
"""Encoder Embedding + Positional Embedding"""
|
||||
|
||||
def __init__(self, args, embed_tokens):
|
||||
super().__init__()
|
||||
self.dropout = args.dropout
|
||||
self.max_source_positions = args.max_source_positions
|
||||
self.embed_tokens = embed_tokens
|
||||
if isinstance(embed_tokens, nn.ModuleList):
|
||||
self.padding_idx = embed_tokens[0].padding_idx
|
||||
embed_dim = sum(e.embedding_dim for e in embed_tokens)
|
||||
else:
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
embed_dim = embed_tokens.embedding_dim
|
||||
self.embed_scale = math.sqrt(embed_dim)
|
||||
self.embed_positions = (
|
||||
PositionalEmbedding(
|
||||
args.max_source_positions,
|
||||
embed_dim,
|
||||
self.padding_idx,
|
||||
learned=args.encoder_learned_pos,
|
||||
)
|
||||
if not args.no_token_positional_embeddings
|
||||
else None
|
||||
)
|
||||
if getattr(args, "layernorm_embedding", False):
|
||||
self.layernorm_embedding = LayerNorm(embed_dim)
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
|
||||
def forward(self, input):
|
||||
# embed tokens and positions
|
||||
src_tokens = input[0]
|
||||
prev_output_tokens = input[2]
|
||||
if isinstance(self.embed_tokens, nn.ModuleList):
|
||||
x_embed_list = []
|
||||
for embed_tokens_part in self.embed_tokens:
|
||||
x_embed_list.append(embed_tokens_part(src_tokens))
|
||||
|
||||
embedded = torch.cat(x_embed_list, dim=-1)
|
||||
else:
|
||||
embedded = self.embed_tokens(src_tokens)
|
||||
x = embed = self.embed_scale * embedded
|
||||
if self.embed_positions is not None:
|
||||
x = embed + self.embed_positions(src_tokens)
|
||||
if self.layernorm_embedding:
|
||||
x = self.layernorm_embedding(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
# compute padding mask
|
||||
encoder_padding_mask = src_tokens.eq(self.padding_idx)
|
||||
return (x, encoder_padding_mask, prev_output_tokens)
|
||||
|
||||
|
||||
class TransformerEncoderLayerNorm(nn.Module):
|
||||
"""
|
||||
Layer norm at the the end of all encoder layers if
|
||||
args.encoder_enormalize_before = True
|
||||
"""
|
||||
|
||||
def __init__(self, args, embed_dim):
|
||||
super().__init__()
|
||||
if args.encoder_normalize_before:
|
||||
self.layer_norm = LayerNorm(embed_dim)
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
def forward(self, input):
|
||||
x = input[0]
|
||||
encoder_padding_mask = input[1]
|
||||
prev_output_tokens = input[2]
|
||||
if self.layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
# keeping track of the incremental_state is not supported yet
|
||||
return (x, encoder_padding_mask, prev_output_tokens)
|
||||
|
||||
|
||||
class TransformerDecoderEmbedding(nn.Module):
|
||||
"""Decoder Embedding + Positional Embedding"""
|
||||
|
||||
def __init__(self, args, embed_tokens):
|
||||
super().__init__()
|
||||
self.dropout = args.dropout
|
||||
self.share_input_output_embed = args.share_decoder_input_output_embed
|
||||
input_embed_dim = (
|
||||
sum(e.embedding_dim for e in embed_tokens)
|
||||
if isinstance(embed_tokens, nn.ModuleList)
|
||||
else embed_tokens.embedding_dim
|
||||
)
|
||||
embed_dim = args.decoder_embed_dim
|
||||
self.output_embed_dim = args.decoder_output_dim
|
||||
|
||||
padding_idx = (
|
||||
embed_tokens[0].padding_idx
|
||||
if isinstance(embed_tokens, nn.ModuleList)
|
||||
else embed_tokens.padding_idx
|
||||
)
|
||||
self.max_target_positions = args.max_target_positions
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
|
||||
|
||||
self.project_in_dim = (
|
||||
Linear(input_embed_dim, embed_dim, bias=False)
|
||||
if embed_dim != input_embed_dim
|
||||
else None
|
||||
)
|
||||
|
||||
self.embed_positions = (
|
||||
PositionalEmbedding(
|
||||
args.max_target_positions,
|
||||
embed_dim,
|
||||
padding_idx,
|
||||
learned=args.decoder_learned_pos,
|
||||
)
|
||||
if not args.no_token_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
mt_task = False
|
||||
if isinstance(input, tuple):
|
||||
if len(input) == 3:
|
||||
encoder_out = input[0]
|
||||
encoder_padding_mask = input[1]
|
||||
prev_output_tokens = input[2]
|
||||
incremental_state = None # Hardcoding to avoid passing of None objects
|
||||
mt_task = True
|
||||
else:
|
||||
# HACK for now, need to fix (TODO sidgoyal)
|
||||
prev_output_tokens = input[0]
|
||||
# discard "src_lengths"
|
||||
encoder_out = None
|
||||
encoder_padding_mask = None
|
||||
incremental_state = None
|
||||
|
||||
else:
|
||||
prev_output_tokens = input
|
||||
encoder_out = None
|
||||
encoder_padding_mask = None
|
||||
incremental_state = None
|
||||
|
||||
positions = (
|
||||
self.embed_positions(
|
||||
prev_output_tokens,
|
||||
incremental_state=incremental_state,
|
||||
)
|
||||
if self.embed_positions is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if incremental_state is not None:
|
||||
prev_output_tokens = prev_output_tokens[:, -1:]
|
||||
if positions is not None:
|
||||
positions = positions[:, -1:]
|
||||
|
||||
# embed tokens and positions
|
||||
|
||||
if isinstance(self.embed_tokens, nn.ModuleList):
|
||||
x_embed_list = []
|
||||
for embed_tokens_part in self.embed_tokens:
|
||||
x_embed_list.append(embed_tokens_part(prev_output_tokens))
|
||||
|
||||
x = self.embed_scale * torch.cat(x_embed_list, dim=-1)
|
||||
else:
|
||||
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
||||
|
||||
if self.project_in_dim is not None:
|
||||
x = self.project_in_dim(x)
|
||||
|
||||
if positions is not None:
|
||||
x += positions
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
if mt_task:
|
||||
return (x, encoder_out, encoder_padding_mask)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderOutputLayer(nn.Module):
|
||||
def __init__(self, args, embed_tokens, dictionary):
|
||||
super().__init__()
|
||||
self.share_input_output_embed = args.share_decoder_input_output_embed
|
||||
self.embed_tokens = embed_tokens
|
||||
self.output_embed_dim = args.decoder_output_dim
|
||||
embed_dim = args.decoder_embed_dim
|
||||
|
||||
self.project_out_dim = (
|
||||
Linear(embed_dim, self.output_embed_dim, bias=False)
|
||||
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
|
||||
else None
|
||||
)
|
||||
self.adaptive_softmax = None
|
||||
if args.adaptive_softmax_cutoff is not None:
|
||||
assert not isinstance(embed_tokens, nn.ModuleList)
|
||||
self.adaptive_softmax = AdaptiveSoftmax(
|
||||
len(dictionary),
|
||||
self.output_embed_dim,
|
||||
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
|
||||
dropout=args.adaptive_softmax_dropout,
|
||||
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
|
||||
factor=args.adaptive_softmax_factor,
|
||||
tie_proj=args.tie_adaptive_proj,
|
||||
)
|
||||
elif not self.share_input_output_embed:
|
||||
self.embed_tokens = nn.Parameter(
|
||||
torch.Tensor(len(dictionary), self.output_embed_dim)
|
||||
)
|
||||
nn.init.normal_(
|
||||
self.embed_tokens, mean=0, std=self.output_embed_dim**-0.5
|
||||
)
|
||||
|
||||
if args.decoder_normalize_before and not getattr(
|
||||
args, "no_decoder_final_norm", False
|
||||
):
|
||||
self.layer_norm = LayerNorm(embed_dim)
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
def forward(self, input, apply_final_proj=True):
|
||||
if isinstance(input, tuple):
|
||||
x = input[0]
|
||||
else:
|
||||
x = input
|
||||
|
||||
if self.layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
if self.project_out_dim is not None:
|
||||
x = self.project_out_dim(x)
|
||||
if apply_final_proj:
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
|
||||
def output_layer(self, features, **kwargs):
|
||||
"""Project features to the vocabulary size."""
|
||||
if self.adaptive_softmax is None:
|
||||
# project back to size of vocabulary
|
||||
if self.share_input_output_embed:
|
||||
if isinstance(self.embed_tokens, nn.ModuleList):
|
||||
output = None
|
||||
for i, emb in enumerate(self.embed_tokens):
|
||||
sidx = i * emb.embedding_dim
|
||||
eidx = (i + 1) * emb.embedding_dim
|
||||
if output is None:
|
||||
output = F.linear(features[:, :, sidx:eidx], emb.weight)
|
||||
else:
|
||||
output += F.linear(features[:, :, sidx:eidx], emb.weight)
|
||||
|
||||
return output
|
||||
else:
|
||||
return F.linear(features, self.embed_tokens.weight)
|
||||
else:
|
||||
return F.linear(features, self.embed_tokens)
|
||||
else:
|
||||
return features
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""Encoder layer block.
|
||||
In the original paper each operation (multi-head attention or FFN) is
|
||||
postprocessed with: `dropout -> add residual -> layernorm`. In the
|
||||
tensor2tensor code they suggest that learning is more robust when
|
||||
preprocessing each layer with layernorm and postprocessing with:
|
||||
`dropout -> add residual`. We default to the approach in the paper, but the
|
||||
tensor2tensor approach can be enabled by setting
|
||||
*args.encoder_normalize_before* to ``True``.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.embed_dim = args.encoder_embed_dim
|
||||
self.self_attn = MultiheadAttention(
|
||||
self.embed_dim,
|
||||
args.encoder_attention_heads,
|
||||
dropout=args.attention_dropout,
|
||||
self_attention=True,
|
||||
)
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.dropout = args.dropout
|
||||
self.activation_fn = utils.get_activation_fn(
|
||||
activation=getattr(args, "activation_fn", "relu")
|
||||
)
|
||||
self.activation_dropout = getattr(args, "activation_dropout", 0)
|
||||
if self.activation_dropout == 0:
|
||||
# for backwards compatibility with models that use args.relu_dropout
|
||||
self.activation_dropout = getattr(args, "relu_dropout", 0)
|
||||
self.normalize_before = args.encoder_normalize_before
|
||||
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
|
||||
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
|
||||
self.final_layer_norm = LayerNorm(self.embed_dim)
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""
|
||||
Rename layer norm states from `...layer_norms.0.weight` to
|
||||
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
|
||||
`...final_layer_norm.weight`
|
||||
"""
|
||||
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
|
||||
for old, new in layer_norm_map.items():
|
||||
for m in ("weight", "bias"):
|
||||
k = "{}.layer_norms.{}.{}".format(name, old, m)
|
||||
if k in state_dict:
|
||||
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
|
||||
del state_dict[k]
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
Args:
|
||||
input (Tuple):
|
||||
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
input[1] (ByteTensor/FloatTensor): encoder padding mask -
|
||||
binary ByteTensor of shape `(batch, src_len)` where padding elements
|
||||
are indicated by ``1``.
|
||||
input[2] (LongTensor): previous decoder outputs of shape
|
||||
`(batch, tgt_len)`, for teacher forcing)
|
||||
Returns:
|
||||
output (Tuple):
|
||||
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
|
||||
output[1] (ByteTensor/FloatTensor): encoder padding mask
|
||||
output[2] (LongTensor): previous decoder outputs
|
||||
"""
|
||||
x = input[0]
|
||||
encoder_padding_mask = input[1]
|
||||
prev_output_tokens = input[2]
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
|
||||
x, _ = self.self_attn(
|
||||
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
|
||||
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
|
||||
return (x, encoder_padding_mask, prev_output_tokens)
|
||||
|
||||
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
|
||||
assert before ^ after
|
||||
if after ^ self.normalize_before:
|
||||
return layer_norm(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
"""Decoder layer block.
|
||||
|
||||
In the original paper each operation (multi-head attention, encoder
|
||||
attention or FFN) is postprocessed with: `dropout -> add residual ->
|
||||
layernorm`. In the tensor2tensor code they suggest that learning is more
|
||||
robust when preprocessing each layer with layernorm and postprocessing with:
|
||||
`dropout -> add residual`. We default to the approach in the paper, but the
|
||||
tensor2tensor approach can be enabled by setting
|
||||
*args.decoder_normalize_before* to ``True``.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = args.decoder_embed_dim
|
||||
self.self_attn = MultiheadAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=args.decoder_attention_heads,
|
||||
dropout=args.attention_dropout,
|
||||
add_bias_kv=add_bias_kv,
|
||||
add_zero_attn=add_zero_attn,
|
||||
self_attention=True,
|
||||
)
|
||||
self.dropout = args.dropout
|
||||
self.activation_fn = utils.get_activation_fn(
|
||||
activation=getattr(args, "activation_fn", "relu")
|
||||
)
|
||||
self.activation_dropout = getattr(args, "activation_dropout", 0)
|
||||
if self.activation_dropout == 0:
|
||||
# for backwards compatibility with models that use args.relu_dropout
|
||||
self.activation_dropout = getattr(args, "relu_dropout", 0)
|
||||
self.normalize_before = args.decoder_normalize_before
|
||||
|
||||
# use layerNorm rather than FusedLayerNorm for exporting.
|
||||
# char_inputs can be used to determint this.
|
||||
# TODO remove this once we update apex with the fix
|
||||
export = getattr(args, "char_inputs", False)
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
|
||||
|
||||
if no_encoder_attn:
|
||||
self.encoder_attn = None
|
||||
self.encoder_attn_layer_norm = None
|
||||
else:
|
||||
self.encoder_attn = MultiheadAttention(
|
||||
self.embed_dim,
|
||||
args.decoder_attention_heads,
|
||||
kdim=getattr(args, "encoder_embed_dim", None),
|
||||
vdim=getattr(args, "encoder_embed_dim", None),
|
||||
dropout=args.attention_dropout,
|
||||
encoder_decoder_attention=True,
|
||||
)
|
||||
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
|
||||
|
||||
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
|
||||
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
|
||||
|
||||
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
|
||||
self.need_attn = True
|
||||
|
||||
self.onnx_trace = False
|
||||
|
||||
def prepare_for_onnx_export_(self):
|
||||
self.onnx_trace = True
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
Args:
|
||||
input (Tuple):
|
||||
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)`
|
||||
input[2] (ByteTensor/FloatTensor): encoder padding mask -
|
||||
binary ByteTensor of shape `(batch, src_len)` where padding elements
|
||||
are indicated by ``1``.
|
||||
Returns:
|
||||
output (Tuple):
|
||||
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
|
||||
output[1] (ByteTensor/FloatTensor): encoder padding mask
|
||||
output[2] (LongTensor): previous decoder outputs
|
||||
"""
|
||||
# Note: incremental state is not yet supported
|
||||
mt_task = False
|
||||
if isinstance(input, tuple):
|
||||
x = input[0]
|
||||
encoder_out = input[1]
|
||||
encoder_padding_mask = input[2]
|
||||
incremental_state = None
|
||||
mt_task = True
|
||||
else:
|
||||
x = input
|
||||
encoder_out = None
|
||||
encoder_padding_mask = None
|
||||
incremental_state = None
|
||||
|
||||
if incremental_state is None:
|
||||
self_attn_mask = self.buffered_future_mask(x)
|
||||
else:
|
||||
self_attn_mask = None
|
||||
|
||||
# TODO: add back prev_self_attn_state, prev_attn_state,
|
||||
# self_attn_padding_mask
|
||||
prev_self_attn_state = None
|
||||
prev_attn_state = None
|
||||
self_attn_padding_mask = None
|
||||
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
|
||||
if prev_self_attn_state is not None:
|
||||
if incremental_state is None:
|
||||
incremental_state = {}
|
||||
prev_key, prev_value = prev_self_attn_state
|
||||
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
|
||||
self.self_attn._set_input_buffer(incremental_state, saved_state)
|
||||
x, attn = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
incremental_state=incremental_state,
|
||||
need_weights=False,
|
||||
attn_mask=self_attn_mask,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
|
||||
|
||||
if self.encoder_attn is not None:
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
|
||||
if prev_attn_state is not None:
|
||||
if incremental_state is None:
|
||||
incremental_state = {}
|
||||
prev_key, prev_value = prev_attn_state
|
||||
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
|
||||
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
|
||||
x, attn = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_out,
|
||||
value=encoder_out,
|
||||
key_padding_mask=encoder_padding_mask,
|
||||
incremental_state=incremental_state,
|
||||
static_kv=True,
|
||||
need_weights=(not self.training and self.need_attn),
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
|
||||
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
|
||||
|
||||
if mt_task:
|
||||
return (x, encoder_out, encoder_padding_mask)
|
||||
return x
|
||||
|
||||
def buffered_future_mask(self, tensor):
|
||||
dim = tensor.size(0)
|
||||
if (
|
||||
not hasattr(self, "_future_mask")
|
||||
or self._future_mask is None
|
||||
or self._future_mask.device != tensor.device
|
||||
):
|
||||
self._future_mask = torch.triu(
|
||||
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
|
||||
)
|
||||
if self._future_mask.size(0) < dim:
|
||||
self._future_mask = torch.triu(
|
||||
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
|
||||
)
|
||||
return self._future_mask[:dim, :dim]
|
||||
|
||||
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
|
||||
assert before ^ after
|
||||
if after ^ self.normalize_before:
|
||||
return layer_norm(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
def make_generation_fast_(self, need_attn=False, **kwargs):
|
||||
self.need_attn = need_attn
|
||||
|
||||
|
||||
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
||||
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
|
||||
nn.init.constant_(m.weight[padding_idx], 0)
|
||||
return m
|
||||
|
||||
|
||||
def Linear(in_features, out_features, bias=True):
|
||||
m = nn.Linear(in_features, out_features, bias)
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if bias:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
return m
|
||||
@@ -1,789 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq import utils
|
||||
from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import (
|
||||
Embedding,
|
||||
TransformerDecoderEmbedding,
|
||||
TransformerDecoderLayer,
|
||||
TransformerDecoderOutputLayer,
|
||||
TransformerEncoderEmbedding,
|
||||
TransformerEncoderLayer,
|
||||
TransformerEncoderLayerNorm,
|
||||
)
|
||||
from fairseq.models import (
|
||||
BaseFairseqModel,
|
||||
FairseqDecoder,
|
||||
FairseqEncoder,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.models.fairseq_encoder import EncoderOut
|
||||
from fairseq.models.transformer import (
|
||||
base_architecture,
|
||||
transformer_iwslt_de_en,
|
||||
transformer_wmt_en_de_big,
|
||||
)
|
||||
from fairseq.modules import SinusoidalPositionalEmbedding
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||
TORCH_PIPE = False
|
||||
RPC_INIT = False
|
||||
|
||||
|
||||
def import_pipe():
|
||||
global TORCH_PIPE
|
||||
global RPC_INIT
|
||||
try:
|
||||
from torch.distributed.pipeline.sync import Pipe # noqa
|
||||
|
||||
global Pipe
|
||||
from torch.distributed.pipeline.sync.utils import partition_model
|
||||
|
||||
global partition_model
|
||||
from torch.distributed import rpc
|
||||
import tempfile
|
||||
|
||||
TORCH_PIPE = True
|
||||
# Initialize single process RPC agent since TORCH_PIPE requires
|
||||
# RRef. RRef depends on RPC being initialized and as a result we initialize
|
||||
# RPC with a single node.
|
||||
tmpfile = tempfile.NamedTemporaryFile()
|
||||
if not RPC_INIT:
|
||||
rpc.init_rpc(
|
||||
name="worker",
|
||||
rank=0,
|
||||
world_size=1,
|
||||
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
|
||||
init_method="file://{}".format(tmpfile.name),
|
||||
),
|
||||
)
|
||||
RPC_INIT = True
|
||||
logger.info("Using torch pipe")
|
||||
except ImportError:
|
||||
try:
|
||||
from fairscale.nn import Pipe # noqa
|
||||
|
||||
logger.info("Using fairscale pipe")
|
||||
except ImportError:
|
||||
raise ImportError("Please install fairscale with: pip install fairscale")
|
||||
|
||||
|
||||
@register_model("pipeline_parallel_transformer")
|
||||
class PipelineParallelTransformerModel(BaseFairseqModel):
|
||||
def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint):
|
||||
import_pipe()
|
||||
super().__init__()
|
||||
assert isinstance(encoder, FairseqEncoder)
|
||||
assert isinstance(decoder, FairseqDecoder)
|
||||
encoder_module_list = (
|
||||
[encoder.embedding_layer]
|
||||
+ list(encoder.encoder_layers)
|
||||
+ [encoder.final_layer_norm]
|
||||
)
|
||||
self.num_encoder_modules = len(encoder_module_list)
|
||||
decoder_module_list = (
|
||||
[decoder.embedding_layer]
|
||||
+ list(decoder.decoder_layers)
|
||||
+ [decoder.decoder_output_layer]
|
||||
)
|
||||
self.num_decoder_modules = len(decoder_module_list)
|
||||
module_list = encoder_module_list + decoder_module_list
|
||||
self.devices = devices
|
||||
if TORCH_PIPE:
|
||||
self.model = Pipe(
|
||||
partition_model(nn.Sequential(*module_list), balance, devices),
|
||||
chunks=chunks,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
else:
|
||||
self.model = Pipe(
|
||||
nn.Sequential(*module_list),
|
||||
balance=balance,
|
||||
devices=devices,
|
||||
chunks=chunks,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
self.encoder_max_positions = self.max_positions_helper(
|
||||
encoder.embedding_layer, "max_source_positions"
|
||||
)
|
||||
self.decoder_max_positions = self.max_positions_helper(
|
||||
decoder.embedding_layer, "max_target_positions"
|
||||
)
|
||||
self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None)
|
||||
# Note: To be populated during inference
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
|
||||
def forward(self, src_tokens, src_lengths, prev_output_tokens):
|
||||
if self.training:
|
||||
input_lst = [src_tokens, src_lengths, prev_output_tokens]
|
||||
input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst)
|
||||
if TORCH_PIPE:
|
||||
return self.model(input).local_value()
|
||||
else:
|
||||
return self.model(input)
|
||||
else:
|
||||
assert self.encoder is not None and self.decoder is not None, (
|
||||
"encoder and decoder need to be initialized by "
|
||||
+ "calling the `prepare_for_inference_()` method"
|
||||
)
|
||||
encoder_output_tuple = self.encoder(input)
|
||||
return self.decoder(encoder_output_tuple)
|
||||
|
||||
def prepare_for_inference_(self, cfg):
|
||||
if self.encoder is not None and self.decoder is not None:
|
||||
logger.info("Encoder and Decoder already initialized")
|
||||
return
|
||||
encoder_module_list = []
|
||||
decoder_module_list = []
|
||||
module_count = 0
|
||||
for partition in self.model.partitions:
|
||||
for module in partition:
|
||||
if module_count < self.num_encoder_modules:
|
||||
encoder_module_list.append(module)
|
||||
else:
|
||||
decoder_module_list.append(module)
|
||||
module_count += 1
|
||||
self.model = None
|
||||
self.encoder = TransformerEncoder(
|
||||
cfg.distributed_training, None, None, encoder_module_list
|
||||
)
|
||||
self.decoder = TransformerDecoder(
|
||||
cfg.distributed_training,
|
||||
None,
|
||||
None,
|
||||
decoder_module_list=decoder_module_list,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add model-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
parser.add_argument('--activation-fn',
|
||||
choices=utils.get_available_activation_fns(),
|
||||
help='activation function to use')
|
||||
parser.add_argument('--dropout', type=float, metavar='D',
|
||||
help='dropout probability')
|
||||
parser.add_argument('--attention-dropout', type=float, metavar='D',
|
||||
help='dropout probability for attention weights')
|
||||
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
|
||||
help='dropout probability after activation in FFN.')
|
||||
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
|
||||
help='path to pre-trained encoder embedding')
|
||||
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
|
||||
help='encoder embedding dimension')
|
||||
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
|
||||
help='encoder embedding dimension for FFN')
|
||||
parser.add_argument('--encoder-layers', type=int, metavar='N',
|
||||
help='num encoder layers')
|
||||
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
|
||||
help='num encoder attention heads')
|
||||
parser.add_argument('--encoder-normalize-before', action='store_true',
|
||||
help='apply layernorm before each encoder block')
|
||||
parser.add_argument('--encoder-learned-pos', action='store_true',
|
||||
help='use learned positional embeddings in the encoder')
|
||||
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
|
||||
help='path to pre-trained decoder embedding')
|
||||
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
|
||||
help='decoder embedding dimension')
|
||||
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
|
||||
help='decoder embedding dimension for FFN')
|
||||
parser.add_argument('--decoder-layers', type=int, metavar='N',
|
||||
help='num decoder layers')
|
||||
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
|
||||
help='num decoder attention heads')
|
||||
parser.add_argument('--decoder-learned-pos', action='store_true',
|
||||
help='use learned positional embeddings in the decoder')
|
||||
parser.add_argument('--decoder-normalize-before', action='store_true',
|
||||
help='apply layernorm before each decoder block')
|
||||
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
|
||||
help='share decoder input and output embeddings')
|
||||
parser.add_argument('--share-all-embeddings', action='store_true',
|
||||
help='share encoder, decoder and output embeddings'
|
||||
' (requires shared dictionary and embed dim)')
|
||||
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
|
||||
help='if set, disables positional embeddings (outside self attention)')
|
||||
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
|
||||
help='comma separated list of adaptive softmax cutoff points. '
|
||||
'Must be used with adaptive_loss criterion'),
|
||||
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
|
||||
help='sets adaptive softmax dropout for the tail projections')
|
||||
parser.add_argument('--num-embedding-chunks', type=int, metavar='N', default=1,
|
||||
help='Number of embedding layer chunks (enables more even distribution'
|
||||
'of optimizer states across data parallel nodes'
|
||||
'when using optimizer state sharding and'
|
||||
'a big embedding vocabulary)')
|
||||
# fmt: on
|
||||
|
||||
@classmethod
|
||||
def build_model_base(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
|
||||
# make sure all arguments are present in older models
|
||||
base_architecture(args)
|
||||
|
||||
if not hasattr(args, "max_source_positions"):
|
||||
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
|
||||
if not hasattr(args, "max_target_positions"):
|
||||
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
|
||||
|
||||
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
||||
|
||||
def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1):
|
||||
assert embed_dim % num_embed_chunks == 0, (
|
||||
f"Number of embedding chunks = {num_embed_chunks} should be "
|
||||
+ f"divisible by the embedding dimension = {embed_dim}"
|
||||
)
|
||||
assert path is None or num_embed_chunks == 1, (
|
||||
"Loading embedding from a path with number of embedding chunks > 1"
|
||||
+ " is not yet supported"
|
||||
)
|
||||
num_embeddings = len(dictionary)
|
||||
padding_idx = dictionary.pad()
|
||||
# if provided, load from preloaded dictionaries
|
||||
if path:
|
||||
emb = Embedding(num_embeddings, embed_dim, padding_idx)
|
||||
embed_dict = utils.parse_embedding(path)
|
||||
utils.load_embedding(embed_dict, dictionary, emb)
|
||||
else:
|
||||
embed_chunk_dim = embed_dim // num_embed_chunks
|
||||
emb = nn.ModuleList()
|
||||
for i in range(num_embed_chunks):
|
||||
emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx))
|
||||
return emb
|
||||
|
||||
num_embed_chunks = args.num_embedding_chunks
|
||||
if args.share_all_embeddings:
|
||||
if src_dict != tgt_dict:
|
||||
raise ValueError("--share-all-embeddings requires a joined dictionary")
|
||||
if args.encoder_embed_dim != args.decoder_embed_dim:
|
||||
raise ValueError(
|
||||
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
|
||||
)
|
||||
if args.decoder_embed_path and (
|
||||
args.decoder_embed_path != args.encoder_embed_path
|
||||
):
|
||||
raise ValueError(
|
||||
"--share-all-embeddings not compatible with --decoder-embed-path"
|
||||
)
|
||||
encoder_embed_tokens = build_embedding(
|
||||
src_dict,
|
||||
args.encoder_embed_dim,
|
||||
args.encoder_embed_path,
|
||||
num_embed_chunks,
|
||||
)
|
||||
decoder_embed_tokens = encoder_embed_tokens
|
||||
args.share_decoder_input_output_embed = True
|
||||
else:
|
||||
assert args.share_decoder_input_output_embed or num_embed_chunks == 1, (
|
||||
"Not sharing decoder I/O embeddings is not yet supported with number of "
|
||||
+ "embedding chunks > 1"
|
||||
)
|
||||
encoder_embed_tokens = build_embedding(
|
||||
src_dict,
|
||||
args.encoder_embed_dim,
|
||||
args.encoder_embed_path,
|
||||
num_embed_chunks,
|
||||
)
|
||||
decoder_embed_tokens = build_embedding(
|
||||
tgt_dict,
|
||||
args.decoder_embed_dim,
|
||||
args.decoder_embed_path,
|
||||
num_embed_chunks,
|
||||
)
|
||||
|
||||
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
|
||||
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
|
||||
return (encoder, decoder)
|
||||
|
||||
@classmethod
|
||||
def build_encoder(cls, args, src_dict, embed_tokens):
|
||||
return TransformerEncoder(args, src_dict, embed_tokens)
|
||||
|
||||
@classmethod
|
||||
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
||||
return TransformerDecoder(args, tgt_dict, embed_tokens)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
encoder, decoder = cls.build_model_base(args, task)
|
||||
return PipelineParallelTransformerModel(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
balance=utils.eval_str_list(args.pipeline_balance, type=int),
|
||||
devices=utils.eval_str_list(args.pipeline_devices, type=int),
|
||||
chunks=args.pipeline_chunks,
|
||||
checkpoint=args.pipeline_checkpoint,
|
||||
)
|
||||
|
||||
def output_layer(self, features, **kwargs):
|
||||
"""Project features to the default output size (typically vocabulary size)."""
|
||||
return self.decoder.output_layer(features, **kwargs)
|
||||
|
||||
def max_positions(self):
|
||||
"""Maximum length supported by the model."""
|
||||
return (self.encoder_max_positions, self.decoder_max_positions)
|
||||
|
||||
def max_positions_helper(
|
||||
self, embedding_layer, max_positions_field="max_source_positions"
|
||||
):
|
||||
"""Maximum input length supported by the encoder or decoder."""
|
||||
if embedding_layer.embed_positions is None:
|
||||
return getattr(embedding_layer, max_positions_field)
|
||||
return min(
|
||||
getattr(embedding_layer, max_positions_field),
|
||||
embedding_layer.embed_positions.max_positions,
|
||||
)
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
||||
"""Get normalized probabilities (or log probs) from a net's output."""
|
||||
|
||||
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
|
||||
if sample is not None:
|
||||
assert "target" in sample
|
||||
target = sample["target"]
|
||||
else:
|
||||
target = None
|
||||
out = self.adaptive_softmax.get_log_prob(net_output, target=target)
|
||||
return out.exp_() if not log_probs else out
|
||||
|
||||
# A Pipe() module returns a tuple of tensors as the output.
|
||||
# In this case, the tuple has one element - the output tensor of logits
|
||||
logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0]
|
||||
if log_probs:
|
||||
return utils.log_softmax(logits, dim=-1, onnx_trace=False)
|
||||
else:
|
||||
return utils.softmax(logits, dim=-1, onnx_trace=False)
|
||||
|
||||
def max_decoder_positions(self):
|
||||
"""Maximum length supported by the decoder."""
|
||||
return self.decoder_max_positions
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True, model_cfg=None):
|
||||
"""Copies parameters and buffers from *state_dict* into this module and
|
||||
its descendants.
|
||||
|
||||
Overrides the method in :class:`nn.Module`. Compared with that method
|
||||
this additionally "upgrades" *state_dicts* from old checkpoints.
|
||||
"""
|
||||
self.upgrade_state_dict(state_dict)
|
||||
is_regular_transformer = not any("model.partitions" in k for k in state_dict)
|
||||
if is_regular_transformer:
|
||||
state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict)
|
||||
return super().load_state_dict(state_dict, strict)
|
||||
|
||||
def convert_to_pipeline_parallel_state_dict(self, state_dict):
|
||||
new_state_dict = self.state_dict()
|
||||
encoder_layer_idx = 0
|
||||
decoder_layer_idx = 0
|
||||
encoder_key_suffixes = [
|
||||
"self_attn.k_proj.weight",
|
||||
"self_attn.k_proj.bias",
|
||||
"self_attn.v_proj.weight",
|
||||
"self_attn.v_proj.bias",
|
||||
"self_attn.q_proj.weight",
|
||||
"self_attn.q_proj.bias",
|
||||
"self_attn.out_proj.weight",
|
||||
"self_attn.out_proj.bias",
|
||||
"self_attn_layer_norm.weight",
|
||||
"self_attn_layer_norm.bias",
|
||||
"fc1.weight",
|
||||
"fc1.bias",
|
||||
"fc2.weight",
|
||||
"fc2.bias",
|
||||
"final_layer_norm.weight",
|
||||
"final_layer_norm.bias",
|
||||
]
|
||||
decoder_key_suffixes = [
|
||||
"self_attn.k_proj.weight",
|
||||
"self_attn.k_proj.bias",
|
||||
"self_attn.v_proj.weight",
|
||||
"self_attn.v_proj.bias",
|
||||
"self_attn.q_proj.weight",
|
||||
"self_attn.q_proj.bias",
|
||||
"self_attn.out_proj.weight",
|
||||
"self_attn.out_proj.bias",
|
||||
"self_attn_layer_norm.weight",
|
||||
"self_attn_layer_norm.bias",
|
||||
"encoder_attn.k_proj.weight",
|
||||
"encoder_attn.k_proj.bias",
|
||||
"encoder_attn.v_proj.weight",
|
||||
"encoder_attn.v_proj.bias",
|
||||
"encoder_attn.q_proj.weight",
|
||||
"encoder_attn.q_proj.bias",
|
||||
"encoder_attn.out_proj.weight",
|
||||
"encoder_attn.out_proj.bias",
|
||||
"encoder_attn_layer_norm.weight",
|
||||
"encoder_attn_layer_norm.bias",
|
||||
"fc1.weight",
|
||||
"fc1.bias",
|
||||
"fc2.weight",
|
||||
"fc2.bias",
|
||||
"final_layer_norm.weight",
|
||||
"final_layer_norm.bias",
|
||||
]
|
||||
for pid, partition in enumerate(self.model.partitions):
|
||||
logger.info(f"Begin Partition {pid}")
|
||||
for mid, module in enumerate(partition):
|
||||
# fmt: off
|
||||
if isinstance(module, TransformerEncoderEmbedding):
|
||||
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight']
|
||||
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['encoder.embed_positions._float_tensor']
|
||||
if isinstance(module, TransformerEncoderLayer):
|
||||
for suffix in encoder_key_suffixes:
|
||||
new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}']
|
||||
encoder_layer_idx += 1
|
||||
if isinstance(module, TransformerDecoderLayer):
|
||||
for suffix in decoder_key_suffixes:
|
||||
new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'decoder.layers.{decoder_layer_idx}.{suffix}']
|
||||
decoder_layer_idx += 1
|
||||
if isinstance(module, TransformerEncoderLayerNorm):
|
||||
if 'encoder.layer_norm.weight' in state_dict:
|
||||
new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.weight'] = state_dict['encoder.layer_norm.weight']
|
||||
new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias']
|
||||
if isinstance(module, TransformerDecoderEmbedding):
|
||||
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight']
|
||||
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['decoder.embed_positions._float_tensor']
|
||||
if isinstance(module, TransformerDecoderOutputLayer):
|
||||
new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight']
|
||||
# fmt: on
|
||||
return new_state_dict
|
||||
|
||||
|
||||
class TransformerEncoder(FairseqEncoder):
|
||||
"""
|
||||
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
|
||||
is a :class:`TransformerEncoderLayer`.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
dictionary (~fairseq.data.Dictionary): encoding dictionary
|
||||
embed_tokens (torch.nn.Embedding): input embedding
|
||||
"""
|
||||
|
||||
def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None):
|
||||
super().__init__(dictionary)
|
||||
self.register_buffer("version", torch.Tensor([3]))
|
||||
import_pipe()
|
||||
self.use_pipeline = encoder_module_list is not None
|
||||
if not self.use_pipeline:
|
||||
self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens)
|
||||
self.encoder_layers = nn.Sequential(
|
||||
*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]
|
||||
)
|
||||
if isinstance(embed_tokens, nn.ModuleList):
|
||||
emb_dim = sum(e.embedding_dim for e in embed_tokens)
|
||||
else:
|
||||
emb_dim = embed_tokens.embedding_dim
|
||||
self.final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim)
|
||||
else:
|
||||
encoder_balance = utils.eval_str_list(
|
||||
args.pipeline_encoder_balance, type=int
|
||||
)
|
||||
encoder_devices = utils.eval_str_list(
|
||||
args.pipeline_encoder_devices, type=int
|
||||
)
|
||||
assert sum(encoder_balance) == len(encoder_module_list), (
|
||||
f"Sum of encoder_balance={encoder_balance} is not equal "
|
||||
+ f"to num_encoder_modules={len(encoder_module_list)}"
|
||||
)
|
||||
if TORCH_PIPE:
|
||||
self.model = Pipe(
|
||||
module=partition_model(
|
||||
nn.Sequential(*encoder_module_list),
|
||||
encoder_balance,
|
||||
encoder_devices,
|
||||
),
|
||||
chunks=args.pipeline_chunks,
|
||||
checkpoint=args.pipeline_checkpoint,
|
||||
)
|
||||
else:
|
||||
self.model = Pipe(
|
||||
module=nn.Sequential(*encoder_module_list),
|
||||
balance=encoder_balance,
|
||||
devices=encoder_devices,
|
||||
chunks=args.pipeline_chunks,
|
||||
checkpoint=args.pipeline_checkpoint,
|
||||
)
|
||||
|
||||
def forward(self, src_tokens, src_lengths):
|
||||
"""
|
||||
Args:
|
||||
input_tuple(
|
||||
src_tokens (LongTensor): tokens in the source language of shape
|
||||
`(batch, src_len)`
|
||||
src_lengths (torch.LongTensor): lengths of each source sentence of
|
||||
shape `(batch)`
|
||||
)
|
||||
|
||||
Returns:
|
||||
output_tuple(
|
||||
- **encoder_out** (Tensor): the last encoder layer's output of
|
||||
shape `(src_len, batch, embed_dim)`
|
||||
- **encoder_padding_mask** (ByteTensor): the positions of
|
||||
padding elements of shape `(batch, src_len)`
|
||||
- prev_output_tokens
|
||||
- **encoder_states** (List[Tensor]): all intermediate
|
||||
hidden states of shape `(src_len, batch, embed_dim)`.
|
||||
Only populated if *return_all_hiddens* is True.
|
||||
)
|
||||
"""
|
||||
dummy_prev_output_tokens = torch.zeros(
|
||||
1, dtype=src_tokens.dtype, device=src_tokens.device
|
||||
)
|
||||
input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens)
|
||||
if self.use_pipeline:
|
||||
input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
|
||||
if TORCH_PIPE:
|
||||
encoder_out = self.model(input_tuple).local_value()
|
||||
else:
|
||||
encoder_out = self.model(input_tuple)
|
||||
else:
|
||||
encoder_embed_output_tuple = self.embedding_layer(input_tuple)
|
||||
encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple)
|
||||
encoder_out = self.final_layer_norm(encoder_layers_output)
|
||||
# first element is the encoder output
|
||||
# second element is the encoder padding mask
|
||||
# the remaining elements of EncoderOut are not computed by
|
||||
# the PipelineParallelTransformer
|
||||
return EncoderOut(encoder_out[0], encoder_out[1], None, None, None, None)
|
||||
|
||||
def reorder_encoder_out(self, encoder_out, new_order):
|
||||
"""
|
||||
Reorder encoder output according to *new_order*.
|
||||
|
||||
Args:
|
||||
encoder_out: output from the ``forward()`` method
|
||||
new_order (LongTensor): desired order
|
||||
|
||||
Returns:
|
||||
*encoder_out* rearranged according to *new_order*
|
||||
"""
|
||||
if encoder_out.encoder_out is not None:
|
||||
encoder_out = encoder_out._replace(
|
||||
encoder_out=encoder_out.encoder_out.index_select(1, new_order)
|
||||
)
|
||||
if encoder_out.encoder_padding_mask is not None:
|
||||
encoder_out = encoder_out._replace(
|
||||
encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(
|
||||
0, new_order
|
||||
)
|
||||
)
|
||||
if encoder_out.encoder_embedding is not None:
|
||||
encoder_out = encoder_out._replace(
|
||||
encoder_embedding=encoder_out.encoder_embedding.index_select(
|
||||
0, new_order
|
||||
)
|
||||
)
|
||||
if encoder_out.encoder_states is not None:
|
||||
for idx, state in enumerate(encoder_out.encoder_states):
|
||||
encoder_out.encoder_states[idx] = state.index_select(1, new_order)
|
||||
return encoder_out
|
||||
|
||||
def max_positions(self):
|
||||
"""Maximum input length supported by the encoder."""
|
||||
if self.embedding_layer.embed_positions is None:
|
||||
return self.embedding_layer.max_source_positions
|
||||
return min(
|
||||
self.embedding_layer.max_source_positions,
|
||||
self.embedding_layer.embed_positions.max_positions,
|
||||
)
|
||||
|
||||
|
||||
class TransformerDecoder(FairseqDecoder):
|
||||
"""
|
||||
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
||||
is a :class:`TransformerDecoderLayer`.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
||||
embed_tokens (torch.nn.Embedding): output embedding
|
||||
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
dictionary,
|
||||
embed_tokens,
|
||||
no_encoder_attn=False,
|
||||
decoder_module_list=None,
|
||||
):
|
||||
super().__init__(dictionary)
|
||||
self.register_buffer("version", torch.Tensor([3]))
|
||||
import_pipe()
|
||||
self.use_pipeline = decoder_module_list is not None
|
||||
if not self.use_pipeline:
|
||||
self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
|
||||
self.decoder_layers = nn.Sequential(
|
||||
*[
|
||||
TransformerDecoderLayer(args, no_encoder_attn)
|
||||
for _ in range(args.decoder_layers)
|
||||
]
|
||||
)
|
||||
self.decoder_output_layer = TransformerDecoderOutputLayer(
|
||||
args, embed_tokens, dictionary
|
||||
)
|
||||
else:
|
||||
decoder_balance = utils.eval_str_list(
|
||||
args.pipeline_decoder_balance, type=int
|
||||
)
|
||||
decoder_devices = utils.eval_str_list(
|
||||
args.pipeline_decoder_devices, type=int
|
||||
)
|
||||
assert sum(decoder_balance) == len(decoder_module_list), (
|
||||
f"Sum of decoder_balance={decoder_balance} is not equal "
|
||||
+ f"to num_decoder_modules={len(decoder_module_list)}"
|
||||
)
|
||||
if TORCH_PIPE:
|
||||
self.model = Pipe(
|
||||
module=partition_model(
|
||||
nn.Sequential(*decoder_module_list),
|
||||
decoder_balance,
|
||||
decoder_devices,
|
||||
),
|
||||
chunks=args.pipeline_chunks,
|
||||
checkpoint=args.pipeline_checkpoint,
|
||||
)
|
||||
else:
|
||||
self.model = Pipe(
|
||||
module=nn.Sequential(*decoder_module_list),
|
||||
balance=decoder_balance,
|
||||
devices=decoder_devices,
|
||||
chunks=args.pipeline_chunks,
|
||||
checkpoint=args.pipeline_checkpoint,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
encoder_out=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
||||
`(batch, tgt_len)`, for teacher forcing
|
||||
encoder_out (optional): output from the encoder, used for
|
||||
encoder-side attention
|
||||
incremental_state (dict): dictionary used for storing state during
|
||||
:ref:`Incremental decoding`
|
||||
features_only (bool, optional): only return features without
|
||||
applying output layer (default: False).
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
||||
- a dictionary with any model-specific outputs
|
||||
"""
|
||||
input_tuple = (
|
||||
encoder_out.encoder_out,
|
||||
encoder_out.encoder_padding_mask,
|
||||
prev_output_tokens,
|
||||
)
|
||||
if self.use_pipeline:
|
||||
input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
|
||||
if TORCH_PIPE:
|
||||
return (self.model(input_tuple).local_value(),)
|
||||
else:
|
||||
return (self.model(input_tuple),)
|
||||
else:
|
||||
embed_layer_output = self.embedding_layer(input_tuple)
|
||||
state = self.decoder_layers(embed_layer_output)
|
||||
return (self.decoder_output_layer(state),)
|
||||
|
||||
def output_layer(self, features, **kwargs):
|
||||
"""Project features to the vocabulary size."""
|
||||
if self.adaptive_softmax is None:
|
||||
# project back to size of vocabulary
|
||||
if self.share_input_output_embed:
|
||||
return F.linear(features, self.embed_tokens.weight)
|
||||
else:
|
||||
return F.linear(features, self.embed_out)
|
||||
else:
|
||||
return features
|
||||
|
||||
def max_positions(self):
|
||||
"""Maximum output length supported by the decoder."""
|
||||
if self.embedding_layer.embed_positions is None:
|
||||
return self.embedding_layer.max_target_positions
|
||||
return min(
|
||||
self.embedding_layer.max_target_positions,
|
||||
self.embedding_layer.embed_positions.max_positions,
|
||||
)
|
||||
|
||||
def buffered_future_mask(self, tensor):
|
||||
dim = tensor.size(0)
|
||||
if (
|
||||
not hasattr(self, "_future_mask")
|
||||
or self._future_mask is None
|
||||
or self._future_mask.device != tensor.device
|
||||
or self._future_mask.size(0) < dim
|
||||
):
|
||||
self._future_mask = torch.triu(
|
||||
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
|
||||
)
|
||||
return self._future_mask[:dim, :dim]
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||||
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
|
||||
weights_key = "{}.embed_positions.weights".format(name)
|
||||
if weights_key in state_dict:
|
||||
del state_dict[weights_key]
|
||||
state_dict[
|
||||
"{}.embed_positions._float_tensor".format(name)
|
||||
] = torch.FloatTensor(1)
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
# update layer norms
|
||||
layer_norm_map = {
|
||||
"0": "self_attn_layer_norm",
|
||||
"1": "encoder_attn_layer_norm",
|
||||
"2": "final_layer_norm",
|
||||
}
|
||||
for old, new in layer_norm_map.items():
|
||||
for m in ("weight", "bias"):
|
||||
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
|
||||
if k in state_dict:
|
||||
state_dict[
|
||||
"{}.layers.{}.{}.{}".format(name, i, new, m)
|
||||
] = state_dict[k]
|
||||
del state_dict[k]
|
||||
|
||||
version_key = "{}.version".format(name)
|
||||
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
|
||||
# earlier checkpoints did not normalize after the stack of layers
|
||||
self.layer_norm = None
|
||||
self.normalize = False
|
||||
state_dict[version_key] = torch.Tensor([1])
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
@register_model_architecture(
|
||||
"pipeline_parallel_transformer", "transformer_iwslt_de_en_pipeline_parallel"
|
||||
)
|
||||
def transformer_iwslt_de_en_dist(args):
|
||||
transformer_iwslt_de_en(args)
|
||||
|
||||
|
||||
@register_model_architecture(
|
||||
"pipeline_parallel_transformer", "transformer_wmt_en_de_big_pipeline_parallel"
|
||||
)
|
||||
def transformer_wmt_en_de_big_dist(args):
|
||||
transformer_wmt_en_de_big(args)
|
||||
@@ -1,6 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .model import * # noqa
|
||||
@@ -1,225 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq import utils
|
||||
from fairseq.model_parallel.models.transformer import ModelParallelTransformerEncoder
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.roberta import (
|
||||
roberta_base_architecture,
|
||||
roberta_prenorm_architecture,
|
||||
RobertaEncoder,
|
||||
RobertaModel,
|
||||
)
|
||||
from fairseq.modules import LayerNorm
|
||||
|
||||
|
||||
try:
|
||||
from fairseq.model_parallel.megatron.mpu import (
|
||||
copy_to_model_parallel_region,
|
||||
gather_from_model_parallel_region,
|
||||
ColumnParallelLinear,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
|
||||
has_megatron_submodule = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
has_megatron_submodule = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_model("model_parallel_roberta")
|
||||
class ModelParallelRobertaModel(RobertaModel):
|
||||
def __init__(self, args, encoder):
|
||||
super().__init__(args, encoder)
|
||||
|
||||
self.classification_heads = nn.ModuleDict()
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
RobertaModel.add_args(parser)
|
||||
parser.add_argument(
|
||||
"--no-final-layer-norm",
|
||||
action="store_true",
|
||||
help=(
|
||||
"don't add final layernorm (only applicable when "
|
||||
"--encoder-normalize-before=True"
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
|
||||
# make sure all arguments are present
|
||||
base_architecture(args)
|
||||
|
||||
task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
|
||||
task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
|
||||
|
||||
if not hasattr(args, "max_positions"):
|
||||
args.max_positions = args.tokens_per_sample
|
||||
|
||||
if getattr(args, "untie_weights_roberta", False):
|
||||
raise NotImplementedError(
|
||||
"--untie-weights-roberta is not supported in model parallel mode"
|
||||
)
|
||||
|
||||
encoder = ModelParallelRobertaEncoder(args, task.source_dictionary)
|
||||
return cls(args, encoder)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src_tokens,
|
||||
features_only=False,
|
||||
return_all_hiddens=False,
|
||||
classification_head_name=None,
|
||||
**kwargs
|
||||
):
|
||||
if classification_head_name is not None:
|
||||
features_only = True
|
||||
|
||||
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
|
||||
|
||||
if classification_head_name is not None:
|
||||
x = self.classification_heads[classification_head_name](x)
|
||||
return x, extra
|
||||
|
||||
def register_classification_head(
|
||||
self, name, num_classes=None, inner_dim=None, **kwargs
|
||||
):
|
||||
"""Register a classification head."""
|
||||
if name in self.classification_heads:
|
||||
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
||||
prev_inner_dim = self.classification_heads[name].dense.out_features
|
||||
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
||||
logger.warning(
|
||||
're-registering head "{}" with num_classes {} (prev: {}) '
|
||||
"and inner_dim {} (prev: {})".format(
|
||||
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
||||
)
|
||||
)
|
||||
self.classification_heads[name] = ModelParallelRobertaClassificationHead(
|
||||
self.args.encoder_embed_dim,
|
||||
inner_dim or self.args.encoder_embed_dim,
|
||||
num_classes,
|
||||
self.args.pooler_activation_fn,
|
||||
self.args.pooler_dropout,
|
||||
)
|
||||
|
||||
|
||||
class ModelParallelRobertaLMHead(nn.Module):
|
||||
"""Head for masked language modeling."""
|
||||
|
||||
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
|
||||
super().__init__()
|
||||
self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True)
|
||||
self.activation_fn = utils.get_activation_fn(activation_fn)
|
||||
self.layer_norm = LayerNorm(embed_dim)
|
||||
|
||||
if weight is None:
|
||||
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
|
||||
self.weight = weight
|
||||
self.bias = nn.Parameter(torch.zeros(output_dim))
|
||||
|
||||
def forward(self, features, masked_tokens=None, **kwargs):
|
||||
# Only project the unmasked tokens while training,
|
||||
# saves both memory and computation
|
||||
if masked_tokens is not None:
|
||||
features = features[masked_tokens, :]
|
||||
|
||||
x = self.dense(features)
|
||||
x = self.activation_fn(x)
|
||||
x = self.layer_norm(x)
|
||||
|
||||
x = copy_to_model_parallel_region(x)
|
||||
# project back to size of vocabulary with bias
|
||||
x = F.linear(x, self.weight)
|
||||
x = gather_from_model_parallel_region(x).contiguous()
|
||||
x = x + self.bias
|
||||
return x
|
||||
|
||||
|
||||
class ModelParallelRobertaClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
def __init__(
|
||||
self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout
|
||||
):
|
||||
super().__init__()
|
||||
self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True)
|
||||
self.activation_fn = utils.get_activation_fn(activation_fn)
|
||||
self.dropout = nn.Dropout(p=pooler_dropout)
|
||||
self.out_proj = nn.Linear(inner_dim, num_classes)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
||||
x = self.dropout(x)
|
||||
x = self.dense(x)
|
||||
x = self.activation_fn(x)
|
||||
x = self.dropout(x)
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ModelParallelRobertaEncoder(RobertaEncoder):
|
||||
"""RoBERTa encoder."""
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(args, dictionary)
|
||||
assert not self.args.untie_weights_roberta
|
||||
|
||||
def build_embedding(self, vocab_size, embedding_dim, padding_idx):
|
||||
return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx)
|
||||
|
||||
def build_encoder(self, args, dictionary, embed_tokens):
|
||||
return ModelParallelTransformerEncoder(args, dictionary, embed_tokens)
|
||||
|
||||
def build_lm_head(self, embed_dim, output_dim, activation_fn, weight):
|
||||
return ModelParallelRobertaLMHead(embed_dim, output_dim, activation_fn, weight)
|
||||
|
||||
|
||||
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta")
|
||||
def base_architecture(args):
|
||||
args.no_final_layer_norm = getattr(args, "no_final_layer_norm", False)
|
||||
# model parallel RoBERTa defaults to "Pre-LN" formulation
|
||||
roberta_prenorm_architecture(args)
|
||||
|
||||
|
||||
# earlier versions of model parallel RoBERTa removed the final layer norm
|
||||
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_v1")
|
||||
def model_parallel_roberta_v1_architecture(args):
|
||||
args.no_final_layer_norm = getattr(args, "no_final_layer_norm", True)
|
||||
base_architecture(args)
|
||||
|
||||
|
||||
@register_model_architecture(
|
||||
"model_parallel_roberta", "model_parallel_roberta_postnorm"
|
||||
)
|
||||
def model_parallel_roberta_postnorm_architecture(args):
|
||||
# the original BERT/RoBERTa uses the "Post-LN" formulation
|
||||
roberta_base_architecture(args)
|
||||
|
||||
|
||||
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base")
|
||||
def model_parallel_roberta_base_architecture(args):
|
||||
base_architecture(args)
|
||||
|
||||
|
||||
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large")
|
||||
def model_parallel_roberta_large_architecture(args):
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 24)
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
||||
base_architecture(args)
|
||||
@@ -1,121 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from fairseq.model_parallel.modules import (
|
||||
ModelParallelTransformerDecoderLayer,
|
||||
ModelParallelTransformerEncoderLayer,
|
||||
)
|
||||
from fairseq.models import register_model
|
||||
from fairseq.models.transformer import (
|
||||
TransformerDecoder,
|
||||
TransformerEncoder,
|
||||
TransformerModel,
|
||||
)
|
||||
|
||||
try:
|
||||
from fairseq.model_parallel.megatron.mpu import (
|
||||
VocabParallelEmbedding,
|
||||
copy_to_model_parallel_region,
|
||||
gather_from_model_parallel_region,
|
||||
)
|
||||
|
||||
has_megatron_submodule = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
has_megatron_submodule = False
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_model("model_parallel_transformer")
|
||||
class ModelParallelTransformerModel(TransformerModel):
|
||||
"""
|
||||
Model parallel Transformer model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
||||
if not has_megatron_submodule:
|
||||
raise ImportError(
|
||||
"\n\nPlease install the megatron submodule:"
|
||||
"\n\n git submodule update --init "
|
||||
"fairseq/model_parallel/megatron"
|
||||
)
|
||||
dictionary.pad_to_multiple_(args.model_parallel_size * 8)
|
||||
num_embeddings = len(dictionary)
|
||||
padding_idx = dictionary.pad()
|
||||
|
||||
def _vocab_init(tensor, **kwargs):
|
||||
nn.init.normal_(tensor, mean=0, std=num_embeddings**-0.5)
|
||||
nn.init.constant_(tensor[1], 0)
|
||||
|
||||
emb = VocabParallelEmbedding(
|
||||
num_embeddings, embed_dim, padding_idx, init_method=_vocab_init
|
||||
)
|
||||
# if provided, load from preloaded dictionaries
|
||||
if path:
|
||||
raise NotImplementedError(
|
||||
"Loading of embedding from path is not supported for model parallel"
|
||||
)
|
||||
return emb
|
||||
|
||||
@classmethod
|
||||
def build_encoder(cls, args, src_dict, embed_tokens):
|
||||
return ModelParallelTransformerEncoder(args, src_dict, embed_tokens)
|
||||
|
||||
@classmethod
|
||||
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
||||
return ModelParallelTransformerDecoder(
|
||||
args,
|
||||
tgt_dict,
|
||||
embed_tokens,
|
||||
no_encoder_attn=getattr(args, "no_cross_attention", False),
|
||||
)
|
||||
|
||||
|
||||
class ModelParallelTransformerEncoder(TransformerEncoder):
|
||||
"""
|
||||
Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer
|
||||
is a :class:`ModelParallelTransformerEncoderLayer`.
|
||||
"""
|
||||
|
||||
def __init__(self, args, dictionary, embed_tokens):
|
||||
super().__init__(args, dictionary, embed_tokens)
|
||||
|
||||
if args.no_final_layer_norm:
|
||||
self.layer_norm = None
|
||||
|
||||
def build_encoder_layer(self, args):
|
||||
return ModelParallelTransformerEncoderLayer(args)
|
||||
|
||||
|
||||
class ModelParallelTransformerDecoder(TransformerDecoder):
|
||||
"""
|
||||
Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
||||
is a :class:`ModelParallelTransformerDecoderLayer`.
|
||||
"""
|
||||
|
||||
def build_decoder_layer(self, args, no_encoder_attn=False):
|
||||
return ModelParallelTransformerDecoderLayer(args, no_encoder_attn)
|
||||
|
||||
def output_layer(self, features, **kwargs):
|
||||
"""Project features to the vocabulary size."""
|
||||
if not self.share_input_output_embed:
|
||||
raise NotImplementedError(
|
||||
"Model parallel training currently requires --share-decoder-input-output-embed"
|
||||
)
|
||||
|
||||
features = copy_to_model_parallel_region(features)
|
||||
|
||||
# project back to size of vocabulary
|
||||
x = self.output_projection(features)
|
||||
|
||||
if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy":
|
||||
x = gather_from_model_parallel_region(x).contiguous()
|
||||
return x
|
||||
@@ -1,169 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.transformer_lm import TransformerLanguageModel
|
||||
|
||||
try:
|
||||
from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding
|
||||
|
||||
has_megatron_submodule = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
has_megatron_submodule = False
|
||||
|
||||
|
||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||
|
||||
|
||||
@register_model("model_parallel_transformer_lm")
|
||||
class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
TransformerLanguageModel.add_args(parser)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
if not has_megatron_submodule:
|
||||
raise ImportError(
|
||||
"\n\nPlease install the megatron submodule:"
|
||||
"\n\n git submodule update --init "
|
||||
"fairseq/model_parallel/megatron"
|
||||
)
|
||||
|
||||
# make sure all arguments are present in older models
|
||||
base_lm_architecture(args)
|
||||
|
||||
task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
|
||||
task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
|
||||
|
||||
if args.decoder_layers_to_keep:
|
||||
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
|
||||
|
||||
if getattr(args, "max_target_positions", None) is None:
|
||||
args.max_target_positions = getattr(
|
||||
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
|
||||
)
|
||||
|
||||
if args.character_embeddings:
|
||||
raise NotImplementedError(
|
||||
"Character embeddings is not supported for model parallel"
|
||||
)
|
||||
elif args.adaptive_input:
|
||||
raise NotImplementedError(
|
||||
"Adaptive input is not supported for model parallel"
|
||||
)
|
||||
else:
|
||||
embed_tokens = cls.build_embedding(
|
||||
args, task.source_dictionary, args.decoder_input_dim
|
||||
)
|
||||
|
||||
decoder = ModelParallelTransformerDecoder(
|
||||
args,
|
||||
task.target_dictionary,
|
||||
embed_tokens,
|
||||
no_encoder_attn=True,
|
||||
)
|
||||
return cls(decoder)
|
||||
|
||||
@classmethod
|
||||
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
||||
def _vocab_init(tensor, **kwargs):
|
||||
nn.init.normal_(tensor, mean=0, std=embed_dim**-0.5)
|
||||
nn.init.constant_(tensor[1], 0)
|
||||
|
||||
embed_tokens = VocabParallelEmbedding(
|
||||
len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init
|
||||
)
|
||||
return embed_tokens
|
||||
|
||||
|
||||
def base_lm_architecture(args):
|
||||
# backward compatibility for older model checkpoints
|
||||
if hasattr(args, "no_tie_adaptive_proj"):
|
||||
# previous models defined --no-tie-adaptive-proj, so use the existence of
|
||||
# that option to determine if this is an "old" model checkpoint
|
||||
args.no_decoder_final_norm = True # old models always set this to True
|
||||
if args.no_tie_adaptive_proj is False:
|
||||
args.tie_adaptive_proj = True
|
||||
if hasattr(args, "decoder_final_norm"):
|
||||
args.no_decoder_final_norm = not args.decoder_final_norm
|
||||
|
||||
args.activation_fn = getattr(args, "activation_fn", "relu")
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
||||
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
||||
args.relu_dropout = getattr(args, "relu_dropout", 0.0)
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
||||
args.decoder_output_dim = getattr(
|
||||
args, "decoder_output_dim", args.decoder_embed_dim
|
||||
)
|
||||
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
|
||||
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
||||
# Model training is not stable without this
|
||||
args.decoder_normalize_before = True
|
||||
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
|
||||
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
||||
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
||||
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
|
||||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
args.share_decoder_input_output_embed = getattr(
|
||||
args, "share_decoder_input_output_embed", False
|
||||
)
|
||||
args.character_embeddings = getattr(args, "character_embeddings", False)
|
||||
args.character_filters = getattr(
|
||||
args,
|
||||
"character_filters",
|
||||
"[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]",
|
||||
)
|
||||
args.character_embedding_dim = getattr(args, "character_embedding_dim", 4)
|
||||
args.char_embedder_highway_layers = getattr(args, "char_embedder_highway_layers", 2)
|
||||
args.adaptive_input = getattr(args, "adaptive_input", False)
|
||||
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
|
||||
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
|
||||
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
||||
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
|
||||
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
||||
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
|
||||
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
|
||||
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
||||
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
||||
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0.0)
|
||||
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
|
||||
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0.0)
|
||||
args.add_bos_token = getattr(args, "add_bos_token", False)
|
||||
|
||||
|
||||
@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_megatron")
|
||||
def transformer_lm_megatron(args):
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072)
|
||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 4)
|
||||
args.decoder_layers = getattr(args, "decoder_layers", 72)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
base_lm_architecture(args)
|
||||
|
||||
|
||||
@register_model_architecture(
|
||||
"model_parallel_transformer_lm", "transformer_lm_megatron_11b"
|
||||
)
|
||||
def transformer_lm_megatron_11b(args):
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072)
|
||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 6)
|
||||
args.decoder_layers = getattr(args, "decoder_layers", 72)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
base_lm_architecture(args)
|
||||
@@ -1,17 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""isort:skip_file"""
|
||||
|
||||
from .multihead_attention import ModelParallelMultiheadAttention
|
||||
from .transformer_layer import (
|
||||
ModelParallelTransformerEncoderLayer,
|
||||
ModelParallelTransformerDecoderLayer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ModelParallelMultiheadAttention",
|
||||
"ModelParallelTransformerEncoderLayer",
|
||||
"ModelParallelTransformerDecoderLayer",
|
||||
]
|
||||
@@ -1,349 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.incremental_decoding_utils import with_incremental_state
|
||||
from fairseq.modules.fairseq_dropout import FairseqDropout
|
||||
|
||||
try:
|
||||
from fairseq.model_parallel.megatron.mpu import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
get_cuda_rng_tracker,
|
||||
get_model_parallel_world_size,
|
||||
)
|
||||
|
||||
has_megatron_submodule = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
has_megatron_submodule = False
|
||||
|
||||
|
||||
@with_incremental_state
|
||||
class ModelParallelMultiheadAttention(nn.Module):
|
||||
"""Model parallel Multi-headed attention.
|
||||
This performs the Multi-headed attention over multiple gpus.
|
||||
|
||||
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
self_attention=False,
|
||||
encoder_decoder_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
if not has_megatron_submodule:
|
||||
raise ImportError(
|
||||
"\n\nPlease install the megatron submodule:"
|
||||
"\n\n git submodule update --init "
|
||||
"fairseq/model_parallel/megatron"
|
||||
)
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.model_parallel_size = get_model_parallel_world_size()
|
||||
|
||||
self.num_heads_partition = num_heads // self.model_parallel_size
|
||||
assert (
|
||||
self.num_heads_partition * self.model_parallel_size == num_heads
|
||||
), "Number of heads must be divisible by model parallel size"
|
||||
|
||||
self.dropout_module = FairseqDropout(
|
||||
dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.self_attention = self_attention
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
|
||||
assert (
|
||||
not self.self_attention or self.qkv_same_dim
|
||||
), "Self-attention requires query, key and value to be of the same size"
|
||||
|
||||
self.k_proj = ColumnParallelLinear(
|
||||
self.kdim, embed_dim, bias=bias, gather_output=False
|
||||
)
|
||||
self.v_proj = ColumnParallelLinear(
|
||||
self.vdim, embed_dim, bias=bias, gather_output=False
|
||||
)
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
embed_dim, embed_dim, bias=bias, gather_output=False
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
embed_dim, embed_dim, bias=bias, input_is_parallel=True
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key: Optional[Tensor],
|
||||
value: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
static_kv: bool = False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
**unused_kwargs,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""Input shape: Time x Batch x Channel
|
||||
|
||||
Args:
|
||||
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||
keys that are pads, of shape `(batch, src_len)`, where
|
||||
padding elements are indicated by 1s.
|
||||
attn_mask (ByteTensor, optional): typically used to
|
||||
implement causal attention, where the mask prevents the
|
||||
attention from looking forward in time (default: None).
|
||||
"""
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
|
||||
is_tpu = query.device.type == "xla"
|
||||
|
||||
if incremental_state is not None:
|
||||
saved_state = self._get_input_buffer(incremental_state)
|
||||
if saved_state is not None and "prev_key" in saved_state:
|
||||
# previous time steps are cached - no need to recompute
|
||||
# key and value if they are static
|
||||
if static_kv:
|
||||
assert self.encoder_decoder_attention and not self.self_attention
|
||||
key = value = None
|
||||
else:
|
||||
saved_state = None
|
||||
|
||||
if self.self_attention:
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(query)
|
||||
v = self.v_proj(query)
|
||||
elif self.encoder_decoder_attention:
|
||||
# encoder-decoder attention
|
||||
q = self.q_proj(query)
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = v = None
|
||||
else:
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(key)
|
||||
|
||||
else:
|
||||
assert key is not None and value is not None
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
q *= self.scaling
|
||||
|
||||
q = (
|
||||
q.contiguous()
|
||||
.view(tgt_len, bsz * self.num_heads_partition, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
if k is not None:
|
||||
k = (
|
||||
k.contiguous()
|
||||
.view(-1, bsz * self.num_heads_partition, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
if v is not None:
|
||||
v = (
|
||||
v.contiguous()
|
||||
.view(-1, bsz * self.num_heads_partition, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
if saved_state is not None:
|
||||
# saved states are stored with shape (bsz, num_heads_partition, seq_len, head_dim)
|
||||
if "prev_key" in saved_state:
|
||||
_prev_key = saved_state["prev_key"]
|
||||
assert _prev_key is not None
|
||||
prev_key = _prev_key.view(
|
||||
bsz * self.num_heads_partition, -1, self.head_dim
|
||||
)
|
||||
if static_kv:
|
||||
k = prev_key
|
||||
else:
|
||||
assert k is not None
|
||||
k = torch.cat([prev_key, k], dim=1)
|
||||
if "prev_value" in saved_state:
|
||||
_prev_value = saved_state["prev_value"]
|
||||
assert _prev_value is not None
|
||||
prev_value = _prev_value.view(
|
||||
bsz * self.num_heads_partition, -1, self.head_dim
|
||||
)
|
||||
if static_kv:
|
||||
v = prev_value
|
||||
else:
|
||||
assert v is not None
|
||||
v = torch.cat([prev_value, v], dim=1)
|
||||
prev_key_padding_mask: Optional[Tensor] = None
|
||||
if "prev_key_padding_mask" in saved_state:
|
||||
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||||
assert k is not None and v is not None
|
||||
key_padding_mask = (
|
||||
ModelParallelMultiheadAttention._append_prev_key_padding_mask(
|
||||
key_padding_mask=key_padding_mask,
|
||||
prev_key_padding_mask=prev_key_padding_mask,
|
||||
batch_size=bsz,
|
||||
src_len=k.size(1),
|
||||
static_kv=static_kv,
|
||||
)
|
||||
)
|
||||
|
||||
saved_state["prev_key"] = k.view(
|
||||
bsz, self.num_heads_partition, -1, self.head_dim
|
||||
)
|
||||
saved_state["prev_value"] = v.view(
|
||||
bsz, self.num_heads_partition, -1, self.head_dim
|
||||
)
|
||||
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||
# In this branch incremental_state is never None
|
||||
assert incremental_state is not None
|
||||
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||
assert k is not None
|
||||
src_len = k.size(1)
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism
|
||||
# not supporting Optional types.
|
||||
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||
key_padding_mask = None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
|
||||
assert list(attn_weights.size()) == [
|
||||
bsz * self.num_heads_partition,
|
||||
tgt_len,
|
||||
src_len,
|
||||
]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# don't attend to padding symbols
|
||||
attn_weights = attn_weights.view(
|
||||
bsz, self.num_heads_partition, tgt_len, src_len
|
||||
)
|
||||
if not is_tpu:
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||
float("-inf"),
|
||||
)
|
||||
else:
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
||||
attn_weights = attn_weights.transpose(0, 2)
|
||||
attn_weights = attn_weights.view(
|
||||
bsz * self.num_heads_partition, tgt_len, src_len
|
||||
)
|
||||
|
||||
attn_weights_float = utils.softmax(attn_weights, dim=-1)
|
||||
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||
|
||||
with get_cuda_rng_tracker().fork():
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
assert v is not None
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
assert list(attn.size()) == [
|
||||
bsz * self.num_heads_partition,
|
||||
tgt_len,
|
||||
self.head_dim,
|
||||
]
|
||||
embed_dim_partition = embed_dim // self.model_parallel_size
|
||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim_partition)
|
||||
attn = self.out_proj(attn)
|
||||
# return attn_weights None to keep the return type same as single gpu multihead attention
|
||||
# This will be deprecated.
|
||||
attn_weights: Optional[Tensor] = None
|
||||
|
||||
return attn, attn_weights
|
||||
|
||||
@staticmethod
|
||||
def _append_prev_key_padding_mask(
|
||||
key_padding_mask: Optional[Tensor],
|
||||
prev_key_padding_mask: Optional[Tensor],
|
||||
batch_size: int,
|
||||
src_len: int,
|
||||
static_kv: bool,
|
||||
) -> Optional[Tensor]:
|
||||
# saved key padding masks have shape (bsz, seq_len)
|
||||
if prev_key_padding_mask is not None and static_kv:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||
new_key_padding_mask = torch.cat(
|
||||
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
||||
)
|
||||
# During incremental decoding, as the padding token enters and
|
||||
# leaves the frame, there will be a time when prev or current
|
||||
# is None
|
||||
elif prev_key_padding_mask is not None:
|
||||
|
||||
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
|
||||
if prev_key_padding_mask.is_cuda:
|
||||
filler = filler.cuda()
|
||||
new_key_padding_mask = torch.cat(
|
||||
[prev_key_padding_mask.float(), filler.float()], dim=1
|
||||
)
|
||||
elif key_padding_mask is not None:
|
||||
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
|
||||
if key_padding_mask.is_cuda:
|
||||
filler = filler.cuda()
|
||||
new_key_padding_mask = torch.cat(
|
||||
[filler.float(), key_padding_mask.float()], dim=1
|
||||
)
|
||||
else:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
return new_key_padding_mask
|
||||
|
||||
def reorder_incremental_state(
|
||||
self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order
|
||||
):
|
||||
"""Reorder buffered internal state (for incremental generation)."""
|
||||
input_buffer = self._get_input_buffer(incremental_state)
|
||||
if input_buffer is not None:
|
||||
for k in input_buffer.keys():
|
||||
if input_buffer[k] is not None:
|
||||
input_buffer[k] = input_buffer[k].index_select(0, new_order)
|
||||
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
||||
return incremental_state
|
||||
|
||||
def _get_input_buffer(
|
||||
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||||
) -> Dict[str, Optional[Tensor]]:
|
||||
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
empty_result: Dict[str, Optional[Tensor]] = {}
|
||||
return empty_result
|
||||
|
||||
def _set_input_buffer(
|
||||
self,
|
||||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||
buffer: Dict[str, Optional[Tensor]],
|
||||
):
|
||||
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||
@@ -1,78 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from fairseq.model_parallel.modules import ModelParallelMultiheadAttention
|
||||
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
|
||||
|
||||
|
||||
try:
|
||||
from fairseq.model_parallel.megatron.mpu import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
|
||||
has_megatron_submodule = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
has_megatron_submodule = False
|
||||
|
||||
|
||||
class ModelParallelTransformerEncoderLayer(TransformerEncoderLayer):
|
||||
"""Encoder layer block over multiple gpus.
|
||||
|
||||
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
|
||||
"""
|
||||
|
||||
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
|
||||
if q_noise > 0:
|
||||
raise NotImplementedError
|
||||
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
|
||||
|
||||
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
|
||||
if q_noise > 0:
|
||||
raise NotImplementedError
|
||||
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
|
||||
|
||||
def build_self_attention(self, embed_dim, args, **unused_kwargs):
|
||||
return ModelParallelMultiheadAttention(
|
||||
embed_dim,
|
||||
args.encoder_attention_heads,
|
||||
dropout=args.attention_dropout,
|
||||
self_attention=True,
|
||||
)
|
||||
|
||||
|
||||
class ModelParallelTransformerDecoderLayer(TransformerDecoderLayer):
|
||||
"""Decoder layer block.
|
||||
|
||||
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
|
||||
"""
|
||||
|
||||
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
|
||||
if q_noise > 0:
|
||||
raise NotImplementedError
|
||||
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
|
||||
|
||||
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
|
||||
if q_noise > 0:
|
||||
raise NotImplementedError
|
||||
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
|
||||
|
||||
def build_self_attention(self, embed_dim, args, **unused_kwargs):
|
||||
return ModelParallelMultiheadAttention(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=args.decoder_attention_heads,
|
||||
dropout=args.attention_dropout,
|
||||
self_attention=not getattr(args, "cross_self_attention", False),
|
||||
)
|
||||
|
||||
def build_encoder_attention(self, embed_dim, args, **unused_kwargs):
|
||||
return ModelParallelMultiheadAttention(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=args.decoder_attention_heads,
|
||||
kdim=getattr(args, "encoder_embed_dim", None),
|
||||
vdim=getattr(args, "encoder_embed_dim", None),
|
||||
dropout=args.attention_dropout,
|
||||
encoder_decoder_attention=True,
|
||||
)
|
||||
@@ -1,55 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from fairseq import registry
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
class BaseScorer(ABC):
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.ref = []
|
||||
self.pred = []
|
||||
|
||||
def add_string(self, ref, pred):
|
||||
self.ref.append(ref)
|
||||
self.pred.append(pred)
|
||||
|
||||
@abstractmethod
|
||||
def score(self) -> float:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def result_string(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
_build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry(
|
||||
"--scoring", default="bleu"
|
||||
)
|
||||
|
||||
|
||||
def build_scorer(choice, tgt_dict):
|
||||
_choice = choice._name if isinstance(choice, DictConfig) else choice
|
||||
|
||||
if _choice == "bleu":
|
||||
from fairseq.scoring import bleu
|
||||
|
||||
return bleu.Scorer(
|
||||
bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk())
|
||||
)
|
||||
return _build_scorer(choice)
|
||||
|
||||
|
||||
# automatically import any Python files in the current directory
|
||||
for file in sorted(os.listdir(os.path.dirname(__file__))):
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
module = file[: file.find(".py")]
|
||||
importlib.import_module("fairseq.scoring." + module)
|
||||
@@ -1,44 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.scoring import BaseScorer, register_scorer
|
||||
|
||||
|
||||
@dataclass
|
||||
class BertScoreScorerConfig(FairseqDataclass):
|
||||
bert_score_lang: str = field(default="en", metadata={"help": "BERTScore language"})
|
||||
|
||||
|
||||
@register_scorer("bert_score", dataclass=BertScoreScorerConfig)
|
||||
class BertScoreScorer(BaseScorer):
|
||||
def __init__(self, cfg):
|
||||
super(BertScoreScorer, self).__init__(cfg)
|
||||
try:
|
||||
import bert_score as _bert_score
|
||||
except ImportError:
|
||||
raise ImportError("Please install BERTScore: pip install bert-score")
|
||||
|
||||
self.cfg = cfg
|
||||
self._bert_score = _bert_score
|
||||
self.scores = None
|
||||
|
||||
def add_string(self, ref, pred):
|
||||
self.ref.append(ref)
|
||||
self.pred.append(pred)
|
||||
|
||||
def score(self, order=4):
|
||||
_, _, self.scores = self._bert_score.score(
|
||||
self.pred, self.ref, lang=self.cfg.bert_score_lang
|
||||
)
|
||||
self.scores = self.scores.numpy()
|
||||
return np.mean(self.scores)
|
||||
|
||||
def result_string(self, order=4):
|
||||
return f"BERTScore: {self.score():.4f}"
|
||||
@@ -1,168 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import ctypes
|
||||
import math
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.scoring import BaseScorer, register_scorer
|
||||
from fairseq.scoring.tokenizer import EvaluationTokenizer
|
||||
|
||||
|
||||
class BleuStat(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("reflen", ctypes.c_size_t),
|
||||
("predlen", ctypes.c_size_t),
|
||||
("match1", ctypes.c_size_t),
|
||||
("count1", ctypes.c_size_t),
|
||||
("match2", ctypes.c_size_t),
|
||||
("count2", ctypes.c_size_t),
|
||||
("match3", ctypes.c_size_t),
|
||||
("count3", ctypes.c_size_t),
|
||||
("match4", ctypes.c_size_t),
|
||||
("count4", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SacrebleuConfig(FairseqDataclass):
|
||||
sacrebleu_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field(
|
||||
default="13a", metadata={"help": "tokenizer"}
|
||||
)
|
||||
sacrebleu_lowercase: bool = field(
|
||||
default=False, metadata={"help": "apply lowercasing"}
|
||||
)
|
||||
sacrebleu_char_level: bool = field(
|
||||
default=False, metadata={"help": "evaluate at character level"}
|
||||
)
|
||||
|
||||
|
||||
@register_scorer("sacrebleu", dataclass=SacrebleuConfig)
|
||||
class SacrebleuScorer(BaseScorer):
|
||||
def __init__(self, cfg):
|
||||
super(SacrebleuScorer, self).__init__(cfg)
|
||||
import sacrebleu
|
||||
|
||||
self.sacrebleu = sacrebleu
|
||||
self.tokenizer = EvaluationTokenizer(
|
||||
tokenizer_type=cfg.sacrebleu_tokenizer,
|
||||
lowercase=cfg.sacrebleu_lowercase,
|
||||
character_tokenization=cfg.sacrebleu_char_level,
|
||||
)
|
||||
|
||||
def add_string(self, ref, pred):
|
||||
self.ref.append(self.tokenizer.tokenize(ref))
|
||||
self.pred.append(self.tokenizer.tokenize(pred))
|
||||
|
||||
def _score(self, order=4):
|
||||
if order != 4:
|
||||
raise NotImplementedError
|
||||
# tokenization and lowercasing are performed by self.tokenizer instead.
|
||||
return self.sacrebleu.corpus_bleu(self.pred, [self.ref], tokenize="none")
|
||||
|
||||
def score(self, order=4):
|
||||
return self._score(order).score
|
||||
|
||||
def result_string(self, order=4):
|
||||
return self._score(order).format()
|
||||
|
||||
|
||||
@dataclass
|
||||
class BleuConfig(FairseqDataclass):
|
||||
pad: int = field(default=1, metadata={"help": "padding index"})
|
||||
eos: int = field(default=2, metadata={"help": "eos index"})
|
||||
unk: int = field(default=3, metadata={"help": "unk index"})
|
||||
|
||||
|
||||
@register_scorer("bleu", dataclass=BleuConfig)
|
||||
class Scorer(object):
|
||||
def __init__(self, cfg):
|
||||
self.stat = BleuStat()
|
||||
self.pad = cfg.pad
|
||||
self.eos = cfg.eos
|
||||
self.unk = cfg.unk
|
||||
|
||||
try:
|
||||
from fairseq import libbleu
|
||||
except ImportError as e:
|
||||
sys.stderr.write(
|
||||
"ERROR: missing libbleu.so. run `pip install --editable .`\n"
|
||||
)
|
||||
raise e
|
||||
|
||||
self.C = ctypes.cdll.LoadLibrary(libbleu.__file__)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self, one_init=False):
|
||||
if one_init:
|
||||
self.C.bleu_one_init(ctypes.byref(self.stat))
|
||||
else:
|
||||
self.C.bleu_zero_init(ctypes.byref(self.stat))
|
||||
|
||||
def add(self, ref, pred):
|
||||
if not isinstance(ref, torch.IntTensor):
|
||||
raise TypeError("ref must be a torch.IntTensor (got {})".format(type(ref)))
|
||||
if not isinstance(pred, torch.IntTensor):
|
||||
raise TypeError("pred must be a torch.IntTensor(got {})".format(type(pred)))
|
||||
|
||||
# don't match unknown words
|
||||
rref = ref.clone()
|
||||
assert not rref.lt(0).any()
|
||||
rref[rref.eq(self.unk)] = -999
|
||||
|
||||
rref = rref.contiguous().view(-1)
|
||||
pred = pred.contiguous().view(-1)
|
||||
|
||||
self.C.bleu_add(
|
||||
ctypes.byref(self.stat),
|
||||
ctypes.c_size_t(rref.size(0)),
|
||||
ctypes.c_void_p(rref.data_ptr()),
|
||||
ctypes.c_size_t(pred.size(0)),
|
||||
ctypes.c_void_p(pred.data_ptr()),
|
||||
ctypes.c_int(self.pad),
|
||||
ctypes.c_int(self.eos),
|
||||
)
|
||||
|
||||
def score(self, order=4):
|
||||
psum = sum(
|
||||
math.log(p) if p > 0 else float("-Inf") for p in self.precision()[:order]
|
||||
)
|
||||
return self.brevity() * math.exp(psum / order) * 100
|
||||
|
||||
def precision(self):
|
||||
def ratio(a, b):
|
||||
return a / b if b > 0 else 0
|
||||
|
||||
return [
|
||||
ratio(self.stat.match1, self.stat.count1),
|
||||
ratio(self.stat.match2, self.stat.count2),
|
||||
ratio(self.stat.match3, self.stat.count3),
|
||||
ratio(self.stat.match4, self.stat.count4),
|
||||
]
|
||||
|
||||
def brevity(self):
|
||||
r = self.stat.reflen / self.stat.predlen
|
||||
return min(1, math.exp(1 - r))
|
||||
|
||||
def result_string(self, order=4):
|
||||
assert order <= 4, "BLEU scores for order > 4 aren't supported"
|
||||
fmt = "BLEU{} = {:2.2f}, {:2.1f}"
|
||||
for _ in range(1, order):
|
||||
fmt += "/{:2.1f}"
|
||||
fmt += " (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})"
|
||||
bleup = [p * 100 for p in self.precision()[:order]]
|
||||
return fmt.format(
|
||||
order,
|
||||
self.score(order=order),
|
||||
*bleup,
|
||||
self.brevity(),
|
||||
self.stat.predlen / self.stat.reflen,
|
||||
self.stat.predlen,
|
||||
self.stat.reflen
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.scoring import BaseScorer, register_scorer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChrFScorerConfig(FairseqDataclass):
|
||||
pass
|
||||
|
||||
|
||||
@register_scorer("chrf", dataclass=ChrFScorerConfig)
|
||||
class ChrFScorer(BaseScorer):
|
||||
def __init__(self, args):
|
||||
super(ChrFScorer, self).__init__(args)
|
||||
import sacrebleu
|
||||
|
||||
self.sacrebleu = sacrebleu
|
||||
|
||||
def add_string(self, ref, pred):
|
||||
self.ref.append(ref)
|
||||
self.pred.append(pred)
|
||||
|
||||
def score(self, order=4):
|
||||
return self.result_string(order).score
|
||||
|
||||
def result_string(self, order=4):
|
||||
if order != 4:
|
||||
raise NotImplementedError
|
||||
return self.sacrebleu.corpus_chrf(self.pred, [self.ref]).format()
|
||||
@@ -1,42 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.scoring import BaseScorer, register_scorer
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeteorScorerConfig(FairseqDataclass):
|
||||
pass
|
||||
|
||||
|
||||
@register_scorer("meteor", dataclass=MeteorScorerConfig)
|
||||
class MeteorScorer(BaseScorer):
|
||||
def __init__(self, args):
|
||||
super(MeteorScorer, self).__init__(args)
|
||||
try:
|
||||
import nltk
|
||||
except ImportError:
|
||||
raise ImportError("Please install nltk to use METEOR scorer")
|
||||
|
||||
self.nltk = nltk
|
||||
self.scores = []
|
||||
|
||||
def add_string(self, ref, pred):
|
||||
self.ref.append(ref)
|
||||
self.pred.append(pred)
|
||||
|
||||
def score(self, order=4):
|
||||
self.scores = [
|
||||
self.nltk.translate.meteor_score.single_meteor_score(r, p)
|
||||
for r, p in zip(self.ref, self.pred)
|
||||
]
|
||||
return np.mean(self.scores)
|
||||
|
||||
def result_string(self, order=4):
|
||||
return f"METEOR: {self.score():.4f}"
|
||||
@@ -1,80 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unicodedata
|
||||
|
||||
import sacrebleu as sb
|
||||
|
||||
from fairseq.dataclass import ChoiceEnum
|
||||
|
||||
SACREBLEU_V2_ABOVE = int(sb.__version__[0]) >= 2
|
||||
|
||||
|
||||
class EvaluationTokenizer(object):
|
||||
"""A generic evaluation-time tokenizer, which leverages built-in tokenizers
|
||||
in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides
|
||||
lowercasing, punctuation removal and character tokenization, which are
|
||||
applied after sacreBLEU tokenization.
|
||||
|
||||
Args:
|
||||
tokenizer_type (str): the type of sacreBLEU tokenizer to apply.
|
||||
lowercase (bool): lowercase the text.
|
||||
punctuation_removal (bool): remove punctuation (based on unicode
|
||||
category) from text.
|
||||
character_tokenization (bool): tokenize the text to characters.
|
||||
"""
|
||||
|
||||
SPACE = chr(32)
|
||||
SPACE_ESCAPE = chr(9601)
|
||||
_ALL_TOKENIZER_TYPES = (
|
||||
sb.BLEU.TOKENIZERS
|
||||
if SACREBLEU_V2_ABOVE
|
||||
else ["none", "13a", "intl", "zh", "ja-mecab"]
|
||||
)
|
||||
ALL_TOKENIZER_TYPES = ChoiceEnum(_ALL_TOKENIZER_TYPES)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_type: str = "13a",
|
||||
lowercase: bool = False,
|
||||
punctuation_removal: bool = False,
|
||||
character_tokenization: bool = False,
|
||||
):
|
||||
|
||||
assert (
|
||||
tokenizer_type in self._ALL_TOKENIZER_TYPES
|
||||
), f"{tokenizer_type}, {self._ALL_TOKENIZER_TYPES}"
|
||||
self.lowercase = lowercase
|
||||
self.punctuation_removal = punctuation_removal
|
||||
self.character_tokenization = character_tokenization
|
||||
if SACREBLEU_V2_ABOVE:
|
||||
self.tokenizer = sb.BLEU(tokenize=str(tokenizer_type)).tokenizer
|
||||
else:
|
||||
self.tokenizer = sb.tokenizers.TOKENIZERS[tokenizer_type]()
|
||||
|
||||
@classmethod
|
||||
def remove_punctuation(cls, sent: str):
|
||||
"""Remove punctuation based on Unicode category."""
|
||||
return cls.SPACE.join(
|
||||
t
|
||||
for t in sent.split(cls.SPACE)
|
||||
if not all(unicodedata.category(c)[0] == "P" for c in t)
|
||||
)
|
||||
|
||||
def tokenize(self, sent: str):
|
||||
tokenized = self.tokenizer(sent)
|
||||
|
||||
if self.punctuation_removal:
|
||||
tokenized = self.remove_punctuation(tokenized)
|
||||
|
||||
if self.character_tokenization:
|
||||
tokenized = self.SPACE.join(
|
||||
list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE))
|
||||
)
|
||||
|
||||
if self.lowercase:
|
||||
tokenized = tokenized.lower()
|
||||
|
||||
return tokenized
|
||||
@@ -1,58 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.scoring import BaseScorer, register_scorer
|
||||
from fairseq.scoring.tokenizer import EvaluationTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class WerScorerConfig(FairseqDataclass):
|
||||
wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field(
|
||||
default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"}
|
||||
)
|
||||
wer_remove_punct: bool = field(
|
||||
default=False, metadata={"help": "remove punctuation"}
|
||||
)
|
||||
wer_char_level: bool = field(
|
||||
default=False, metadata={"help": "evaluate at character level"}
|
||||
)
|
||||
wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"})
|
||||
|
||||
|
||||
@register_scorer("wer", dataclass=WerScorerConfig)
|
||||
class WerScorer(BaseScorer):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.reset()
|
||||
try:
|
||||
import editdistance as ed
|
||||
except ImportError:
|
||||
raise ImportError("Please install editdistance to use WER scorer")
|
||||
self.ed = ed
|
||||
self.tokenizer = EvaluationTokenizer(
|
||||
tokenizer_type=self.cfg.wer_tokenizer,
|
||||
lowercase=self.cfg.wer_lowercase,
|
||||
punctuation_removal=self.cfg.wer_remove_punct,
|
||||
character_tokenization=self.cfg.wer_char_level,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self.distance = 0
|
||||
self.ref_length = 0
|
||||
|
||||
def add_string(self, ref, pred):
|
||||
ref_items = self.tokenizer.tokenize(ref).split()
|
||||
pred_items = self.tokenizer.tokenize(pred).split()
|
||||
self.distance += self.ed.eval(ref_items, pred_items)
|
||||
self.ref_length += len(ref_items)
|
||||
|
||||
def result_string(self):
|
||||
return f"WER: {self.score():.2f}"
|
||||
|
||||
def score(self):
|
||||
return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0
|
||||
@@ -20,8 +20,8 @@ import os
|
||||
import logging
|
||||
|
||||
DEBUG_PREFIX = "<RVC module>"
|
||||
INPUT_FILE_PATH = "rvc_input.wav" #"./data/tmp/rvc_input.wav"
|
||||
OUTPUT_FILE_PATH = "rvc_output.wav" #"./data/tmp/rvc_output.wav"
|
||||
INPUT_FILE_PATH = "data/tmp/rvc_input.wav" #"./data/tmp/rvc_input.wav"
|
||||
OUTPUT_FILE_PATH = "data/tmp/rvc_output.wav" #"./data/tmp/rvc_output.wav"
|
||||
|
||||
def rvc_process_audio():
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user