Continued monkey patching fairseq for RVC python 3.11

This commit is contained in:
Tony Ribeiro
2023-08-10 16:15:54 +02:00
parent 240684c8c4
commit 3fbb1b6ede
78 changed files with 15 additions and 9090 deletions

View File

@@ -26,20 +26,20 @@ sys.modules["fairseq.metrics"] = metrics
sys.modules["fairseq.progress_bar"] = progress_bar
# initialize hydra
from fairseq.dataclass.initialize import hydra_init
#from fairseq.dataclass.initialize import hydra_init
#hydra_init()
import fairseq.criterions # noqa
import fairseq.distributed # noqa
import fairseq.models # noqa
import fairseq.modules # noqa
import fairseq.optim # noqa
import fairseq.optim.lr_scheduler # noqa
import fairseq.pdb # noqa
import fairseq.scoring # noqa
import fairseq.tasks # noqa
import fairseq.token_generation_constraints # noqa
#import fairseq.criterions # noqa
#import fairseq.distributed # noqa
#import fairseq.models # noqa
#import fairseq.modules # noqa
#import fairseq.optim # noqa
#import fairseq.optim.lr_scheduler # noqa
#import fairseq.pdb # noqa
#import fairseq.scoring # noqa
#import fairseq.tasks # noqa
#import fairseq.token_generation_constraints # noqa
import fairseq.benchmark # noqa
import fairseq.model_parallel # noqa
#import fairseq.benchmark # noqa
#import fairseq.model_parallel # noqa

View File

@@ -1,7 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# import models/tasks to register them
from . import dummy_dataset, dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa

View File

@@ -1,172 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import random
import torch
from torch.utils import benchmark
from fairseq.modules.multihead_attention import MultiheadAttention
BATCH = [20, 41, 97]
SEQ = 64
EMB = 48
HEADS = 4
DROP = 0.1
DEVICE = torch.device("cuda")
ATTN_MASK_DTYPE = [torch.uint8, torch.bool, torch.float]
KEY_PADDING_MASK_DTYPE = [torch.uint8, torch.bool]
def _reset_seeds():
torch.manual_seed(0)
random.seed(0)
def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
if to_dtype == torch.float:
mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
def benchmark_multihead_attention(
label="",
attn_dtype=torch.uint8,
key_padding_dtype=torch.uint8,
add_bias_kv=False,
add_zero_attn=False,
static_kv=False,
batch_size=20,
embedding=EMB,
seq_len=SEQ,
num_heads=HEADS,
):
results = []
# device = torch.device("cuda")
xformers_att_config = '{"name": "scaled_dot_product"}'
attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len)
key_padding_mask = _get_mask(
to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len
)
q = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
k = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
v = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
_reset_seeds()
original_mha = MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
xformers_att_config=None,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
xformers_mha = MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
xformers_att_config=xformers_att_config,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
original_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
xformers_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
output, _ = original_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
loss = torch.norm(output)
loss.backward()
def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
output, _ = xformers_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
loss = torch.norm(output)
loss.backward()
fns = [
original_bench_fw,
xformers_bench_fw,
original_bench_fw_bw,
xformers_bench_fw_bw,
]
for fn in fns:
results.append(
benchmark.Timer(
stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)",
globals={
"q": q,
"k": k,
"v": v,
"key_padding_mask": key_padding_mask,
"attn_mask": attn_mask,
"static_kv": static_kv,
"fn": fn,
},
label="multihead fw + bw",
sub_label=f"{fn.__name__}",
description=label,
).blocked_autorange(min_run_time=1)
)
compare = benchmark.Compare(results)
compare.print()
def run_benchmarks():
for attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn in itertools.product(
ATTN_MASK_DTYPE, KEY_PADDING_MASK_DTYPE, [True, False], [True, False]
):
label = f"attn_dtype {attn_dtype}, key_padding_dtype {key_padding_dtype}, \
add_bias_kv {add_bias_kv}, add_zero_attn {add_zero_attn}"
benchmark_multihead_attention(
label=label,
attn_dtype=attn_dtype,
key_padding_dtype=key_padding_dtype,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
run_benchmarks()

View File

@@ -1,36 +0,0 @@
import numpy as np
from fairseq.data import FairseqDataset
class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
self.num_items = num_items
self.item_size = item_size
def __getitem__(self, index):
return index
def __len__(self):
return self.num_items
def collater(self, samples):
return self.batch
@property
def sizes(self):
return np.array([self.item_size] * self.num_items)
def num_tokens(self, index):
return self.item_size
def size(self, index):
return self.item_size
def ordered_indices(self):
return np.arange(self.num_items)
@property
def supports_prefetch(self):
return False

View File

@@ -1,83 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from typing import Optional
import torch
from .dummy_dataset import DummyDataset
from fairseq.data import Dictionary
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from omegaconf import II
logger = logging.getLogger(__name__)
@dataclass
class DummyLMConfig(FairseqDataclass):
dict_size: int = 49996
dataset_size: int = 100000
tokens_per_sample: int = field(
default=512, metadata={"help": "max sequence length"}
)
add_bos_token: bool = False
batch_size: Optional[int] = II("dataset.batch_size")
max_tokens: Optional[int] = II("dataset.max_tokens")
max_target_positions: int = II("task.tokens_per_sample")
@register_task("dummy_lm", dataclass=DummyLMConfig)
class DummyLMTask(FairseqTask):
def __init__(self, cfg: DummyLMConfig):
super().__init__(cfg)
# load dictionary
self.dictionary = Dictionary()
for i in range(cfg.dict_size):
self.dictionary.add_symbol("word{}".format(i))
self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
logger.info("dictionary: {} types".format(len(self.dictionary)))
seq = torch.arange(cfg.tokens_per_sample + 1) + self.dictionary.pad() + 1
self.dummy_src = seq[:-1]
self.dummy_tgt = seq[1:]
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if self.cfg.batch_size is not None:
bsz = self.cfg.batch_size
else:
bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
self.datasets[split] = DummyDataset(
{
"id": 1,
"net_input": {
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
"src_lengths": torch.full(
(bsz,), self.cfg.tokens_per_sample, dtype=torch.long
),
},
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
"nsentences": bsz,
"ntokens": bsz * self.cfg.tokens_per_sample,
},
num_items=self.cfg.dataset_size,
item_size=self.cfg.tokens_per_sample,
)
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary

View File

@@ -1,94 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from typing import Optional
import torch
from omegaconf import II
from .dummy_dataset import DummyDataset
from fairseq.data import Dictionary
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
logger = logging.getLogger(__name__)
@dataclass
class DummyMaskedLMConfig(FairseqDataclass):
dict_size: int = 49996
dataset_size: int = 100000
tokens_per_sample: int = field(
default=512,
metadata={
"help": "max number of total tokens over all"
" segments per sample for BERT dataset"
},
)
batch_size: Optional[int] = II("dataset.batch_size")
max_tokens: Optional[int] = II("dataset.max_tokens")
max_target_positions: int = II("task.tokens_per_sample")
@register_task("dummy_masked_lm", dataclass=DummyMaskedLMConfig)
class DummyMaskedLMTask(FairseqTask):
def __init__(self, cfg: DummyMaskedLMConfig):
super().__init__(cfg)
self.dictionary = Dictionary()
for i in range(cfg.dict_size):
self.dictionary.add_symbol("word{}".format(i))
logger.info("dictionary: {} types".format(len(self.dictionary)))
# add mask token
self.mask_idx = self.dictionary.add_symbol("<mask>")
self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
mask_idx = 0
pad_idx = 1
seq = torch.arange(cfg.tokens_per_sample) + pad_idx + 1
mask = torch.arange(2, cfg.tokens_per_sample, 7) # ~15%
src = seq.clone()
src[mask] = mask_idx
tgt = torch.full_like(seq, pad_idx)
tgt[mask] = seq[mask]
self.dummy_src = src
self.dummy_tgt = tgt
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if self.cfg.batch_size is not None:
bsz = self.cfg.batch_size
else:
bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
self.datasets[split] = DummyDataset(
{
"id": 1,
"net_input": {
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
"src_lengths": torch.full(
(bsz,), self.cfg.tokens_per_sample, dtype=torch.long
),
},
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
"nsentences": bsz,
"ntokens": bsz * self.cfg.tokens_per_sample,
},
num_items=self.cfg.dataset_size,
item_size=self.cfg.tokens_per_sample,
)
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary

View File

@@ -1,96 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
import torch.nn.functional as F
from fairseq.data import Dictionary
from fairseq.models import (
FairseqDecoder,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
@register_model("dummy_model")
class DummyModel(FairseqLanguageModel):
def __init__(self, args, encoder):
super().__init__(encoder)
self.args = args
@staticmethod
def add_args(parser):
parser.add_argument("--num-layers", type=int, default=24)
parser.add_argument("--embed-dim", type=int, default=1024)
@classmethod
def build_model(cls, args, task):
encoder = DummyEncoder(
num_embed=len(task.target_dictionary),
embed_dim=args.embed_dim,
num_layers=args.num_layers,
)
return cls(args, encoder)
def forward(self, src_tokens, masked_tokens=None, **kwargs):
return self.decoder(src_tokens, masked_tokens=masked_tokens)
class DummyEncoder(FairseqDecoder):
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
super().__init__(Dictionary())
self.embed = nn.Embedding(
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
)
self.layers_a = nn.ModuleList(
[
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection
nn.Linear(3 * embed_dim, embed_dim), # skip self-attention
nn.Linear(embed_dim, embed_dim), # output projection
nn.Dropout(),
)
for i in range(num_layers)
]
)
self.layers_b = nn.ModuleList(
[
nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 4 * embed_dim), # FFN
nn.ReLU(),
nn.Linear(4 * embed_dim, embed_dim), # FFN
nn.Dropout(0.1),
)
for i in range(num_layers)
]
)
self.out_proj = nn.Linear(embed_dim, num_embed)
def forward(self, tokens, masked_tokens=None):
x = self.embed(tokens)
for layer_a, layer_b in zip(self.layers_a, self.layers_b):
x = x + layer_a(x)
x = x + layer_b(x)
x = self.out_proj(x)
if masked_tokens is not None:
x = x[masked_tokens]
return (x,)
def max_positions(self):
return 1024
def get_normalized_probs(self, net_output, log_probs, sample=None):
logits = net_output[0].float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)
@register_model_architecture("dummy_model", "dummy_model")
def base_architecture(args):
pass

View File

@@ -1,119 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("dummy_mt")
class DummyMTTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument("--dict-size", default=49996, type=int)
parser.add_argument("--dataset-size", default=100000, type=int)
parser.add_argument("--src-len", default=30, type=int)
parser.add_argument("--tgt-len", default=30, type=int)
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1
self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task."""
dictionary = Dictionary()
for i in range(args.dict_size):
dictionary.add_symbol("word{}".format(i))
logger.info("dictionary: {} types".format(len(dictionary)))
args.max_source_positions = args.src_len + dictionary.pad() + 2
args.max_target_positions = args.tgt_len + dictionary.pad() + 2
return cls(args, dictionary)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
item_size = max(self.args.src_len, self.args.tgt_len)
if self.args.batch_size is not None:
bsz = self.args.batch_size
else:
bsz = max(1, self.args.max_tokens // item_size)
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
self.datasets[split] = DummyDataset(
{
"id": 1,
"net_input": {
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
"src_lengths": torch.full(
(bsz,), self.args.src_len, dtype=torch.long
),
"prev_output_tokens": tgt.clone(),
},
"target": tgt,
"nsentences": bsz,
"ntokens": bsz * self.args.tgt_len,
},
num_items=self.args.dataset_size,
item_size=item_size,
)
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
self.num_items = num_items
self.item_size = item_size
def __getitem__(self, index):
return index
def __len__(self):
return self.num_items
def collater(self, samples):
return self.batch
@property
def sizes(self):
return np.array([self.item_size] * self.num_items)
def num_tokens(self, index):
return self.item_size
def size(self, index):
return self.item_size
def ordered_indices(self):
return np.arange(self.num_items)
@property
def supports_prefetch(self):
return False

View File

@@ -1,55 +0,0 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/
#include <torch/extension.h>
#include <vector>
/*
CPP Binding for CUDA OP
*/
// CUDA forward declarations
torch::Tensor ngram_repeat_block_cuda_forward(
torch::Tensor tokens,
torch::Tensor lprobs,
int bsz,
int step,
int beam_size,
int no_repeat_ngram_size);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// Input check and call to CUDA OP
// Backward method not required
torch::Tensor ngram_repeat_block_forward(
torch::Tensor tokens,
torch::Tensor lprobs,
int bsz,
int step,
int beam_size,
int no_repeat_ngram_size) {
CHECK_INPUT(tokens);
CHECK_INPUT(lprobs);
assert(bsz > 0);
assert(step >= 0);
assert(beam_size > 0);
assert(no_repeat_ngram_size > 0);
return ngram_repeat_block_cuda_forward(
tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"forward",
&ngram_repeat_block_forward,
"No Repeat Ngram Block forward (CUDA)");
}

View File

@@ -1,82 +0,0 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/
/*
Kernel implementation for blocking repeated n-grams.
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <torch/extension.h>
#include <vector>
// Ban repeated ngrams of length = 'no_repeat_ngram_size'
__global__ void banRepeatedTokens(
long* __restrict__ tokens,
float* __restrict__ lprobs,
int max_predict_len,
int vocab_size,
int no_repeat_ngram_size) {
auto row = blockIdx.x;
auto col = threadIdx.x;
auto start = row * (max_predict_len) + col;
// Each thread compares ngram starting from
// thread index with final ngram starting from
// step - no_repeat_ngram_size +2
auto check_start_pos = blockDim.x;
auto lprob_start = row * vocab_size;
bool is_banned = true;
extern __shared__ long tokens_shm[];
tokens_shm[col] = tokens[start];
if (col == blockDim.x - 1) {
for (int i = 1; i < no_repeat_ngram_size; i++) {
if (col + i < max_predict_len) {
tokens_shm[col + i] = tokens[start + i];
}
}
}
__syncthreads();
for (int k = 0; k < no_repeat_ngram_size - 1; k++) {
if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) {
is_banned = false;
}
}
if (is_banned == true) {
auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1];
lprobs[lprob_start + token_to_be_banned] = -INFINITY;
}
}
// Allocate blocks and threads based on
// batch size and sequence length and launch
// kernel
torch::Tensor ngram_repeat_block_cuda_forward(
const torch::Tensor tokens,
torch::Tensor lprobs,
int bsz,
int step,
int beam_size,
int no_repeat_ngram_size) {
int threads = step - no_repeat_ngram_size + 2;
if (threads <= 0)
return lprobs;
int max_predict_len = tokens.size(1);
int vocab_size = lprobs.size(1);
auto token_ptr = tokens.data_ptr<long>();
auto lprob_ptr = lprobs.data_ptr<float>();
int blocks = bsz * beam_size;
int shared_mem_size = (step + 1) * sizeof(long);
// Launching N blocks where N is number of samples in a batch (beams*bsz)
// Launching T threads where T is number of previous ngrams in a sample
// Allocating shared mem per block for fastser access of input tokens since
// each token will be accessed N times to compare with current Ngram where
// N is Ngram size.
banRepeatedTokens<<<blocks, threads, shared_mem_size>>>(
token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size);
return lprobs;
}

View File

@@ -1,109 +0,0 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
/*
C++ code for solving the linear assignment problem.
Based on the Auction Algorithm from
https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the
implementation from: https://github.com/bkj/auction-lap Adapted to be more
efficient when each worker is looking for k jobs instead of 1.
*/
#include <torch/extension.h>
#include <iostream>
using namespace torch::indexing;
torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score) {
int max_iterations = 100;
torch::Tensor epsilon =
(job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50;
epsilon.clamp_min_(1e-04);
torch::Tensor worker_and_job_to_score =
job_and_worker_to_score.detach().transpose(0, 1).contiguous();
int num_workers = worker_and_job_to_score.size(0);
int num_jobs = worker_and_job_to_score.size(1);
auto device = worker_and_job_to_score.device();
int jobs_per_worker = num_jobs / num_workers;
torch::Tensor value = worker_and_job_to_score.clone();
int counter = 0;
torch::Tensor max_value = worker_and_job_to_score.max();
torch::Tensor bid_indices;
torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs});
torch::Tensor bids =
worker_and_job_to_score.new_empty({num_workers, num_jobs});
torch::Tensor bid_increments =
worker_and_job_to_score.new_empty({num_workers, jobs_per_worker});
torch::Tensor top_values =
worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1});
torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs});
torch::Tensor top_index = top_values.to(torch::kLong);
torch::Tensor high_bidders = top_index.new_empty({num_jobs});
torch::Tensor have_bids = high_bidders.to(torch::kBool);
torch::Tensor jobs_indices =
torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device));
torch::Tensor true_tensor =
torch::ones({1}, torch::dtype(torch::kBool).device(device));
while (true) {
bids.zero_();
torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1);
// Each worker bids the difference in value between that job and the k+1th
// job
torch::sub_out(
bid_increments,
top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}),
top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1));
bid_increments.add_(epsilon);
bids.scatter_(
1,
top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}),
bid_increments);
if (counter < max_iterations && counter > 0) {
// Put in a minimal bid to retain items from the last round if no-one else
// bids for them this round
bids.view(-1).index_put_({bid_indices}, epsilon);
}
// Find the highest bidding worker per job
torch::max_out(high_bids, high_bidders, bids, 0);
torch::gt_out(have_bids, high_bids, 0);
if (have_bids.all().item<bool>()) {
// All jobs were bid for
break;
}
// Make popular items more expensive
cost.add_(high_bids);
torch::sub_out(value, worker_and_job_to_score, cost);
bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids});
if (counter < max_iterations) {
// Make sure that this item will be in the winning worker's top-k next
// time.
value.view(-1).index_put_({bid_indices}, max_value);
} else {
// Suboptimal approximation that converges quickly from current solution
value.view(-1).index_put_(
{bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices}));
}
counter += 1;
}
return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)})
.reshape(-1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment");
}

View File

@@ -1,157 +0,0 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <array>
#include <cstdio>
#include <cstring>
#include <map>
// NOLINTNEXTLINE
typedef struct {
size_t reflen;
size_t predlen;
size_t match1;
size_t count1;
size_t match2;
size_t count2;
size_t match3;
size_t count3;
size_t match4;
size_t count4;
} bleu_stat;
// left trim (remove pad)
void bleu_ltrim(size_t* len, int** sent, int pad) {
size_t start = 0;
while (start < *len) {
if (*(*sent + start) != pad) {
break;
}
start++;
}
*sent += start;
*len -= start;
}
// right trim remove (eos)
void bleu_rtrim(size_t* len, int** sent, int pad, int eos) {
size_t end = *len - 1;
while (end > 0) {
if (*(*sent + end) != eos && *(*sent + end) != pad) {
break;
}
end--;
}
*len = end + 1;
}
// left and right trim
void bleu_trim(size_t* len, int** sent, int pad, int eos) {
bleu_ltrim(len, sent, pad);
bleu_rtrim(len, sent, pad, eos);
}
size_t bleu_hash(int len, int* data) {
size_t h = 14695981039346656037ul;
size_t prime = 0x100000001b3;
char* b = (char*)data;
size_t blen = sizeof(int) * len;
while (blen-- > 0) {
h ^= *b++;
h *= prime;
}
return h;
}
void bleu_addngram(
size_t* ntotal,
size_t* nmatch,
size_t n,
size_t reflen,
int* ref,
size_t predlen,
int* pred) {
if (predlen < n) {
return;
}
predlen = predlen - n + 1;
(*ntotal) += predlen;
if (reflen < n) {
return;
}
reflen = reflen - n + 1;
std::map<size_t, size_t> count;
while (predlen > 0) {
size_t w = bleu_hash(n, pred++);
count[w]++;
predlen--;
}
while (reflen > 0) {
size_t w = bleu_hash(n, ref++);
if (count[w] > 0) {
(*nmatch)++;
count[w] -= 1;
}
reflen--;
}
}
extern "C" {
#ifdef _WIN64
__declspec(dllexport)
#endif
void bleu_zero_init(bleu_stat* stat) {
std::memset(stat, 0, sizeof(bleu_stat));
}
#ifdef _WIN64
__declspec(dllexport)
#endif
void bleu_one_init(bleu_stat* stat) {
bleu_zero_init(stat);
stat->count1 = 0;
stat->count2 = 1;
stat->count3 = 1;
stat->count4 = 1;
stat->match1 = 0;
stat->match2 = 1;
stat->match3 = 1;
stat->match4 = 1;
}
#ifdef _WIN64
__declspec(dllexport)
#endif
void bleu_add(
bleu_stat* stat,
size_t reflen,
int* ref,
size_t predlen,
int* pred,
int pad,
int eos) {
bleu_trim(&reflen, &ref, pad, eos);
bleu_trim(&predlen, &pred, pad, eos);
stat->reflen += reflen;
stat->predlen += predlen;
bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred);
bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred);
bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred);
bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred);
}
}

View File

@@ -1,33 +0,0 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <Python.h>
static PyMethodDef method_def[] = {{NULL, NULL, 0, NULL}}; // NOLINT
static struct PyModuleDef module_def = {
PyModuleDef_HEAD_INIT,
"libbleu", /* name of module */
// NOLINTNEXTLINE
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
method_def}; // NOLINT
#if PY_MAJOR_VERSION == 2
PyMODINIT_FUNC init_libbleu()
#else
PyMODINIT_FUNC PyInit_libbleu()
#endif
{
PyObject* m = PyModule_Create(&module_def);
if (!m) {
return NULL;
}
return m;
}

View File

@@ -1,231 +0,0 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
#include <torch/torch.h> // @manual=//caffe2:torch_extension
#include <algorithm>
#include <cstdint>
#include <iosfwd>
#include <memory>
#include <new>
#include <string>
#include <utility>
#include <vector>
using namespace ::std;
vector<vector<uint32_t>> edit_distance2_with_dp(
vector<uint32_t>& x,
vector<uint32_t>& y) {
uint32_t lx = x.size();
uint32_t ly = y.size();
vector<vector<uint32_t>> d(lx + 1, vector<uint32_t>(ly + 1));
for (uint32_t i = 0; i < lx + 1; i++) {
d[i][0] = i;
}
for (uint32_t j = 0; j < ly + 1; j++) {
d[0][j] = j;
}
for (uint32_t i = 1; i < lx + 1; i++) {
for (uint32_t j = 1; j < ly + 1; j++) {
d[i][j] =
min(min(d[i - 1][j], d[i][j - 1]) + 1,
d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1));
}
}
return d;
}
vector<vector<uint32_t>> edit_distance2_backtracking(
vector<vector<uint32_t>>& d,
vector<uint32_t>& x,
vector<uint32_t>& y,
uint32_t terminal_symbol) {
vector<uint32_t> seq;
vector<vector<uint32_t>> edit_seqs(x.size() + 2, vector<uint32_t>());
/*
edit_seqs:
0~x.size() cell is the insertion sequences
last cell is the delete sequence
*/
if (x.size() == 0) {
edit_seqs.at(0) = y;
return edit_seqs;
}
uint32_t i = d.size() - 1;
uint32_t j = d.at(0).size() - 1;
while ((i >= 0) && (j >= 0)) {
if ((i == 0) && (j == 0)) {
break;
}
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
seq.push_back(1); // insert
seq.push_back(y.at(j - 1));
j--;
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
seq.push_back(2); // delete
seq.push_back(x.at(i - 1));
i--;
} else {
seq.push_back(3); // keep
seq.push_back(x.at(i - 1));
i--;
j--;
}
}
uint32_t prev_op, op, s, word;
prev_op = 0, s = 0;
for (uint32_t k = 0; k < seq.size() / 2; k++) {
op = seq.at(seq.size() - 2 * k - 2);
word = seq.at(seq.size() - 2 * k - 1);
if (prev_op != 1) {
s++;
}
if (op == 1) // insert
{
edit_seqs.at(s - 1).push_back(word);
} else if (op == 2) // delete
{
edit_seqs.at(x.size() + 1).push_back(1);
} else {
edit_seqs.at(x.size() + 1).push_back(0);
}
prev_op = op;
}
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
if (edit_seqs[k].size() == 0) {
edit_seqs[k].push_back(terminal_symbol);
}
}
return edit_seqs;
}
vector<vector<uint32_t>> edit_distance2_backtracking_with_delete(
vector<vector<uint32_t>>& d,
vector<uint32_t>& x,
vector<uint32_t>& y,
uint32_t terminal_symbol,
uint32_t deletion_symbol) {
vector<uint32_t> seq;
vector<vector<uint32_t>> edit_seqs(x.size() + 1, vector<uint32_t>());
/*
edit_seqs:
0~x.size() cell is the insertion sequences
last cell is the delete sequence
*/
if (x.size() == 0) {
edit_seqs.at(0) = y;
return edit_seqs;
}
uint32_t i = d.size() - 1;
uint32_t j = d.at(0).size() - 1;
while ((i >= 0) && (j >= 0)) {
if ((i == 0) && (j == 0)) {
break;
}
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
seq.push_back(1); // insert
seq.push_back(y.at(j - 1));
j--;
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
seq.push_back(2); // delete
seq.push_back(x.at(i - 1));
i--;
} else {
seq.push_back(3); // keep
seq.push_back(x.at(i - 1));
i--;
j--;
}
}
uint32_t prev_op, op, s, word;
prev_op = 0, s = 0;
for (uint32_t k = 0; k < seq.size() / 2; k++) {
op = seq.at(seq.size() - 2 * k - 2);
word = seq.at(seq.size() - 2 * k - 1);
if (prev_op != 1) {
s++;
}
if (op == 1) // insert
{
edit_seqs.at(s - 1).push_back(word);
} else if (op == 2) // delete
{
edit_seqs.at(s - 1).push_back(deletion_symbol);
}
prev_op = op;
}
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
if (edit_seqs.at(k).size() == 0) {
edit_seqs.at(k).push_back(terminal_symbol);
}
}
return edit_seqs;
}
vector<uint32_t> compute_ed2(
vector<vector<uint32_t>>& xs,
vector<vector<uint32_t>>& ys) {
vector<uint32_t> distances(xs.size());
for (uint32_t i = 0; i < xs.size(); i++) {
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size());
}
return distances;
}
vector<vector<vector<uint32_t>>> suggested_ed2_path(
vector<vector<uint32_t>>& xs,
vector<vector<uint32_t>>& ys,
uint32_t terminal_symbol) {
vector<vector<vector<uint32_t>>> seq(xs.size());
for (uint32_t i = 0; i < xs.size(); i++) {
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
seq.at(i) =
edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol);
}
return seq;
}
vector<vector<vector<uint32_t>>> suggested_ed2_path_with_delete(
vector<vector<uint32_t>>& xs,
vector<vector<uint32_t>>& ys,
uint32_t terminal_symbol,
uint32_t deletion_symbol) {
vector<vector<vector<uint32_t>>> seq(xs.size());
for (uint32_t i = 0; i < xs.size(); i++) {
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
seq.at(i) = edit_distance2_backtracking_with_delete(
d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol);
}
return seq;
}
PYBIND11_MODULE(libnat, m) {
m.def("compute_ed2", &compute_ed2, "compute_ed2");
m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path");
m.def(
"suggested_ed2_path_with_delete",
&suggested_ed2_path_with_delete,
"suggested_ed2_path_with_delete");
}

View File

@@ -1,67 +0,0 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
/*
This code is partially adpoted from
https://github.com/1ytic/pytorch-edit-distance
*/
#include <torch/types.h>
#include "edit_dist.h"
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch::Tensor LevenshteinDistance(
torch::Tensor source,
torch::Tensor target,
torch::Tensor source_length,
torch::Tensor target_length) {
CHECK_INPUT(source);
CHECK_INPUT(target);
CHECK_INPUT(source_length);
CHECK_INPUT(target_length);
return LevenshteinDistanceCuda(source, target, source_length, target_length);
}
torch::Tensor GenerateDeletionLabel(
torch::Tensor source,
torch::Tensor operations) {
CHECK_INPUT(source);
CHECK_INPUT(operations);
return GenerateDeletionLabelCuda(source, operations);
}
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel(
torch::Tensor target,
torch::Tensor operations) {
CHECK_INPUT(target);
CHECK_INPUT(operations);
return GenerateInsertionLabelCuda(target, operations);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance");
m.def(
"generate_deletion_labels",
&GenerateDeletionLabel,
"Generate Deletion Label");
m.def(
"generate_insertion_labels",
&GenerateInsertionLabel,
"Generate Insertion Label");
}

View File

@@ -1,344 +0,0 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "edit_dist.h"
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <utility> // std::pair
template <typename scalar_t>
__global__ void generate_deletion_label_kernel(
const scalar_t* __restrict__ source,
const size_t source_size,
const size_t operation_size,
int* __restrict__ operations,
int* __restrict__ labels) {
const int index = blockIdx.x;
const int offset = index * operation_size;
const int offset_label = index * source_size;
for (int i = 0; i < source_size; i++) {
labels[offset_label + i] = 0;
}
int k = 0;
for (int i = 0; i < operation_size; i++) {
if (operations[offset + i] == 0) {
break;
} else if (operations[offset + i] == 1) {
continue;
} else {
labels[offset_label + k] = 3 - operations[offset + i];
k++;
}
}
}
template <typename scalar_t>
__global__ void generate_insertion_label_kernel(
const scalar_t* __restrict__ target,
const size_t target_size,
const size_t operation_size,
int* __restrict__ operations,
int* __restrict__ labels,
int* __restrict__ masks) {
const int index = blockIdx.x;
const int offset = index * operation_size;
const int offset_label = index * target_size;
int k = 0;
int u = 0;
int m = 0;
for (int i = 0; i < target_size; i++) {
labels[offset_label + i] = 0;
masks[offset_label + i] = 0;
}
for (int i = 0; i < operation_size - 1; i++) {
if (operations[offset + i] == 0) {
break;
} else if (operations[offset + i] == 2) {
continue;
} else if (operations[offset + i] == 1) {
masks[offset_label + m] = 1;
u++;
m++;
} else {
labels[offset_label + k] = u;
masks[offset_label + m] = 0;
k++;
m++;
u = 0;
}
}
}
template <typename scalar_t>
__global__ void levenshtein_distance_kernel(
const scalar_t* __restrict__ source,
const scalar_t* __restrict__ target,
const int* __restrict__ source_length,
const int* __restrict__ target_length,
const size_t source_size,
const size_t target_size,
int* __restrict__ operations,
int* __restrict__ errors_curr) {
const int index = blockIdx.x;
const int offset = index * (source_size + target_size);
const int d = index * (source_size + 1) * (target_size + 1);
const int t = target_size + 1;
auto err_idx = [d, t](int i, int j) { return d + i * t + j; };
auto opt_idx = [offset](int k) { return offset + k; };
const int hyp_len = source_length[index];
const int ref_len = target_length[index];
const scalar_t* hyp_begin = source + index * source_size;
const scalar_t* ref_begin = target + index * target_size;
// dynamic programming
for (int i = 0; i <= hyp_len; i++) {
errors_curr[err_idx(i, 0)] = i;
}
for (int j = 0; j <= ref_len; j++) {
errors_curr[err_idx(0, j)] = j;
}
for (int i = 1; i <= hyp_len; i++) {
for (int j = 1; j <= ref_len; j++) {
errors_curr[err_idx(i, j)] = min(
min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
1,
errors_curr[err_idx(i - 1, j - 1)] +
2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
}
}
// back-tracing
int i = hyp_len;
int j = ref_len;
int o = hyp_len + ref_len;
for (int k = 0; k < source_size + target_size; k++) {
operations[opt_idx(k)] = 0;
}
while ((i >= 0) && (j >= 0)) {
if ((i == 0) && (j == 0)) {
break;
}
if ((j > 0) &&
(errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
o--;
operations[opt_idx(o)] = 1;
j--; // insertion
} else if (
(i > 0) &&
(errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
o--;
operations[opt_idx(o)] = 2;
i--; // deletion
} else {
o--;
operations[opt_idx(o)] = 3;
i--;
j--; // do nothing
}
}
// moving to the left
for (int k = 0; k < hyp_len + ref_len; k++) {
if (k + o < hyp_len + ref_len) {
operations[opt_idx(k)] = operations[opt_idx(k + o)];
} else {
operations[opt_idx(k)] = 0; // padding
}
}
}
template <typename scalar_t>
__global__ void faster_levenshtein_distance_kernel(
const scalar_t* __restrict__ source,
const scalar_t* __restrict__ target,
const int* __restrict__ source_length,
const int* __restrict__ target_length,
const size_t source_size,
const size_t target_size,
int* __restrict__ operations) {
extern __shared__ short errors[];
auto errors_curr = errors;
const int index = blockIdx.x;
const int offset = index * (source_size + target_size);
const int t = target_size + 1;
auto err_idx = [t](int i, int j) { return i * t + j; };
auto opt_idx = [offset](int k) { return offset + k; };
const int hyp_len = source_length[index];
const int ref_len = target_length[index];
const scalar_t* hyp_begin = source + index * source_size;
const scalar_t* ref_begin = target + index * target_size;
// dynamic programming
for (int i = 0; i <= hyp_len; i++) {
errors_curr[err_idx(i, 0)] = i;
}
for (int j = 0; j <= ref_len; j++) {
errors_curr[err_idx(0, j)] = j;
}
for (int i = 1; i <= hyp_len; i++) {
for (int j = 1; j <= ref_len; j++) {
errors_curr[err_idx(i, j)] = min(
min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) +
1,
errors_curr[err_idx(i - 1, j - 1)] +
2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1));
}
}
// back-tracing
int i = hyp_len;
int j = ref_len;
int o = hyp_len + ref_len;
for (int k = 0; k < source_size + target_size; k++) {
operations[opt_idx(k)] = 0;
}
while ((i >= 0) && (j >= 0)) {
if ((i == 0) && (j == 0)) {
break;
}
if ((j > 0) &&
(errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) {
o--;
operations[opt_idx(o)] = 1;
j--; // insertion
} else if (
(i > 0) &&
(errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) {
o--;
operations[opt_idx(o)] = 2;
i--; // deletion
} else {
o--;
operations[opt_idx(o)] = 3;
i--;
j--; // do nothing
}
}
// moving to the left
for (int k = 0; k < hyp_len + ref_len; k++) {
if (k + o < hyp_len + ref_len) {
operations[opt_idx(k)] = operations[opt_idx(k + o)];
} else {
operations[opt_idx(k)] = 0; // padding
}
}
}
torch::Tensor GenerateDeletionLabelCuda(
torch::Tensor source,
torch::Tensor operations) {
const auto batch_size = source.size(0);
at::TensorOptions options(source.device());
options = options.dtype(at::ScalarType::Int);
auto labels = torch::empty({batch_size, source.size(1)}, options);
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] {
generate_deletion_label_kernel<scalar_t>
<<<batch_size, 1, 0, stream>>>(
source.data_ptr<scalar_t>(),
source.size(1),
operations.size(1),
operations.data_ptr<int>(),
labels.data_ptr<int>());
}));
return labels;
}
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
torch::Tensor target,
torch::Tensor operations) {
const auto batch_size = target.size(0);
at::TensorOptions options(target.device());
options = options.dtype(at::ScalarType::Int);
auto labels = torch::empty({batch_size, target.size(1)}, options);
auto masks = torch::empty({batch_size, target.size(1)}, options);
auto stream = at::cuda::getCurrentCUDAStream(target.device().index());
AT_DISPATCH_ALL_TYPES(
target.scalar_type(), "generate_insertion_labels", ([&] {
generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
target.data_ptr<scalar_t>(),
target.size(1),
operations.size(1),
operations.data_ptr<int>(),
labels.data_ptr<int>(),
masks.data_ptr<int>());
}));
return std::make_pair(labels, masks);
}
torch::Tensor LevenshteinDistanceCuda(
torch::Tensor source,
torch::Tensor target,
torch::Tensor source_length,
torch::Tensor target_length) {
const auto batch_size = source.size(0);
const auto shared_size =
(source.size(1) + 1) * (target.size(1) + 1) * sizeof(short);
at::TensorOptions options(source.device());
options = options.dtype(at::ScalarType::Int);
auto operations =
torch::empty({batch_size, source.size(1) + target.size(1)}, options);
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
if (shared_size > 40000) {
auto distances = torch::empty(
{batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options);
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] {
levenshtein_distance_kernel<scalar_t>
<<<batch_size, 1, 0, stream>>>(
source.data_ptr<scalar_t>(),
target.data_ptr<scalar_t>(),
source_length.data_ptr<int>(),
target_length.data_ptr<int>(),
source.size(1),
target.size(1),
operations.data_ptr<int>(),
distances.data_ptr<int>());
}));
} else {
AT_DISPATCH_ALL_TYPES(
source.scalar_type(), "faster_levenshtein_distance", ([&] {
faster_levenshtein_distance_kernel<scalar_t>
<<<batch_size, 1, shared_size, stream>>>(
source.data_ptr<scalar_t>(),
target.data_ptr<scalar_t>(),
source_length.data_ptr<int>(),
target_length.data_ptr<int>(),
source.size(1),
target.size(1),
operations.data_ptr<int>());
}));
}
return operations;
}

View File

@@ -1,25 +0,0 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h>
torch::Tensor LevenshteinDistanceCuda(
torch::Tensor source,
torch::Tensor target,
torch::Tensor source_length,
torch::Tensor target_length);
torch::Tensor GenerateDeletionLabelCuda(
torch::Tensor source,
torch::Tensor operations);
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
torch::Tensor source,
torch::Tensor operations);

View File

@@ -1,4 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -1,19 +0,0 @@
# @package _group_
hydra:
run:
dir: .
defaults:
- _self_
- task: null
- model: null
- criterion: cross_entropy
- optimizer: null
- lr_scheduler: fixed
- bpe: null
- tokenizer: null
- scoring: null
- generation: null
- common_eval: null
- eval_lm: null

View File

@@ -1,36 +0,0 @@
# @package _group_
activation_fn: "relu"
dropout: 0.1
attention_dropout: 0.1
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 512
decoder_output_dim: 512
decoder_input_dim: 512
decoder_ffn_embed_dim: 4096
decoder_layers: 12
decoder_attention_heads: 16
decoder_normalize_before: true
no_decoder_final_norm: true
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -1,36 +0,0 @@
# @package _group_
activation_fn: "relu"
dropout: 0.3
attention_dropout: 0.1
activation_dropout: 0.1
relu_dropout: 0.1
decoder_embed_dim: 1024
decoder_output_dim: 1024
decoder_input_dim: 1024
decoder_ffn_embed_dim: 4096
decoder_layers: 16
decoder_attention_heads: 8
decoder_normalize_before: true
no_decoder_final_norm: true
adaptive_softmax_cutoff: "20000,60000"
adaptive_softmax_dropout: 0.2
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: true
adaptive_input_factor: 4
adaptive_input_cutoff: "20000,60000"
tie_adaptive_weights: true
tie_adaptive_proj: true
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -1,36 +0,0 @@
# @package _group_
activation_fn: "relu"
dropout: 0.1
attention_dropout: 0.0
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 1024
decoder_output_dim: 1024
decoder_input_dim: 1024
decoder_ffn_embed_dim: 4096
decoder_layers: 12
decoder_attention_heads: 16
decoder_normalize_before: true
no_decoder_final_norm: false
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -1,36 +0,0 @@
# @package _group_
activation_fn: "relu"
dropout: 0.1
attention_dropout: 0.1
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 512
decoder_output_dim: 512
decoder_input_dim: 512
decoder_ffn_embed_dim: 4096
decoder_layers: 12
decoder_attention_heads: 16
decoder_normalize_before: true
no_decoder_final_norm: true
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -1,36 +0,0 @@
# @package _group_
activation_fn: "gelu"
dropout: 0.1
attention_dropout: 0.1
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 768
decoder_output_dim: 768
decoder_input_dim: 768
decoder_ffn_embed_dim: 3072
decoder_layers: 12
decoder_attention_heads: 12
decoder_normalize_before: true
no_decoder_final_norm: false
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -1,36 +0,0 @@
# @package _group_
activation_fn: "gelu"
dropout: 0.1
attention_dropout: 0.1
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 1600
decoder_output_dim: 1600
decoder_input_dim: 1600
decoder_ffn_embed_dim: 6400
decoder_layers: 48
decoder_attention_heads: 25
decoder_normalize_before: true
no_decoder_final_norm: false
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -1,36 +0,0 @@
# @package _group_
activation_fn: "gelu"
dropout: 0.1
attention_dropout: 0.1
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 1280
decoder_output_dim: 1280
decoder_input_dim: 1280
decoder_ffn_embed_dim: 5120
decoder_layers: 36
decoder_attention_heads: 20
decoder_normalize_before: true
no_decoder_final_norm: false
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -1,36 +0,0 @@
# @package _group_
activation_fn: "gelu"
dropout: 0.1
attention_dropout: 0.1
activation_dropout: 0.0
relu_dropout: 0.0
decoder_embed_dim: 1024
decoder_output_dim: 1024
decoder_input_dim: 1024
decoder_ffn_embed_dim: 4096
decoder_layers: 24
decoder_attention_heads: 16
decoder_normalize_before: true
no_decoder_final_norm: false
adaptive_softmax_cutoff: null
adaptive_softmax_dropout: 0
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: false
adaptive_input_factor: 4
adaptive_input_cutoff: null
tie_adaptive_weights: false
tie_adaptive_proj: false
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -1,36 +0,0 @@
# @package _group_
activation_fn: "relu"
dropout: 0.3
attention_dropout: 0.1
activation_dropout: 0.1
relu_dropout: 0.1
decoder_embed_dim: 1024
decoder_output_dim: 1024
decoder_input_dim: 1024
decoder_ffn_embed_dim: 4096
decoder_layers: 16
decoder_attention_heads: 8
decoder_normalize_before: true
no_decoder_final_norm: true
adaptive_softmax_cutoff: "20000,60000"
adaptive_softmax_dropout: 0.2
adaptive_softmax_factor: 4
no_token_positional_embeddings: false
share_decoder_input_output_embed: false
character_embeddings: false
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
character_embedding_dim: 4
char_embedder_highway_layers: 2
adaptive_input: true
adaptive_input_factor: 4
adaptive_input_cutoff: "20000,60000"
tie_adaptive_weights: true
tie_adaptive_proj: true
decoder_learned_pos: false
decoder_layerdrop: 0
decoder_layers_to_keep: null
layernorm_embedding: false
no_scale_embedding: false
quant_noise_pq: 0
quant_noise_pq_block_size: 8
quant_noise_scalar: 0

View File

@@ -1,5 +0,0 @@
# @package _group_
activation: gelu
vq_type: gumbel
vq_depth: 2
combine_groups: true

View File

@@ -1,8 +0,0 @@
# @package _group_
quantize_targets: true
final_dim: 256
encoder_layerdrop: 0.05
dropout_input: 0.1
dropout_features: 0.1
feature_grad_mult: 0.1

View File

@@ -1,20 +0,0 @@
# @package _group_
quantize_targets: true
extractor_mode: layer_norm
layer_norm_first: true
final_dim: 768
latent_temp: [2.0,0.1,0.999995]
encoder_layerdrop: 0.0
dropout_input: 0.0
dropout_features: 0.0
dropout: 0.0
attention_dropout: 0.0
conv_bias: true
encoder_layers: 24
encoder_embed_dim: 1024
encoder_ffn_embed_dim: 4096
encoder_attention_heads: 16
feature_grad_mult: 1.0

View File

@@ -1,36 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import importlib
import os
from fairseq import registry
from fairseq.criterions.fairseq_criterion import ( # noqa
FairseqCriterion,
LegacyFairseqCriterion,
)
from omegaconf import DictConfig
(
build_criterion_,
register_criterion,
CRITERION_REGISTRY,
CRITERION_DATACLASS_REGISTRY,
) = registry.setup_registry(
"--criterion", base_class=FairseqCriterion, default="cross_entropy"
)
def build_criterion(cfg: DictConfig, task):
return build_criterion_(cfg, task)
# automatically import any Python files in the criterions/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
file_name = file[: file.find(".py")]
importlib.import_module("fairseq.criterions." + file_name)

View File

@@ -1,123 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.constants import DDP_BACKEND_CHOICES
from omegaconf import II
@dataclass
class AdaptiveLossConfig(FairseqDataclass):
sentence_avg: bool = II("optimization.sentence_avg")
ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend")
@register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig)
class AdaptiveLoss(FairseqCriterion):
"""This is an implementation of the loss function accompanying the adaptive softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
(http://arxiv.org/abs/1609.04309)."""
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
@classmethod
def build_criterion(cls, cfg: AdaptiveLossConfig, task):
if cfg.ddp_backend in {"c10d", "pytorch_ddp"}:
raise Exception(
"AdaptiveLoss is not compatible with the PyTorch "
"version of DistributedDataParallel. Please use "
"`--ddp-backend=legacy_ddp` instead."
)
return cls(task, cfg.sentence_avg)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model.decoder, "adaptive_softmax")
and model.decoder.adaptive_softmax is not None
)
adaptive_softmax = model.decoder.adaptive_softmax
net_output = model(**sample["net_input"])
orig_target = model.get_targets(sample, net_output)
nsentences = orig_target.size(0)
orig_target = orig_target.view(-1)
bsz = orig_target.size(0)
logits, target = adaptive_softmax(net_output[0], orig_target)
assert len(target) == len(logits)
loss = net_output[0].new(1 if reduce else bsz).zero_()
for i in range(len(target)):
if target[i] is not None:
assert target[i].min() >= 0 and target[i].max() <= logits[i].size(1)
loss += F.cross_entropy(
logits[i],
target[i],
ignore_index=self.padding_idx,
reduction="sum" if reduce else "none",
)
orig = utils.strip_pad(orig_target, self.padding_idx)
ntokens = orig.numel()
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
logging_output = {
"loss": loss.data,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -1,100 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq import utils
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
from torch import nn
@register_criterion("composite_loss")
class CompositeLoss(LegacyFairseqCriterion):
"""This is a composite loss that, given a list of model outputs and a list of targets,
computes an average of losses for each output-target pair"""
def __init__(self, args, task):
super().__init__(args, task)
self.underlying_criterion = args.underlying_criterion
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True,
help='underlying criterion to use for the composite loss')
# fmt: on
@staticmethod
def build_underlying_criterion(args, task):
saved_criterion = args.criterion
args.criterion = args.underlying_criterion
assert saved_criterion != args.underlying_criterion
underlying_criterion = task.build_criterion(args)
args.criterion = saved_criterion
return underlying_criterion
@classmethod
def build_criterion(cls, args, task):
underlying_criterion = CompositeLoss.build_underlying_criterion(args, task)
class FakeModel(nn.Module):
def __init__(self, model, net_out, target):
super().__init__()
self.model = model
self.net_out = net_out
self.target = target
def forward(self, **unused):
return self.net_out
def get_normalized_probs(self, net_output, log_probs, sample=None):
return self.model.get_normalized_probs(
net_output, log_probs, sample=sample
)
def get_targets(self, *unused):
return self.target
@property
def decoder(self):
return self.model.decoder
class _CompositeLoss(LegacyFairseqCriterion):
def __init__(self, args, task, underlying_criterion):
super().__init__(args, task)
self.underlying_criterion = underlying_criterion
def forward(self, model, sample, reduce=True):
net_outputs = model(**sample["net_input"])
targets = sample["target"]
bsz = targets[0].size(0)
loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_()
sample_size = 0
logging_output = {}
for o, t in zip(net_outputs[0], targets):
m = FakeModel(model, (o, net_outputs[1]), t)
sample["target"] = t
l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
loss += l
sample_size += ss
loss.div_(len(targets))
sample_size /= len(targets)
logging_output["loss"] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
return underlying_criterion.__class__.aggregate_logging_outputs(
logging_outputs
)
@staticmethod
def reduce_metrics(logging_outputs) -> None:
underlying_criterion.__class__.reduce_metrics(logging_outputs)
return _CompositeLoss(args, task, underlying_criterion)

View File

@@ -1,90 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
@dataclass
class CrossEntropyCriterionConfig(FairseqDataclass):
sentence_avg: bool = II("optimization.sentence_avg")
@register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig)
class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1)
loss = F.nll_loss(
lprobs,
target,
ignore_index=self.padding_idx,
reduction="sum" if reduce else "none",
)
return loss, loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
# we divide by log(2) to convert the loss from base e to base 2
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -1,319 +0,0 @@
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch.nn.functional as F
from omegaconf import II
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.data.data_utils import post_process
from fairseq.dataclass import FairseqDataclass
from fairseq.logging.meters import safe_round
from fairseq.tasks import FairseqTask
@dataclass
class CtcCriterionConfig(FairseqDataclass):
zero_infinity: bool = field(
default=False,
metadata={"help": "zero inf loss when source length <= target length"},
)
sentence_avg: bool = II("optimization.sentence_avg")
post_process: str = field(
default="letter",
metadata={
"help": "how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options"
},
)
wer_kenlm_model: Optional[str] = field(
default=None,
metadata={
"help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
},
)
wer_lexicon: Optional[str] = field(
default=None,
metadata={"help": "lexicon to use with wer_kenlm_model"},
)
wer_lm_weight: float = field(
default=2.0,
metadata={"help": "lm weight to use with wer_kenlm_model"},
)
wer_word_score: float = field(
default=-1.0,
metadata={"help": "lm word score to use with wer_kenlm_model"},
)
wer_args: Optional[str] = field(
default=None,
metadata={
"help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
},
)
@register_criterion("ctc", dataclass=CtcCriterionConfig)
class CtcCriterion(FairseqCriterion):
def __init__(
self, cfg: CtcCriterionConfig, task: FairseqTask, rdrop_alpha: int = 0.0
):
super().__init__(task)
self.blank_idx = (
task.target_dictionary.index(task.blank_symbol)
if hasattr(task, "blank_symbol")
else 0
)
self.pad_idx = task.target_dictionary.pad()
self.eos_idx = task.target_dictionary.eos()
self.post_process = cfg.post_process
self.rdrop_alpha = rdrop_alpha
if cfg.wer_args is not None:
(
cfg.wer_kenlm_model,
cfg.wer_lexicon,
cfg.wer_lm_weight,
cfg.wer_word_score,
) = eval(cfg.wer_args)
if cfg.wer_kenlm_model is not None and cfg.wer_kenlm_model != "":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
dec_args = Namespace()
dec_args.nbest = 1
dec_args.criterion = "ctc"
dec_args.kenlm_model = cfg.wer_kenlm_model
dec_args.lexicon = cfg.wer_lexicon
dec_args.beam = 50
dec_args.beam_size_token = min(50, len(task.target_dictionary))
dec_args.beam_threshold = min(50, len(task.target_dictionary))
dec_args.lm_weight = cfg.wer_lm_weight
dec_args.word_score = cfg.wer_word_score
dec_args.unk_weight = -math.inf
dec_args.sil_weight = 0
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
else:
self.w2l_decoder = None
self.zero_infinity = cfg.zero_infinity
self.sentence_avg = cfg.sentence_avg
def forward(self, model, sample, reduce=True, **kwargs):
net_output = model(**sample["net_input"])
lprobs = model.get_normalized_probs(
net_output, log_probs=True
).contiguous() # (T, B, C) from the encoder
# CTC loss is calculated over duplicated inputs
# sample is already duplicated for R-Drop
if self.rdrop_alpha > 0:
for k, v in sample.items():
if k in ["target", "target_lengths"]:
sample[k] = torch.cat([v, v.clone()], dim=0)
elif k == "net_input":
if sample[k]["src_tokens"].size(1) != sample[k]["src_lengths"].size(
0
):
# for decoder CTC loss
sample[k]["src_lengths"] = torch.cat(
[
sample[k]["src_lengths"],
sample[k]["src_lengths"].clone(),
],
dim=0,
)
if "src_lengths" in sample["net_input"]:
input_lengths = sample["net_input"]["src_lengths"]
else:
if net_output["padding_mask"] is not None:
non_padding_mask = ~net_output["padding_mask"]
input_lengths = non_padding_mask.long().sum(-1)
else:
input_lengths = lprobs.new_full(
(lprobs.size(1),), lprobs.size(0), dtype=torch.long
)
pad_mask = (sample["target"] != self.pad_idx) & (
sample["target"] != self.eos_idx
)
targets_flat = sample["target"].masked_select(pad_mask)
if "target_lengths" in sample:
target_lengths = sample["target_lengths"]
else:
target_lengths = pad_mask.sum(-1)
with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss(
lprobs,
targets_flat,
input_lengths,
target_lengths,
blank=self.blank_idx,
reduction="sum",
zero_infinity=self.zero_infinity,
)
ntokens = (
sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
)
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
logging_output = {
"loss": utils.item(loss.data), # * sample['ntokens'],
"ntokens": ntokens,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
}
if not model.training:
import editdistance
with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
c_err = 0
c_len = 0
w_errs = 0
w_len = 0
wv_errs = 0
for lp, t, inp_l in zip(
lprobs_t,
sample["target_label"]
if "target_label" in sample
else sample["target"],
input_lengths,
):
lp = lp[:inp_l].unsqueeze(0)
decoded = None
if self.w2l_decoder is not None:
decoded = self.w2l_decoder.decode(lp)
if len(decoded) < 1:
decoded = None
else:
decoded = decoded[0]
if len(decoded) < 1:
decoded = None
else:
decoded = decoded[0]
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
)
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
if decoded is not None and "words" in decoded:
pred_words = decoded["words"]
w_errs += editdistance.eval(pred_words, targ_words)
wv_errs += editdistance.eval(pred_words_raw, targ_words)
else:
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist
w_len += len(targ_words)
logging_output["wv_errors"] = wv_errs
logging_output["w_errors"] = w_errs
logging_output["w_total"] = w_len
logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
metrics.log_scalar("_c_errors", c_errors)
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
metrics.log_scalar("_c_total", c_total)
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_w_errors", w_errors)
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
metrics.log_scalar("_wv_errors", wv_errors)
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
metrics.log_scalar("_w_total", w_total)
if c_total > 0:
metrics.log_derived(
"uer",
lambda meters: safe_round(
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
)
if meters["_c_total"].sum > 0
else float("nan"),
)
if w_total > 0:
metrics.log_derived(
"wer",
lambda meters: safe_round(
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
metrics.log_derived(
"raw_wer",
lambda meters: safe_round(
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -1,120 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
from typing import Any, Dict, List
from fairseq import metrics, utils
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import gen_parser_from_dataclass
from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss):
def __init__(self, task):
super().__init__()
self.task = task
if hasattr(task, "target_dictionary"):
tgt_dict = task.target_dictionary
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
@classmethod
def add_args(cls, parser):
"""Add criterion-specific arguments to the parser."""
dc = getattr(cls, "__dataclass", None)
if dc is not None:
gen_parser_from_dataclass(parser, dc())
@classmethod
def build_criterion(cls, cfg: FairseqDataclass, task):
"""Construct a criterion from command-line args."""
# arguments in the __init__.
init_args = {}
for p in inspect.signature(cls).parameters.values():
if (
p.kind == p.POSITIONAL_ONLY
or p.kind == p.VAR_POSITIONAL
or p.kind == p.VAR_KEYWORD
):
# we haven't implemented inference for these argument types,
# but PRs welcome :)
raise NotImplementedError("{} not supported".format(p.kind))
assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
if p.name == "task":
init_args["task"] = task
elif p.name == "cfg":
init_args["cfg"] = cfg
elif hasattr(cfg, p.name):
init_args[p.name] = getattr(cfg, p.name)
elif p.default != p.empty:
pass # we'll use the default value
else:
raise NotImplementedError(
"Unable to infer Criterion arguments, please implement "
"{}.build_criterion".format(cls.__name__)
)
return cls(**init_args)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
raise NotImplementedError
@staticmethod
def aggregate_logging_outputs(
logging_outputs: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
"The aggregate_logging_outputs API is deprecated. "
"Please use the reduce_metrics API instead."
)
raise NotImplementedError
@classmethod
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
"""Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
"Criterions should implement the reduce_metrics API. "
"Falling back to deprecated aggregate_logging_outputs API."
)
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
for k, v in agg_logging_outputs.items():
if k in {"nsentences", "ntokens", "sample_size"}:
continue
metrics.log_scalar(k, v)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
class LegacyFairseqCriterion(FairseqCriterion):
def __init__(self, args, task):
super().__init__(task=task)
self.args = args
utils.deprecation_warning(
"Criterions should take explicit arguments instead of an "
"argparse.Namespace object, please update your criterion by "
"extending FairseqCriterion instead of LegacyFairseqCriterion."
)
@classmethod
def build_criterion(cls, args, task):
"""Construct a criterion from command-line args."""
return cls(args, task)

View File

@@ -1,136 +0,0 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from typing import List, Dict, Any
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.data.data_utils import lengths_to_mask
from fairseq.models.fairseq_model import FairseqEncoderModel
@dataclass
class FastSpeech2CriterionConfig(FairseqDataclass):
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig)
class FastSpeech2Loss(FairseqCriterion):
def __init__(self, task, ctc_weight):
super().__init__(task)
self.ctc_weight = ctc_weight
def forward(self, model: FairseqEncoderModel, sample, reduction="mean"):
src_tokens = sample["net_input"]["src_tokens"]
src_lens = sample["net_input"]["src_lengths"]
tgt_lens = sample["target_lengths"]
_feat_out, _feat_out_post, _, log_dur_out, pitch_out, energy_out = model(
src_tokens=src_tokens,
src_lengths=src_lens,
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
incremental_state=None,
target_lengths=tgt_lens,
speaker=sample["speaker"],
durations=sample["durations"],
pitches=sample["pitches"],
energies=sample["energies"],
)
src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
tgt_mask = lengths_to_mask(sample["target_lengths"])
pitches, energies = sample["pitches"], sample["energies"]
pitch_out, pitches = pitch_out[src_mask], pitches[src_mask]
energy_out, energies = energy_out[src_mask], energies[src_mask]
feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask]
l1_loss = F.l1_loss(feat_out, feat, reduction=reduction)
if _feat_out_post is not None:
l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction)
pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)
log_dur_out = log_dur_out[src_mask]
dur = sample["durations"].float()
dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur
log_dur = torch.log(dur + 1)[src_mask]
dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
if self.ctc_weight > 0.0:
lprobs = model.get_normalized_probs((_feat_out,), log_probs=True)
lprobs = lprobs.transpose(0, 1) # T x B x C
src_mask = lengths_to_mask(src_lens)
src_tokens_flat = src_tokens.masked_select(src_mask)
ctc_loss = (
F.ctc_loss(
lprobs,
src_tokens_flat,
tgt_lens,
src_lens,
reduction=reduction,
zero_infinity=True,
)
* self.ctc_weight
)
loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss
sample_size = sample["nsentences"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"l1_loss": utils.item(l1_loss.data),
"dur_loss": utils.item(dur_loss.data),
"pitch_loss": utils.item(pitch_loss.data),
"energy_loss": utils.item(energy_loss.data),
"ctc_loss": utils.item(ctc_loss.data),
}
return loss, sample_size, logging_output
@classmethod
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
ns = [log.get("sample_size", 0) for log in logging_outputs]
ntot = sum(ns)
ws = [n / (ntot + 1e-8) for n in ns]
for key in [
"loss",
"l1_loss",
"dur_loss",
"pitch_loss",
"energy_loss",
"ctc_loss",
]:
vals = [log.get(key, 0) for log in logging_outputs]
val = sum(val * w for val, w in zip(vals, ws))
metrics.log_scalar(key, val, ntot, round=3)
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
# inference metrics
if "targ_frames" not in logging_outputs[0]:
return
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
return False

View File

@@ -1,194 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import re
from dataclasses import dataclass, field
from typing import List, Optional
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
@dataclass
class HubertCriterionConfig(FairseqDataclass):
pred_masked_weight: float = field(
default=1.0,
metadata={"help": "weight for predictive loss for masked frames"},
)
pred_nomask_weight: float = field(
default=0.0,
metadata={"help": "weight for predictive loss for unmasked frames"},
)
loss_weights: Optional[List[float]] = field(
default=None,
metadata={"help": "weights for additional loss terms (not first one)"},
)
log_keys: List[str] = field(
default_factory=lambda: [],
metadata={"help": "output keys to log"},
)
@register_criterion("hubert", dataclass=HubertCriterionConfig)
class HubertCriterion(FairseqCriterion):
def __init__(
self,
task,
pred_masked_weight,
pred_nomask_weight,
loss_weights=None,
log_keys=None,
):
super().__init__(task)
self.pred_masked_weight = pred_masked_weight
self.pred_nomask_weight = pred_nomask_weight
self.loss_weights = loss_weights
self.log_keys = [] if log_keys is None else log_keys
def forward(self, model, sample, reduce=True, log_pred=False):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(target_list=sample["target_list"], **sample["net_input"])
loss = 0.0
sample_size = 0
logging_output = {}
reduction = "sum" if reduce else "none"
loss_m_list = []
logp_m_list = model.get_logits(net_output, True)
targ_m_list = model.get_targets(net_output, True)
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
loss_m_list.append(loss_m)
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
if self.pred_masked_weight > 0:
loss += self.pred_masked_weight * sum(loss_m_list)
sample_size += targ_m_list[0].numel()
loss_u_list = []
logp_u_list = model.get_logits(net_output, False)
targ_u_list = model.get_targets(net_output, False)
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
loss_u_list.append(loss_u)
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
if self.pred_nomask_weight > 0:
loss += self.pred_nomask_weight * sum(loss_u_list)
sample_size += targ_u_list[0].numel()
if self.loss_weights is not None:
assert hasattr(model, "get_extra_losses")
extra_losses, names = model.get_extra_losses(net_output)
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
logging_output[f"loss_{n}"] = p.item()
logging_output = {
"loss": loss.item() if reduce else loss,
"ntokens": sample_size,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
**logging_output,
}
for lk in self.log_keys:
if lk in net_output:
logging_output[lk] = float((net_output[lk]))
def compute_correct(logits):
if logits.numel() == 0:
return 0, 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
with torch.no_grad():
for i, logp_m in enumerate(logp_m_list):
corr_m, count_m = compute_correct(logp_m)
logging_output[f"correct_m_{i}"] = corr_m
logging_output[f"count_m_{i}"] = count_m
for i, logp_u in enumerate(logp_u_list):
corr_u, count_u = compute_correct(logp_u)
logging_output[f"correct_u_{i}"] = corr_u
logging_output[f"count_u_{i}"] = count_u
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
counts = {}
for lk in logging_outputs[0].keys():
if lk.startswith("count_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val)
counts[lk] = val
for lk in logging_outputs[0].keys():
if lk.startswith("loss_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
elif lk.startswith("correct_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
raise NotImplementedError()
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False

View File

@@ -1,168 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
import torch
from omegaconf import II
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
@dataclass
class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
label_smoothing: float = field(
default=0.0,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
)
report_accuracy: bool = field(
default=False,
metadata={"help": "report accuracy metric"},
)
ignore_prefix_size: int = field(
default=0,
metadata={"help": "Ignore first N tokens"},
)
sentence_avg: bool = II("optimization.sentence_avg")
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = epsilon / (lprobs.size(-1) - 1)
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
@register_criterion(
"label_smoothed_cross_entropy", dataclass=LabelSmoothedCrossEntropyCriterionConfig
)
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=0,
report_accuracy=False,
):
super().__init__(task)
self.sentence_avg = sentence_avg
self.eps = label_smoothing
self.ignore_prefix_size = ignore_prefix_size
self.report_accuracy = report_accuracy
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
return loss, sample_size, logging_output
def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = model.get_targets(sample, net_output)
if self.ignore_prefix_size > 0:
# lprobs: B x T x C
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
return loss, nll_loss
def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
)
total = torch.sum(mask)
return n_correct, total
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0:
metrics.log_scalar("total", total)
n_correct = utils.item(
sum(log.get("n_correct", 0) for log in logging_outputs)
)
metrics.log_scalar("n_correct", n_correct)
metrics.log_derived(
"accuracy",
lambda meters: round(
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
)
if meters["total"].sum > 0
else float("nan"),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -1,220 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import torch
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
LabelSmoothedCrossEntropyCriterionConfig,
)
try:
from simuleval.metrics.latency import (
AverageLagging,
AverageProportion,
DifferentiableAverageLagging,
)
LATENCY_METRICS = {
"average_lagging": AverageLagging,
"average_proportion": AverageProportion,
"differentiable_average_lagging": DifferentiableAverageLagging,
}
except ImportError:
LATENCY_METRICS = None
@dataclass
class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig(
LabelSmoothedCrossEntropyCriterionConfig
):
latency_avg_weight: float = field(
default=0.0,
metadata={"help": "weight fot average latency loss."},
)
latency_var_weight: float = field(
default=0.0,
metadata={"help": "weight fot variance latency loss."},
)
latency_avg_type: str = field(
default="differentiable_average_lagging",
metadata={"help": "latency type for average loss"},
)
latency_var_type: str = field(
default="variance_delay",
metadata={"help": "latency typ for variance loss"},
)
latency_gather_method: str = field(
default="weighted_average",
metadata={"help": "method to gather latency loss for all heads"},
)
latency_update_after: int = field(
default=0,
metadata={"help": "Add latency loss after certain steps"},
)
@register_criterion(
"latency_augmented_label_smoothed_cross_entropy",
dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig,
)
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
LabelSmoothedCrossEntropyCriterion
):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size,
report_accuracy,
latency_avg_weight,
latency_var_weight,
latency_avg_type,
latency_var_type,
latency_gather_method,
latency_update_after,
):
super().__init__(
task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy
)
assert LATENCY_METRICS is not None, "Please make sure SimulEval is installed."
self.latency_avg_weight = latency_avg_weight
self.latency_var_weight = latency_var_weight
self.latency_avg_type = latency_avg_type
self.latency_var_type = latency_var_type
self.latency_gather_method = latency_gather_method
self.latency_update_after = latency_update_after
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
# 1. Compute cross entropy loss
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
# 2. Compute cross latency loss
latency_loss, expected_latency, expected_delays_var = self.compute_latency_loss(
model, sample, net_output
)
if self.latency_update_after > 0:
num_updates = getattr(model.decoder, "num_updates", None)
assert (
num_updates is not None
), "model.decoder doesn't have attribute 'num_updates'"
if num_updates <= self.latency_update_after:
latency_loss = 0
loss += latency_loss
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
"latency": expected_latency,
"delays_var": expected_delays_var,
"latency_loss": latency_loss,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
return loss, sample_size, logging_output
def compute_latency_loss(self, model, sample, net_output):
assert (
net_output[-1].encoder_padding_mask is None
or not net_output[-1].encoder_padding_mask[:, 0].any()
), "Only right padding on source is supported."
# 1. Obtain the expected alignment
alpha_list = [item["alpha"] for item in net_output[1].attn_list]
num_layers = len(alpha_list)
bsz, num_heads, tgt_len, src_len = alpha_list[0].size()
# bsz * num_layers * num_heads, tgt_len, src_len
alpha_all = torch.cat(alpha_list, dim=1).view(-1, tgt_len, src_len)
# 2 compute expected delays
# bsz * num_heads * num_layers, tgt_len, src_len for MMA
steps = (
torch.arange(1, 1 + src_len)
.unsqueeze(0)
.unsqueeze(1)
.expand_as(alpha_all)
.type_as(alpha_all)
)
expected_delays = torch.sum(steps * alpha_all, dim=-1)
target_padding_mask = (
model.get_targets(sample, net_output)
.eq(self.padding_idx)
.unsqueeze(1)
.expand(bsz, num_layers * num_heads, tgt_len)
.contiguous()
.view(-1, tgt_len)
)
src_lengths = (
sample["net_input"]["src_lengths"]
.unsqueeze(1)
.expand(bsz, num_layers * num_heads)
.contiguous()
.view(-1)
)
expected_latency = LATENCY_METRICS[self.latency_avg_type](
expected_delays, src_lengths, None, target_padding_mask=target_padding_mask
)
# 2.1 average expected latency of heads
# bsz, num_layers * num_heads
expected_latency = expected_latency.view(bsz, -1)
if self.latency_gather_method == "average":
# bsz * tgt_len
expected_latency = expected_delays.mean(dim=1)
elif self.latency_gather_method == "weighted_average":
weights = torch.nn.functional.softmax(expected_latency, dim=1)
expected_latency = torch.sum(expected_latency * weights, dim=1)
elif self.latency_gather_method == "max":
expected_latency = expected_latency.max(dim=1)[0]
else:
raise NotImplementedError
expected_latency = expected_latency.sum()
avg_loss = self.latency_avg_weight * expected_latency
# 2.2 variance of expected delays
expected_delays_var = (
expected_delays.view(bsz, -1, tgt_len).var(dim=1).mean(dim=1)
)
expected_delays_var = expected_delays_var.sum()
var_loss = self.latency_avg_weight * expected_delays_var
# 3. Final loss
latency_loss = avg_loss + var_loss
return latency_loss, expected_latency, expected_delays_var
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
super().reduce_metrics(logging_outputs)
latency = sum(log.get("latency", 0) for log in logging_outputs)
delays_var = sum(log.get("delays_var", 0) for log in logging_outputs)
latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3)
metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3)
metrics.log_scalar(
"latency_loss", latency_loss / nsentences, nsentences, round=3
)

View File

@@ -1,130 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from .label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
LabelSmoothedCrossEntropyCriterionConfig,
)
from dataclasses import dataclass, field
@dataclass
class LabelSmoothedCrossEntropyCriterionWithAlignmentConfig(
LabelSmoothedCrossEntropyCriterionConfig
):
alignment_lambda: float = field(
default=0.05, metadata={"help": "weight for the alignment loss"}
)
@register_criterion(
"label_smoothed_cross_entropy_with_alignment",
dataclass=LabelSmoothedCrossEntropyCriterionWithAlignmentConfig,
)
class LabelSmoothedCrossEntropyCriterionWithAlignment(
LabelSmoothedCrossEntropyCriterion
):
def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda):
super().__init__(task, sentence_avg, label_smoothing)
self.alignment_lambda = alignment_lambda
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
alignment_loss = None
# Compute alignment loss only for training set and non dummy batches.
if "alignments" in sample and sample["alignments"] is not None:
alignment_loss = self.compute_alignment_loss(sample, net_output)
if alignment_loss is not None:
logging_output["alignment_loss"] = utils.item(alignment_loss.data)
loss += self.alignment_lambda * alignment_loss
return loss, sample_size, logging_output
def compute_alignment_loss(self, sample, net_output):
attn_prob = net_output[1]["attn"][0]
bsz, tgt_sz, src_sz = attn_prob.shape
attn = attn_prob.view(bsz * tgt_sz, src_sz)
align = sample["alignments"]
align_weights = sample["align_weights"].float()
if len(align) > 0:
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
loss = -(
(attn[align[:, 1][:, None], align[:, 0][:, None]]).log()
* align_weights[:, None]
).sum()
else:
return None
return loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
nll_loss_sum = utils.item(
sum(log.get("nll_loss", 0) for log in logging_outputs)
)
alignment_loss_sum = utils.item(
sum(log.get("alignment_loss", 0) for log in logging_outputs)
)
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_scalar(
"alignment_loss",
alignment_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -1,96 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
LabelSmoothedCrossEntropyCriterionConfig,
)
from fairseq.data.data_utils import lengths_to_mask
@dataclass
class LabelSmoothedCrossEntropyWithCtcCriterionConfig(
LabelSmoothedCrossEntropyCriterionConfig
):
ctc_weight: float = field(default=1.0, metadata={"help": "weight for CTC loss"})
@register_criterion(
"label_smoothed_cross_entropy_with_ctc",
dataclass=LabelSmoothedCrossEntropyWithCtcCriterionConfig,
)
class LabelSmoothedCrossEntropyWithCtcCriterion(LabelSmoothedCrossEntropyCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size,
report_accuracy,
ctc_weight,
):
super().__init__(
task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy
)
self.ctc_weight = ctc_weight
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
ctc_loss = torch.tensor(0.0).type_as(loss)
if self.ctc_weight > 0.0:
ctc_lprobs, ctc_lens = model.get_ctc_output(net_output, sample)
ctc_tgt, ctc_tgt_lens = model.get_ctc_target(sample)
ctc_tgt_mask = lengths_to_mask(ctc_tgt_lens)
ctc_tgt_flat = ctc_tgt.masked_select(ctc_tgt_mask)
reduction = "sum" if reduce else "none"
ctc_loss = (
F.ctc_loss(
ctc_lprobs,
ctc_tgt_flat,
ctc_lens,
ctc_tgt_lens,
reduction=reduction,
zero_infinity=True,
)
* self.ctc_weight
)
loss += ctc_loss
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data),
"nll_loss": utils.item(nll_loss.data),
"ctc_loss": utils.item(ctc_loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
return loss, sample_size, logging_output
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
super().reduce_metrics(logging_outputs)
loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"ctc_loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)

View File

@@ -1,176 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
import torch
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
LabelSmoothedCrossEntropyCriterionConfig,
label_smoothed_nll_loss,
)
@dataclass
class RdropLabelSmoothedCrossEntropyCriterionConfig(
LabelSmoothedCrossEntropyCriterionConfig
):
rdrop_alpha: float = field(
default=0.0,
metadata={"help": "alpha for r-drop, 0 means no r-drop"},
)
@register_criterion(
"label_smoothed_cross_entropy_with_rdrop",
dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig,
)
class RdropLabelSmoothedCrossEntropyCriterion(LabelSmoothedCrossEntropyCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=0,
report_accuracy=False,
rdrop_alpha=0.0,
):
super().__init__(
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=ignore_prefix_size,
report_accuracy=report_accuracy,
)
self.sentence_avg = sentence_avg
self.eps = label_smoothing
self.ignore_prefix_size = ignore_prefix_size
self.report_accuracy = report_accuracy
self.rdrop_alpha = rdrop_alpha
def forward(self, model, sample, reduce=True, net_output=None):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
if net_output is None:
if self.rdrop_alpha > 0 and sample["net_input"]["src_tokens"].size(
0
) == sample["target"].size(0):
sample = duplicate_input(sample)
net_output = model(**sample["net_input"])
loss, nll_loss, rdrop_kl_loss = self.compute_loss(
model, net_output, sample, reduce=reduce
)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
if self.rdrop_alpha > 0:
logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
return loss, sample_size, logging_output
def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = model.get_targets(sample, net_output)
if self.rdrop_alpha > 0 or target.size(0) != lprobs.size(0):
target = torch.cat([target, target.clone()], dim=0)
if self.ignore_prefix_size > 0:
# lprobs: B x T x C
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
if self.rdrop_alpha > 0:
pad_mask = target[: target.size(0) // 2].unsqueeze(-1).eq(self.padding_idx)
rdrop_kl_loss = compute_kl_loss(model, net_output, pad_mask)
loss += self.rdrop_alpha * rdrop_kl_loss
else:
rdrop_kl_loss = loss.new_zeros(1)
return loss, nll_loss, rdrop_kl_loss
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
super().reduce_metrics(logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
rdrop_kl_loss = utils.item(
sum(log.get("rdrop_kl_loss", 0) for log in logging_outputs)
/ sample_size
/ math.log(2)
)
if rdrop_kl_loss > 0:
metrics.log_scalar("rdrop_kl_loss", rdrop_kl_loss)
def duplicate_input(sample):
if "net_input" in sample.keys():
sample_input = sample["net_input"]
else:
sample_input = sample
for k, v in sample_input.items():
if isinstance(v, torch.Tensor):
sample_input[k] = torch.cat([v, v.clone()], dim=0)
if "net_input" in sample.keys():
sample["net_input"] = sample_input
else:
sample = sample_input
return sample
def compute_kl_loss(model, net_output, pad_mask=None, reduce=True):
net_prob = model.get_normalized_probs(net_output, log_probs=True)
net_prob_tec = model.get_normalized_probs(net_output, log_probs=False)
net_prob = net_prob.view(-1, net_prob.size(-1))
net_prob_tec = net_prob_tec.view(-1, net_prob_tec.size(-1))
p, q = torch.split(net_prob, net_prob.size(0) // 2, dim=0)
p_tec, q_tec = torch.split(net_prob_tec, net_prob_tec.size(0) // 2, dim=0)
p_loss = torch.nn.functional.kl_div(p, q_tec, reduction="none")
q_loss = torch.nn.functional.kl_div(q, p_tec, reduction="none")
if pad_mask is not None:
p_loss.masked_fill_(pad_mask, 0.0)
q_loss.masked_fill_(pad_mask, 0.0)
if reduce:
p_loss = p_loss.sum()
q_loss = q_loss.sum()
loss = (p_loss + q_loss) / 2
return loss

View File

@@ -1,177 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
"""
Function to compute the cross entropy loss. The default value of
ignore_index is the same as the default value for F.cross_entropy in
pytorch.
"""
assert logits.size(0) == targets.size(
-1
), "Logits and Targets tensor shapes don't match up"
loss = F.nll_loss(
F.log_softmax(logits, -1, dtype=torch.float32),
targets,
reduction="sum",
ignore_index=ignore_index,
)
return loss
@register_criterion("legacy_masked_lm_loss")
class LegacyMaskedLmLoss(FairseqCriterion):
"""
Implementation for the loss used in masked language model (MLM) training.
This optionally also computes the next sentence prediction (NSP) loss and
adds it to the overall loss based on the specified args. There are three
cases to consider:
1) Generic MLM training without NSP loss. In this case sentence_targets
and sentence_logits are both None.
2) BERT training without NSP loss. In this case sentence_targets is
not None but sentence_logits is None and we should not be computing
a sentence level loss.
3) BERT training with NSP loss. In this case both sentence_targets and
sentence_logits are not None and we should be computing a sentence
level loss. The weight of the sentence level loss is specified as
an argument.
"""
def __init__(self, task, masked_lm_only, nsp_loss_weight):
super().__init__(task)
self.masked_lm_only = masked_lm_only
self.nsp_loss_weight = nsp_loss_weight
@staticmethod
def add_args(parser):
"""Args for MaskedLM Loss"""
# Default for masked_lm_only is False so as to not break BERT training
parser.add_argument(
"--masked-lm-only",
default=False,
action="store_true",
help="compute MLM loss only",
)
parser.add_argument(
"--nsp-loss-weight",
default=1.0,
type=float,
help="weight for next sentence prediction" " loss (default 1)",
)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
lm_logits, output_metadata = model(**sample["net_input"])
# reshape lm_logits from (N,T,C) to (N*T,C)
lm_logits = lm_logits.view(-1, lm_logits.size(-1))
lm_targets = sample["lm_target"].view(-1)
lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets, self.padding_idx)
# compute the number of tokens for which loss is computed. This is used
# to normalize the loss
ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
loss = lm_loss / ntokens
nsentences = sample["nsentences"]
# nsentences = 0
# Compute sentence loss if masked_lm_only is False
sentence_loss = None
if not self.masked_lm_only:
sentence_logits = output_metadata["sentence_logits"]
sentence_targets = sample["sentence_target"].view(-1)
# This needs to be recomputed due to some differences between
# TokenBlock and BlockPair dataset. This can be resolved with a
# refactor of BERTModel which we will do in the future.
# TODO: Remove this after refactor of BERTModel
nsentences = sentence_targets.size(0)
# Check for logits being none which can happen when remove_heads
# is set to true in the BERT model. Ideally we should set
# masked_lm_only to true in this case, but that requires some
# refactor in the BERT model.
if sentence_logits is not None:
sentence_loss = compute_cross_entropy_loss(
sentence_logits, sentence_targets
)
loss += self.nsp_loss_weight * (sentence_loss / nsentences)
# NOTE: as we are summing up per token mlm loss and per sentence nsp loss
# we don't need to use sample_size as denominator for the gradient
# here sample_size is just used for logging
sample_size = 1
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"lm_loss": utils.item(lm_loss.data) if reduce else lm_loss.data,
# sentence loss is not always computed
"sentence_loss": (
(utils.item(sentence_loss.data) if reduce else sentence_loss.data)
if sentence_loss is not None
else 0.0
),
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
lm_loss_sum = sum(log.get("lm_loss", 0) for log in logging_outputs)
sentence_loss_sum = sum(log.get("sentence_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_loss = sum(log.get("loss", 0) for log in logging_outputs)
metrics.log_scalar(
"loss",
agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.0,
sample_size,
round=3,
)
metrics.log_scalar(
"lm_loss",
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
ntokens,
round=3,
)
metrics.log_scalar(
"sentence_loss",
sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.0,
nsentences,
round=3,
)
metrics.log_scalar(
"nll_loss",
lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0,
ntokens,
round=3,
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -1,98 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import math
from omegaconf import II
import torch
from fairseq import metrics, modules, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
@dataclass
class MaskedLmConfig(FairseqDataclass):
tpu: bool = II("common.tpu")
@register_criterion("masked_lm", dataclass=MaskedLmConfig)
class MaskedLmLoss(FairseqCriterion):
"""
Implementation for the loss used in masked language model (MLM) training.
"""
def __init__(self, cfg: MaskedLmConfig, task):
super().__init__(task)
self.tpu = cfg.tpu
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
masked_tokens = sample["target"].ne(self.padding_idx)
sample_size = masked_tokens.int().sum()
# Rare: when all tokens are masked, project all tokens.
# We use torch.where to avoid device-to-host transfers,
# except on CPU where torch.where is not well supported
# (see github.com/pytorch/pytorch/issues/26247).
if self.tpu:
masked_tokens = None # always project all tokens on TPU
elif masked_tokens.device == torch.device("cpu"):
if not masked_tokens.any():
masked_tokens = None
else:
masked_tokens = torch.where(
masked_tokens.any(),
masked_tokens,
masked_tokens.new([True]),
)
logits = model(**sample["net_input"], masked_tokens=masked_tokens)[0]
targets = model.get_targets(sample, [logits])
if masked_tokens is not None:
targets = targets[masked_tokens]
loss = modules.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
reduction="sum",
ignore_index=self.padding_idx,
)
logging_output = {
"loss": loss if self.tpu else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -1,155 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from typing import Dict, List
import torch
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
logger = logging.getLogger(__name__)
@dataclass
class ModelCriterionConfig(FairseqDataclass):
loss_weights: Dict[str, float] = field(
default_factory=dict,
metadata={"help": "weights for the loss terms"},
)
log_keys: List[str] = field(
default_factory=list,
metadata={"help": "additional output keys to log"},
)
@register_criterion("model", dataclass=ModelCriterionConfig)
class ModelCriterion(FairseqCriterion):
"""
This criterion relies on the model to supply losses.
The losses should be a dictionary of name -> scalar returned by
the model either by including it in the net_output dict or by
implementing a get_losses(net_output, sample) method. The final loss is
a scaled sum of all losses according to weights in loss_weights.
If no weights are provided, then all losses are scaled by 1.0.
The losses will be automatically logged. Additional keys from
net_output dict can be logged via the log_keys parameter.
"""
def __init__(self, task, loss_weights=None, log_keys=None):
super().__init__(task)
self.loss_weights = loss_weights
self.log_keys = log_keys
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
scaled_losses = {}
if hasattr(model, "get_losses"):
losses = model.get_losses(net_output, sample)
elif isinstance(net_output, dict) and "losses" in net_output:
losses = net_output["losses"]
else:
raise Exception("Could not retrieve losses")
for lk, p in losses.items():
try:
coef = 1.0 if len(self.loss_weights) == 0 else self.loss_weights[lk]
except KeyError:
logger.error(
f"weight for loss {lk} is not in loss_weights ({self.loss_weights})"
)
raise
if coef != 0 and p is not None:
scaled_losses[lk] = coef * p.float()
loss = sum(scaled_losses.values())
if "sample_size" in net_output:
sample_size = net_output["sample_size"]
else:
sample_size = loss.numel()
if reduce and loss.numel() > 1:
loss = loss.sum()
logging_output = {
"loss": loss.data,
"ntokens": sample_size,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
"_world_size": 1,
}
for lk in self.log_keys:
if lk in net_output and net_output[lk] is not None:
if not torch.is_tensor(net_output[lk]) or net_output[lk].numel() == 1:
logging_output[lk] = float(net_output[lk])
else:
for i, v in enumerate(net_output[lk]):
logging_output[f"{lk}_{i}"] = float(v)
if len(scaled_losses) > 1:
for lk, l in scaled_losses.items():
if l.numel() > 1:
l = l.sum()
logging_output[f"loss_{lk}"] = l.item()
if "logs" in net_output:
for lgw in net_output["logs"]:
logging_output[lgw] = net_output["logs"][lgw]
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
builtin_keys = {
"loss",
"ntokens",
"nsentences",
"sample_size",
"_world_size",
}
world_size = utils.item(
sum(log.get("_world_size", 0) for log in logging_outputs)
)
for k in logging_outputs[0]:
if k not in builtin_keys:
val = sum(log.get(k, 0) for log in logging_outputs)
if k.startswith("loss_"):
metrics.log_scalar(k, val / sample_size, sample_size, round=3)
else:
metrics.log_scalar(k, val / world_size, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -1,180 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from torch import Tensor
from dataclasses import dataclass, field
@dataclass
class LabelSmoothedDualImitationCriterionConfig(FairseqDataclass):
label_smoothing: float = field(
default=0.0,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
)
@register_criterion("nat_loss", dataclass=LabelSmoothedDualImitationCriterionConfig)
class LabelSmoothedDualImitationCriterion(FairseqCriterion):
def __init__(self, task, label_smoothing):
super().__init__(task)
self.label_smoothing = label_smoothing
def _compute_loss(
self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0
):
"""
outputs: batch x len x d_model
targets: batch x len
masks: batch x len
policy_logprob: if there is some policy
depends on the likelihood score as rewards.
"""
def mean_ds(x: Tensor, dim=None) -> Tensor:
return (
x.float().mean().type_as(x)
if dim is None
else x.float().mean(dim).type_as(x)
)
if masks is not None:
outputs, targets = outputs[masks], targets[masks]
if masks is not None and not masks.any():
nll_loss = torch.tensor(0)
loss = nll_loss
else:
logits = F.log_softmax(outputs, dim=-1)
if targets.dim() == 1:
losses = F.nll_loss(logits, targets.to(logits.device), reduction="none")
else: # soft-labels
losses = F.kl_div(logits, targets.to(logits.device), reduction="none")
losses = losses.sum(-1)
nll_loss = mean_ds(losses)
if label_smoothing > 0:
loss = (
nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing
)
else:
loss = nll_loss
loss = loss * factor
return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor}
def _custom_loss(self, loss, name="loss", factor=1.0):
return {"name": name, "loss": loss, "factor": factor}
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
nsentences, ntokens = sample["nsentences"], sample["ntokens"]
# B x T
src_tokens, src_lengths = (
sample["net_input"]["src_tokens"],
sample["net_input"]["src_lengths"],
)
tgt_tokens, prev_output_tokens = sample["target"], sample["prev_target"]
outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens)
losses, nll_loss = [], []
for obj in outputs:
if outputs[obj].get("loss", None) is None:
_losses = self._compute_loss(
outputs[obj].get("out"),
outputs[obj].get("tgt"),
outputs[obj].get("mask", None),
outputs[obj].get("ls", 0.0),
name=obj + "-loss",
factor=outputs[obj].get("factor", 1.0),
)
else:
_losses = self._custom_loss(
outputs[obj].get("loss"),
name=obj + "-loss",
factor=outputs[obj].get("factor", 1.0),
)
losses += [_losses]
if outputs[obj].get("nll_loss", False):
nll_loss += [_losses.get("nll_loss", 0.0)]
loss = sum(l["loss"] for l in losses)
nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 else loss.new_tensor(0)
# NOTE:
# we don't need to use sample_size as denominator for the gradient
# here sample_size is just used for logging
sample_size = 1
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
for l in losses:
logging_output[l["name"]] = (
utils.item(l["loss"].data / l["factor"])
if reduce
else l[["loss"]].data / l["factor"]
)
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
nll_loss = utils.item(sum(log.get("nll_loss", 0) for log in logging_outputs))
metrics.log_scalar(
"loss", loss / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"nll_loss", nll_loss / sample_size / math.log(2), sample_size, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
for key in logging_outputs[0]:
if key[-5:] == "-loss":
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(
key[:-5],
val / sample_size / math.log(2) if sample_size > 0 else 0.0,
sample_size,
round=3,
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -1,141 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import metrics
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
@dataclass
class SentencePredictionConfig(FairseqDataclass):
classification_head_name: str = field(
default="sentence_classification_head",
metadata={"help": "name of the classification head to use"},
)
regression_target: bool = field(
default=False,
)
@register_criterion("sentence_prediction", dataclass=SentencePredictionConfig)
class SentencePredictionCriterion(FairseqCriterion):
def __init__(self, cfg: SentencePredictionConfig, task):
super().__init__(task)
self.classification_head_name = cfg.classification_head_name
self.regression_target = cfg.regression_target
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model, "classification_heads")
and self.classification_head_name in model.classification_heads
), "model must provide sentence classification head for --criterion=sentence_prediction"
logits, _ = model(
**sample["net_input"],
features_only=True,
classification_head_name=self.classification_head_name,
)
targets = model.get_targets(sample, [logits]).view(-1)
sample_size = targets.numel()
if not self.regression_target:
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
task_loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
logits = logits.view(-1).float()
targets = targets.float()
task_loss = F.mse_loss(logits, targets, reduction="sum")
logging_output = {}
loss = task_loss
# mha & ffn regularization update
if (
hasattr(model.args, "mha_reg_scale_factor")
and model.args.mha_reg_scale_factor != 0.0
):
mha_reg_loss = model._get_adaptive_head_loss()
loss += mha_reg_loss
logging_output.update({"mha_reg_loss": mha_reg_loss})
if (
hasattr(model.args, "ffn_reg_scale_factor")
and model.args.ffn_reg_scale_factor != 0.0
):
ffn_reg_loss = model._get_adaptive_ffn_loss()
loss += ffn_reg_loss
logging_output.update({"ffn_reg_loss": ffn_reg_loss})
logging_output.update(
{
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
)
if not self.regression_target:
preds = logits.argmax(dim=1)
logging_output["ncorrect"] = (preds == targets).sum()
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
mha_reg_loss_sum = sum(log.get("mha_reg_loss", 0) for log in logging_outputs)
ffn_reg_loss_sum = sum(log.get("ffn_reg_loss", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if mha_reg_loss_sum:
metrics.log_scalar(
"mha_reg_loss",
mha_reg_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if ffn_reg_loss_sum:
metrics.log_scalar(
"ffn_reg_loss",
ffn_reg_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
metrics.log_scalar(
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -1,63 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from fairseq.criterions import register_criterion
from fairseq.criterions.sentence_prediction import (
SentencePredictionCriterion,
SentencePredictionConfig,
)
@register_criterion("sentence_prediction_adapters", dataclass=SentencePredictionConfig)
class SentencePredictionCriterionAdapters(SentencePredictionCriterion):
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model, "classification_heads")
and self.classification_head_name in model.classification_heads
), "model must provide sentence classification head for --criterion=sentence_prediction"
if not hasattr(sample, "lang_id"):
# If no language ID is given, we fall back to English
lang_id = ["en_XX"] * sample["nsentences"]
else:
lang_id = sample["lang_id"]
logits, _ = model(
**sample["net_input"],
features_only=True,
classification_head_name=self.classification_head_name,
lang_id=lang_id,
)
targets = model.get_targets(sample, [logits]).view(-1)
sample_size = targets.numel()
if not self.regression_target:
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
logits = logits.view(-1).float()
targets = targets.float()
loss = F.mse_loss(logits, targets, reduction="sum")
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
if not self.regression_target:
preds = logits.argmax(dim=1)
logging_output["ncorrect"] = (preds == targets).sum()
return loss, sample_size, logging_output

View File

@@ -1,120 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("sentence_ranking")
class SentenceRankingCriterion(FairseqCriterion):
def __init__(self, task, ranking_head_name, save_predictions, num_classes):
super().__init__(task)
self.ranking_head_name = ranking_head_name
if save_predictions is not None:
self.prediction_h = open(save_predictions, "w")
else:
self.prediction_h = None
self.num_classes = num_classes
def __del__(self):
if self.prediction_h is not None:
self.prediction_h.close()
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--save-predictions', metavar='FILE',
help='file to save predictions to')
parser.add_argument('--ranking-head-name',
default='sentence_classification_head',
help='name of the ranking head to use')
# fmt: on
def forward(self, model, sample, reduce=True):
"""Compute ranking loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
assert (
hasattr(model, "classification_heads")
and self.ranking_head_name in model.classification_heads
), "model must provide sentence ranking head for --criterion=sentence_ranking"
scores = []
for idx in range(self.num_classes):
score, _ = model(
**sample["net_input{idx}".format(idx=idx + 1)],
classification_head_name=self.ranking_head_name,
)
scores.append(score)
logits = torch.cat(scores, dim=1)
sample_size = logits.size(0)
if "target" in sample:
targets = model.get_targets(sample, [logits]).view(-1)
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = F.nll_loss(lprobs, targets, reduction="sum")
else:
targets = None
loss = torch.tensor(0.0, requires_grad=True)
if self.prediction_h is not None:
preds = logits.argmax(dim=1)
for i, (id, pred) in enumerate(zip(sample["id"].tolist(), preds.tolist())):
if targets is not None:
label = targets[i].item()
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
else:
print("{}\t{}".format(id, pred), file=self.prediction_h)
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample_size,
"sample_size": sample_size,
}
if targets is not None:
logging_output["ncorrect"] = (logits.argmax(dim=1) == targets).sum()
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]:
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
metrics.log_scalar(
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -1,516 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from collections import OrderedDict
import torch
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.criterions.ctc import CtcCriterion
from fairseq.criterions.label_smoothed_cross_entropy_with_rdrop import (
RdropLabelSmoothedCrossEntropyCriterion,
RdropLabelSmoothedCrossEntropyCriterionConfig,
duplicate_input,
)
from fairseq.criterions.tacotron2_loss import (
Tacotron2Criterion,
Tacotron2CriterionConfig,
)
logger = logging.getLogger(__name__)
class MultitaskCriterion:
def __init__(self, multitask_tasks, rdrop_alpha=0.0):
self.rdrop_alpha = rdrop_alpha
self.rdrop_alpha_mtl = rdrop_alpha
self.multitask_criterion = OrderedDict()
self.multitask_loss_weight = OrderedDict()
for task_name, task_obj in multitask_tasks.items():
if task_obj.args.get_loss_weight(0) == 0:
logger.info(f"Skip {task_name} loss criterion")
continue
rdrop_alpha_task = task_obj.args.rdrop_alpha
if rdrop_alpha_task is None:
rdrop_alpha_task = rdrop_alpha
self.rdrop_alpha_mtl = rdrop_alpha_task
logger.info(f"rdrop_alpha is set to {rdrop_alpha_task} for {task_name}")
if task_obj.args.decoder_type == "ctc":
self.multitask_criterion[task_name] = CtcCriterion(
task_obj.args.criterion_cfg,
task_obj,
rdrop_alpha=rdrop_alpha_task,
)
else:
self.multitask_criterion[
task_name
] = RdropLabelSmoothedCrossEntropyCriterion(
task_obj,
task_obj.args.criterion_cfg.sentence_avg,
label_smoothing=task_obj.args.criterion_cfg.label_smoothing,
rdrop_alpha=rdrop_alpha_task,
)
def set_multitask_loss_weight(self, task_name, weight=0.0):
self.multitask_loss_weight[task_name] = weight
def get_multitask_loss(self, model, sample, model_out):
logging_output = {}
loss = 0.0
for task_name, task_criterion in self.multitask_criterion.items():
layer_id = task_criterion.task.args.input_layer
if isinstance(task_criterion, CtcCriterion):
if task_criterion.task.args.input_from == "encoder":
if len(model_out["encoder_padding_mask"]) > 0:
non_padding_mask = ~model_out["encoder_padding_mask"][0]
input_lengths = non_padding_mask.long().sum(-1)
else:
out = model_out["encoder_states"][layer_id]
input_lengths = out.new_full(
(out.shape[1],), out.shape[0]
).long()
task_sample = {
"net_input": {
"src_tokens": model_out["encoder_states"][
layer_id
], # check batch idx
"src_lengths": input_lengths,
},
"id": sample["id"],
}
else:
task_sample = {
"net_input": {
"src_tokens": model_out["inner_states"][layer_id],
"src_lengths": sample["target_lengths"],
},
"id": sample["id"],
}
else:
task_sample = {
"net_input": {
"src_tokens": sample["multitask"][task_name]["net_input"][
"prev_output_tokens"
],
"encoder_out": {
"encoder_out": [model_out["encoder_states"][layer_id]],
"encoder_padding_mask": model_out["encoder_padding_mask"],
},
}
}
for key in ["target", "target_lengths", "ntokens"]:
task_sample[key] = sample["multitask"][task_name][key]
if task_name == getattr(model, "mt_task_name", None):
decoder_out = model_out["mt_decoder_out"]
else:
decoder_out = None
task_loss, task_sample_size, task_logging_output = task_criterion(
model.multitask_decoders[task_name], task_sample, net_output=decoder_out
)
loss = loss + self.multitask_loss_weight[task_name] * task_loss
task_logging_output["loss_weight"] = self.multitask_loss_weight[task_name]
logging_output[task_name] = task_logging_output
return loss, logging_output
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
for task_name in logging_outputs[0]["multitask"].keys():
# different criterion may return different logging
# currently only reduce on loss, the most common one
# ideally the way that losses are reduced should also depend on the task type
loss_sum = sum(
log["multitask"][task_name].get("loss", 0) for log in logging_outputs
)
sample_size = sum(
log["multitask"][task_name].get("sample_size", 0)
for log in logging_outputs
)
metrics.log_scalar(
f"multitask_{task_name}_loss",
loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
loss_weight = logging_outputs[0]["multitask"][task_name].get(
"loss_weight", 0
)
metrics.log_scalar(
f"multitask_{task_name}_loss_weight",
loss_weight,
weight=0,
priority=250,
)
@register_criterion(
"speech_to_unit", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig
)
class SpeechToUnitMultitaskTaskCriterion(
RdropLabelSmoothedCrossEntropyCriterion, MultitaskCriterion
):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=0,
report_accuracy=False,
rdrop_alpha=0.0,
):
super().__init__(
task,
sentence_avg,
label_smoothing,
ignore_prefix_size,
report_accuracy,
rdrop_alpha,
)
MultitaskCriterion.__init__(self, task.multitask_tasks, rdrop_alpha)
def forward(self, model, sample, reduce=True):
net_input_concat = {
"src_tokens": sample["net_input"]["src_tokens"],
"src_lengths": sample["net_input"]["src_lengths"],
"prev_output_tokens": sample["net_input"]["prev_output_tokens"],
"tgt_speaker": sample["net_input"].get("tgt_speaker", None),
"return_all_hiddens": True,
}
if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0:
net_input_concat = duplicate_input(net_input_concat)
net_output, extra = model(**net_input_concat)
loss, nll_loss, rdrop_kl_loss = self.compute_loss(
model, [net_output], sample, reduce=reduce
)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, [net_output], sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
if self.rdrop_alpha > 0:
logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
if len(self.multitask_criterion) == 0:
return loss, sample_size, logging_output
# multitask
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
loss += multitask_loss
logging_output["multitask"] = multitask_log
return loss, sample_size, logging_output
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
super().reduce_metrics(logging_outputs)
# inference metrics
if "targ_frames" in logging_outputs[0]:
n = sum(log.get("norm_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)
if "multitask" not in logging_outputs[0]:
return
MultitaskCriterion.reduce_metrics(logging_outputs)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
@register_criterion(
"speech_to_unit_2pass", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig
)
class SpeechToUnit2passMultitaskTaskCriterion(SpeechToUnitMultitaskTaskCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=0,
report_accuracy=False,
rdrop_alpha=0.0,
):
super().__init__(
task,
sentence_avg,
label_smoothing,
ignore_prefix_size,
report_accuracy,
rdrop_alpha,
)
def forward(self, model, sample, reduce=True):
net_input_concat = {
"src_tokens": sample["net_input"]["src_tokens"],
"src_lengths": sample["net_input"]["src_lengths"],
"prev_output_tokens": sample["net_input"]["prev_output_tokens"],
"prev_output_tokens_mt": sample["multitask"][model.mt_task_name][
"net_input"
]["prev_output_tokens"],
"tgt_speaker": sample["net_input"].get("tgt_speaker", None),
"return_all_hiddens": True,
}
if getattr(model, "asr_task_name", None) is not None:
net_input_concat["prev_output_tokens_asr"] = sample["multitask"][
model.asr_task_name
]["net_input"]["prev_output_tokens"]
if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0:
net_input_concat = duplicate_input(net_input_concat)
net_output, extra = model(**net_input_concat)
loss, nll_loss, rdrop_kl_loss = self.compute_loss(
model, [net_output], sample, reduce=reduce
)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, [net_output], sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
if self.rdrop_alpha > 0:
logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
if len(self.multitask_criterion) == 0:
return loss, sample_size, logging_output
# multitask
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
loss += multitask_loss
logging_output["multitask"] = multitask_log
return loss, sample_size, logging_output
@register_criterion("speech_to_spectrogram", dataclass=Tacotron2CriterionConfig)
class SpeechToSpectrogramMultitaskTaskCriterion(Tacotron2Criterion, MultitaskCriterion):
def __init__(
self,
task,
sentence_avg,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
):
super().__init__(
task,
sentence_avg,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
)
MultitaskCriterion.__init__(self, task.multitask_tasks)
def forward(self, model, sample, reduction="mean"):
bsz, max_len, _ = sample["target"].size()
feat_tgt = sample["target"]
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
eos_tgt = torch.arange(max_len).to(sample["target"].device)
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
eos_tgt = (eos_tgt == (feat_len - 1)).float()
feat_out, eos_out, extra = model(
src_tokens=sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
tgt_speaker=sample["net_input"]["tgt_speaker"],
target_lengths=sample["target_lengths"],
return_all_hiddens=True,
)
l1_loss, mse_loss, eos_loss = self.compute_loss(
extra["feature_out"],
feat_out,
eos_out,
feat_tgt,
eos_tgt,
sample["target_lengths"],
reduction,
)
attn_loss = torch.tensor(0.0).type_as(l1_loss)
if self.guided_attn is not None:
attn_loss = self.guided_attn(
extra["attn"],
sample["net_input"]["src_lengths"],
sample["target_lengths"],
reduction,
)
loss = (
l1_loss + mse_loss + eos_loss + attn_loss
) # do not include ctc loss as there's no text target
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"l1_loss": utils.item(l1_loss.data),
"mse_loss": utils.item(mse_loss.data),
"eos_loss": utils.item(eos_loss.data),
"attn_loss": utils.item(attn_loss.data),
}
if len(self.multitask_criterion) == 0:
return loss, sample_size, logging_output
# multitask
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
loss += multitask_loss
logging_output["multitask"] = multitask_log
return loss, sample_size, logging_output
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
super().reduce_metrics(logging_outputs)
# inference metrics
if "targ_frames" in logging_outputs[0]:
n = sum(log.get("norm_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)
if "multitask" not in logging_outputs[0]:
return
MultitaskCriterion.reduce_metrics(logging_outputs)
@register_criterion("speech_to_spectrogram_2pass", dataclass=Tacotron2CriterionConfig)
class SpeechToSpectrogram2passMultitaskTaskCriterion(
SpeechToSpectrogramMultitaskTaskCriterion
):
def __init__(
self,
task,
sentence_avg,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
):
super().__init__(
task,
sentence_avg,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
)
def forward(self, model, sample, reduction="mean"):
bsz, max_len, _ = sample["target"].size()
feat_tgt = sample["target"]
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
eos_tgt = torch.arange(max_len).to(sample["target"].device)
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
eos_tgt = (eos_tgt == (feat_len - 1)).float()
feat_out, eos_out, extra = model(
src_tokens=sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
prev_output_tokens_mt=sample["multitask"][model.mt_task_name]["net_input"][
"prev_output_tokens"
],
tgt_speaker=sample["net_input"]["tgt_speaker"],
target_lengths=sample["target_lengths"],
return_all_hiddens=True,
)
l1_loss, mse_loss, eos_loss = self.compute_loss(
extra["feature_out"],
feat_out,
eos_out,
feat_tgt,
eos_tgt,
sample["target_lengths"],
reduction,
)
attn_loss = torch.tensor(0.0).type_as(l1_loss)
if self.guided_attn is not None:
attn_loss = self.guided_attn(
extra["attn"],
sample["net_input"]["src_lengths"],
sample["target_lengths"],
reduction,
)
loss = (
l1_loss + mse_loss + eos_loss + attn_loss
) # do not include ctc loss as there's no text target
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"l1_loss": utils.item(l1_loss.data),
"mse_loss": utils.item(mse_loss.data),
"eos_loss": utils.item(eos_loss.data),
"attn_loss": utils.item(attn_loss.data),
}
if len(self.multitask_criterion) == 0:
return loss, sample_size, logging_output
# multitask
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
loss += multitask_loss
logging_output["multitask"] = multitask_log
return loss, sample_size, logging_output

View File

@@ -1,126 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from dataclasses import dataclass, field
import torch.nn.functional as F
from fairseq import metrics
from fairseq.tasks import FairseqTask
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
@dataclass
class SpeechUnitLmCriterionConfig(FairseqDataclass):
sentence_avg: bool = II("optimization.sentence_avg")
loss_weights: str = field(
default="1.;0.0;0.0",
metadata={
"help": "Weights of the losses that correspond to token, duration, and F0 streams"
},
)
discrete_duration: bool = II("task.discrete_duration")
discrete_f0: bool = II("task.discrete_f0")
def mae_loss(pred, targ, mask, reduce=True):
if pred.ndim == 3:
pred = pred.squeeze(2)
else:
assert pred.ndim == 2
loss = (pred.float() - targ.float()).abs() * (~mask).float()
loss = loss.sum() if reduce else loss.view(-1)
return loss
def nll_loss(pred, targ, mask, reduce=True):
lprob = F.log_softmax(pred, dim=-1)
loss = F.nll_loss(lprob.view(-1, lprob.size(-1)), targ.view(-1), reduction="none")
loss = loss * (~mask).float().view(-1)
loss = loss.sum() if reduce else loss.view(-1)
return loss
@register_criterion("speech_unit_lm_criterion", dataclass=SpeechUnitLmCriterionConfig)
class SpeechUnitLmCriterion(FairseqCriterion):
def __init__(self, cfg: SpeechUnitLmCriterionConfig, task: FairseqTask):
super().__init__(task)
self.sentence_avg = cfg.sentence_avg
self.weights = torch.tensor([float(w) for w in cfg.loss_weights.split(";")])
assert self.weights.size(0) == 3
assert (self.weights >= 0.0).all()
self.dur_loss_fn = nll_loss if cfg.discrete_duration else mae_loss
self.f0_loss_fn = nll_loss if cfg.discrete_f0 else mae_loss
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
token_loss = nll_loss(
net_output["token"], sample["target"], sample["mask"], reduce
)
dur_loss = self.dur_loss_fn(
net_output["duration"],
sample["dur_target"],
sample["dur_mask"],
reduce,
)
f0_loss = self.f0_loss_fn(
net_output["f0"],
sample["f0_target"],
sample["f0_mask"],
reduce,
)
loss = self.weights.to(token_loss.device) * torch.stack(
[token_loss, dur_loss, f0_loss], dim=-1
)
loss = loss.sum() if reduce else loss.sum(-1)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.detach().sum().item(),
"token_loss": token_loss.detach().sum().item(),
"dur_loss": dur_loss.detach().sum().item(),
"f0_loss": f0_loss.detach().sum().item(),
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
token_loss_sum = sum(log.get("token_loss", 0) for log in logging_outputs)
dur_loss_sum = sum(log.get("dur_loss", 0) for log in logging_outputs)
f0_loss_sum = sum(log.get("f0_loss", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
metrics.log_scalar(
"token_loss", token_loss_sum / sample_size, sample_size, round=3
)
metrics.log_scalar("dur_loss", dur_loss_sum / sample_size, sample_size, round=3)
metrics.log_scalar("f0_loss", f0_loss_sum / sample_size, sample_size, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
return True

View File

@@ -1,226 +0,0 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Any, Dict, List
import torch
import torch.nn.functional as F
from omegaconf import II
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.data.data_utils import lengths_to_mask
from fairseq.dataclass import FairseqDataclass
logger = logging.getLogger(__name__)
@dataclass
class Tacotron2CriterionConfig(FairseqDataclass):
bce_pos_weight: float = field(
default=1.0,
metadata={"help": "weight of positive examples for BCE loss"},
)
use_guided_attention_loss: bool = field(
default=False,
metadata={"help": "use guided attention loss"},
)
guided_attention_loss_sigma: float = field(
default=0.4,
metadata={"help": "weight of positive examples for BCE loss"},
)
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
sentence_avg: bool = II("optimization.sentence_avg")
class GuidedAttentionLoss(torch.nn.Module):
"""
Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
Networks with Guided Attention (https://arxiv.org/abs/1710.08969)
"""
def __init__(self, sigma):
super().__init__()
self.sigma = sigma
@staticmethod
@lru_cache(maxsize=8)
def _get_weight(s_len, t_len, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(t_len), torch.arange(s_len))
grid_x = grid_x.to(s_len.device)
grid_y = grid_y.to(s_len.device)
w = (grid_y.float() / s_len - grid_x.float() / t_len) ** 2
return 1.0 - torch.exp(-w / (2 * (sigma**2)))
def _get_weights(self, src_lens, tgt_lens):
bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens)
weights = torch.zeros((bsz, max_t_len, max_s_len))
for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)):
weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, self.sigma)
return weights
@staticmethod
def _get_masks(src_lens, tgt_lens):
in_masks = lengths_to_mask(src_lens)
out_masks = lengths_to_mask(tgt_lens)
return out_masks.unsqueeze(2) & in_masks.unsqueeze(1)
def forward(self, attn, src_lens, tgt_lens, reduction="mean"):
weights = self._get_weights(src_lens, tgt_lens).to(attn.device)
masks = self._get_masks(src_lens, tgt_lens).to(attn.device)
loss = (weights * attn.transpose(1, 2)).masked_select(masks)
loss = torch.sum(loss) if reduction == "sum" else torch.mean(loss)
return loss
@register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig)
class Tacotron2Criterion(FairseqCriterion):
def __init__(
self,
task,
sentence_avg,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
):
super().__init__(task)
self.sentence_avg = sentence_avg
self.bce_pos_weight = bce_pos_weight
self.guided_attn = None
if use_guided_attention_loss:
self.guided_attn = GuidedAttentionLoss(guided_attention_loss_sigma)
self.ctc_weight = ctc_weight
def forward(self, model, sample, reduction="mean"):
bsz, max_len, _ = sample["target"].size()
feat_tgt = sample["target"]
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
eos_tgt = torch.arange(max_len).to(sample["target"].device)
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
eos_tgt = (eos_tgt == (feat_len - 1)).float()
src_tokens = sample["net_input"]["src_tokens"]
src_lens = sample["net_input"]["src_lengths"]
tgt_lens = sample["target_lengths"]
feat_out, eos_out, extra = model(
src_tokens=src_tokens,
src_lengths=src_lens,
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
incremental_state=None,
target_lengths=tgt_lens,
speaker=sample["speaker"],
)
l1_loss, mse_loss, eos_loss = self.compute_loss(
extra["feature_out"],
feat_out,
eos_out,
feat_tgt,
eos_tgt,
tgt_lens,
reduction,
)
attn_loss = torch.tensor(0.0).type_as(l1_loss)
if self.guided_attn is not None:
attn_loss = self.guided_attn(extra["attn"], src_lens, tgt_lens, reduction)
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
if self.ctc_weight > 0.0:
net_output = (feat_out, eos_out, extra)
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.transpose(0, 1) # T x B x C
src_mask = lengths_to_mask(src_lens)
src_tokens_flat = src_tokens.masked_select(src_mask)
ctc_loss = (
F.ctc_loss(
lprobs,
src_tokens_flat,
tgt_lens,
src_lens,
reduction=reduction,
zero_infinity=True,
)
* self.ctc_weight
)
loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"l1_loss": utils.item(l1_loss.data),
"mse_loss": utils.item(mse_loss.data),
"eos_loss": utils.item(eos_loss.data),
"attn_loss": utils.item(attn_loss.data),
"ctc_loss": utils.item(ctc_loss.data),
}
return loss, sample_size, logging_output
def compute_loss(
self,
feat_out,
feat_out_post,
eos_out,
feat_tgt,
eos_tgt,
tgt_lens,
reduction="mean",
):
mask = lengths_to_mask(tgt_lens)
_eos_out = eos_out[mask].squeeze()
_eos_tgt = eos_tgt[mask]
_feat_tgt = feat_tgt[mask]
_feat_out = feat_out[mask]
_feat_out_post = feat_out_post[mask]
l1_loss = F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + F.l1_loss(
_feat_out_post, _feat_tgt, reduction=reduction
)
mse_loss = F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + F.mse_loss(
_feat_out_post, _feat_tgt, reduction=reduction
)
eos_loss = F.binary_cross_entropy_with_logits(
_eos_out,
_eos_tgt,
pos_weight=torch.tensor(self.bce_pos_weight),
reduction=reduction,
)
return l1_loss, mse_loss, eos_loss
@classmethod
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
ns = [log.get("sample_size", 0) for log in logging_outputs]
ntot = sum(ns)
ws = [n / (ntot + 1e-8) for n in ns]
for key in ["loss", "l1_loss", "mse_loss", "eos_loss", "attn_loss", "ctc_loss"]:
vals = [log.get(key, 0) for log in logging_outputs]
val = sum(val * w for val, w in zip(vals, ws))
metrics.log_scalar(key, val, ntot, round=3)
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
# inference metrics
if "targ_frames" not in logging_outputs[0]:
return
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
return False

View File

@@ -1,230 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
from typing import List, Optional
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.logging.meters import safe_round
from fairseq.utils import is_xla_tensor
@dataclass
class Wav2VecCriterionConfig(FairseqDataclass):
infonce: bool = field(
default=False,
metadata={
"help": "if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)"
},
)
loss_weights: Optional[List[float]] = field(
default=None,
metadata={"help": "weights for additional loss terms (not first one)"},
)
log_keys: List[str] = field(
default_factory=lambda: [],
metadata={"help": "output keys to log"},
)
@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig)
class Wav2vecCriterion(FairseqCriterion):
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
super().__init__(task)
self.infonce = infonce
self.loss_weights = loss_weights
self.log_keys = [] if log_keys is None else log_keys
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output)
self.xla = is_xla_tensor(logits)
# XXX: handle weights on xla.
weights = None
if hasattr(model, "get_target_weights") and not self.infonce:
weights = model.get_target_weights(target, net_output)
if torch.is_tensor(weights):
weights = weights.float()
losses = []
reduction = "none" if ((not reduce) or self.xla) else "sum"
if self.infonce:
loss = F.cross_entropy(logits, target, reduction=reduction)
else:
loss = F.binary_cross_entropy_with_logits(
logits, target.float(), weights, reduction=reduction
)
if self.xla:
# tpu-comment: since dynamic shapes lead to recompilations on xla,
# we don't shrink tensors using mask_indices.
# Instead, we use mask indices to adjust loss.
mi = (
sample["net_input"]["mask_indices"]
.transpose(0, 1) # logits are transposed in `model.get_logits`
.reshape(logits.size(0))
)
loss = (loss * mi).sum() if reduce else (loss * mi)
if "sample_size" in sample:
sample_size = sample["sample_size"]
elif "mask_indices" in sample["net_input"]:
sample_size = sample["net_input"]["mask_indices"].sum()
else:
sample_size = target.numel() if self.infonce else target.long().sum().item()
losses.append(loss.detach().clone())
if self.loss_weights is not None:
assert hasattr(model, "get_extra_losses")
extra_losses = model.get_extra_losses(net_output)
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, coef in zip(extra_losses, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
losses.append(p)
logging_output = {
"loss": loss.item() if (reduce and not self.xla) else loss.detach(),
"ntokens": sample_size,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
}
for lk in self.log_keys:
# Only store "logits" and "target" for computing MAP and MAUC
# during validation
if lk == "logits":
if not self.training:
logging_output["logits"] = logits.cpu().numpy()
elif lk == "target":
if not self.training:
# If the targets have been mixed with the predictions of
# teacher models, find the original targets
if hasattr(model, "get_original_targets"):
original_target = model.get_original_targets(sample, net_output)
else:
original_target = target
logging_output["target"] = original_target.cpu().numpy()
elif lk in net_output:
value = net_output[lk]
if not is_xla_tensor(value):
value = float(value)
logging_output[lk] = value
if len(losses) > 1:
for i, l in enumerate(losses):
logging_output[f"loss_{i}"] = l.item() if not self.xla else l.detach()
if self.infonce:
with torch.no_grad():
if logits.numel() == 0:
corr = 0
count = 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
if is_xla_tensor(logits):
max, min = max * mi, min * mi
both = max & min
corr = max.long().sum() - both.long().sum()
count = mi.sum()
else:
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = float(max.numel())
logging_output["correct"] = corr
logging_output["count"] = count
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3
)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
correct = sum(log.get("correct", 0) for log in logging_outputs)
metrics.log_scalar("_correct", correct)
total = sum(log.get("count", 0) for log in logging_outputs)
metrics.log_scalar("_total", total)
if total > 0:
metrics.log_derived(
"accuracy",
lambda meters: safe_round(
meters["_correct"].sum / meters["_total"].sum, 5
)
if meters["_total"].sum > 0
else float("nan"),
)
builtin_keys = {
"loss",
"ntokens",
"nsentences",
"sample_size",
"correct",
"count",
}
for k in logging_outputs[0]:
if k not in builtin_keys:
val = sum(log.get(k, 0) for log in logging_outputs)
if k.startswith("loss"):
metrics.log_scalar(
k, val / (sample_size or 1) / math.log(2), sample_size, round=3
)
else:
metrics.log_scalar(k, val / len(logging_outputs), round=3)
# FIXME: revert when gather based xla reduction is implemented
# @staticmethod
# def logging_outputs_can_be_summed() -> bool:
def logging_outputs_can_be_summed(self) -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
# XXX: Gather based reduction not implemented for xla yet.
# So we fall to sum based reduction for xla.
return self.xla

View File

@@ -1,6 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import criterions, models, modules # noqa

View File

@@ -1,14 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
# automatically import any Python files in the criterions/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("fairseq.model_parallel.criterions." + module)

View File

@@ -1,87 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
try:
from fairseq.model_parallel.megatron.mpu.cross_entropy import (
vocab_parallel_cross_entropy,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
@register_criterion("vocab_parallel_cross_entropy")
class VocabParallelCrossEntropyCriterion(FairseqCriterion):
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
target = sample["target"]
loss = vocab_parallel_cross_entropy(net_output[0].float(), target)
loss = (loss * (target != self.padding_idx)).sum()
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True

View File

@@ -1,75 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a network across multiple GPUs.
"""
from fairseq.dataclass.configs import FairseqConfig
from fairseq.distributed import utils as distributed_utils
from fairseq.trainer import Trainer
try:
from fairseq.model_parallel.megatron.mpu import (
get_data_parallel_rank,
get_data_parallel_world_size,
get_model_parallel_src_rank,
get_cuda_rng_tracker,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
class MegatronTrainer(Trainer):
"""Main class for model parallel with data parallel training."""
def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs):
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
super().__init__(cfg, task, model, criterion, **kwargs)
def clip_grad_norm(self, clip_norm):
def _aggregate_model_parallel_grad_norm(total_norm):
total_norm = total_norm**2
distributed_utils.all_reduce(
total_norm, group=distributed_utils.get_model_parallel_group()
)
total_norm = total_norm**0.5
return total_norm
return self.optimizer.clip_grad_norm(
clip_norm,
aggregate_norm_fn=_aggregate_model_parallel_grad_norm,
)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
extra_state["rng_tracker_states"] = get_cuda_rng_tracker().get_states()
super().save_checkpoint(filename, extra_state)
def load_checkpoint(
self,
filename,
reset_optimizer=False,
reset_lr_scheduler=False,
optimizer_overrides=None,
reset_meters=False,
):
extra_state = super().load_checkpoint(
filename,
reset_optimizer=reset_optimizer,
reset_lr_scheduler=reset_lr_scheduler,
optimizer_overrides=optimizer_overrides,
reset_meters=reset_meters,
)
if extra_state is not None and "rng_tracker_states" in extra_state:
get_cuda_rng_tracker().set_states(extra_state["rng_tracker_states"])
return extra_state

View File

@@ -1,20 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
# automatically import any Python files in the models/ directory
models_dir = os.path.dirname(__file__)
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
model_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("fairseq.model_parallel.models." + model_name)

View File

@@ -1,6 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .model import * # noqa

View File

@@ -1,600 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.modules import (
AdaptiveSoftmax,
LayerNorm,
MultiheadAttention,
PositionalEmbedding,
)
EncoderOut = namedtuple(
"TransformerEncoderOut",
[
"encoder_out", # T x B x C
"encoder_padding_mask", # B x T
"encoder_embedding", # B x T x C
"encoder_states", # List[T x B x C]
],
)
class TransformerEncoderEmbedding(nn.Module):
"""Encoder Embedding + Positional Embedding"""
def __init__(self, args, embed_tokens):
super().__init__()
self.dropout = args.dropout
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens
if isinstance(embed_tokens, nn.ModuleList):
self.padding_idx = embed_tokens[0].padding_idx
embed_dim = sum(e.embedding_dim for e in embed_tokens)
else:
self.padding_idx = embed_tokens.padding_idx
embed_dim = embed_tokens.embedding_dim
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = (
PositionalEmbedding(
args.max_source_positions,
embed_dim,
self.padding_idx,
learned=args.encoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
def forward(self, input):
# embed tokens and positions
src_tokens = input[0]
prev_output_tokens = input[2]
if isinstance(self.embed_tokens, nn.ModuleList):
x_embed_list = []
for embed_tokens_part in self.embed_tokens:
x_embed_list.append(embed_tokens_part(src_tokens))
embedded = torch.cat(x_embed_list, dim=-1)
else:
embedded = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * embedded
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
if self.layernorm_embedding:
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
return (x, encoder_padding_mask, prev_output_tokens)
class TransformerEncoderLayerNorm(nn.Module):
"""
Layer norm at the the end of all encoder layers if
args.encoder_enormalize_before = True
"""
def __init__(self, args, embed_dim):
super().__init__()
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(self, input):
x = input[0]
encoder_padding_mask = input[1]
prev_output_tokens = input[2]
if self.layer_norm:
x = self.layer_norm(x)
# keeping track of the incremental_state is not supported yet
return (x, encoder_padding_mask, prev_output_tokens)
class TransformerDecoderEmbedding(nn.Module):
"""Decoder Embedding + Positional Embedding"""
def __init__(self, args, embed_tokens):
super().__init__()
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
input_embed_dim = (
sum(e.embedding_dim for e in embed_tokens)
if isinstance(embed_tokens, nn.ModuleList)
else embed_tokens.embedding_dim
)
embed_dim = args.decoder_embed_dim
self.output_embed_dim = args.decoder_output_dim
padding_idx = (
embed_tokens[0].padding_idx
if isinstance(embed_tokens, nn.ModuleList)
else embed_tokens.padding_idx
)
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
args.max_target_positions,
embed_dim,
padding_idx,
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
def forward(self, input):
mt_task = False
if isinstance(input, tuple):
if len(input) == 3:
encoder_out = input[0]
encoder_padding_mask = input[1]
prev_output_tokens = input[2]
incremental_state = None # Hardcoding to avoid passing of None objects
mt_task = True
else:
# HACK for now, need to fix (TODO sidgoyal)
prev_output_tokens = input[0]
# discard "src_lengths"
encoder_out = None
encoder_padding_mask = None
incremental_state = None
else:
prev_output_tokens = input
encoder_out = None
encoder_padding_mask = None
incremental_state = None
positions = (
self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
if isinstance(self.embed_tokens, nn.ModuleList):
x_embed_list = []
for embed_tokens_part in self.embed_tokens:
x_embed_list.append(embed_tokens_part(prev_output_tokens))
x = self.embed_scale * torch.cat(x_embed_list, dim=-1)
else:
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if mt_task:
return (x, encoder_out, encoder_padding_mask)
return x
class TransformerDecoderOutputLayer(nn.Module):
def __init__(self, args, embed_tokens, dictionary):
super().__init__()
self.share_input_output_embed = args.share_decoder_input_output_embed
self.embed_tokens = embed_tokens
self.output_embed_dim = args.decoder_output_dim
embed_dim = args.decoder_embed_dim
self.project_out_dim = (
Linear(embed_dim, self.output_embed_dim, bias=False)
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
else None
)
self.adaptive_softmax = None
if args.adaptive_softmax_cutoff is not None:
assert not isinstance(embed_tokens, nn.ModuleList)
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif not self.share_input_output_embed:
self.embed_tokens = nn.Parameter(
torch.Tensor(len(dictionary), self.output_embed_dim)
)
nn.init.normal_(
self.embed_tokens, mean=0, std=self.output_embed_dim**-0.5
)
if args.decoder_normalize_before and not getattr(
args, "no_decoder_final_norm", False
):
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(self, input, apply_final_proj=True):
if isinstance(input, tuple):
x = input[0]
else:
x = input
if self.layer_norm:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
if apply_final_proj:
x = self.output_layer(x)
return x
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
if isinstance(self.embed_tokens, nn.ModuleList):
output = None
for i, emb in enumerate(self.embed_tokens):
sidx = i * emb.embedding_dim
eidx = (i + 1) * emb.embedding_dim
if output is None:
output = F.linear(features[:, :, sidx:eidx], emb.weight)
else:
output += F.linear(features[:, :, sidx:eidx], emb.weight)
return output
else:
return F.linear(features, self.embed_tokens.weight)
else:
return F.linear(features, self.embed_tokens)
else:
return features
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, args):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, "activation_fn", "relu")
)
self.activation_dropout = getattr(args, "activation_dropout", 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, "relu_dropout", 0)
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layer_norms.{}.{}".format(name, old, m)
if k in state_dict:
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k]
def forward(self, input):
"""
Args:
input (Tuple):
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
input[1] (ByteTensor/FloatTensor): encoder padding mask -
binary ByteTensor of shape `(batch, src_len)` where padding elements
are indicated by ``1``.
input[2] (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing)
Returns:
output (Tuple):
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
output[1] (ByteTensor/FloatTensor): encoder padding mask
output[2] (LongTensor): previous decoder outputs
"""
x = input[0]
encoder_padding_mask = input[1]
prev_output_tokens = input[2]
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, _ = self.self_attn(
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
return (x, encoder_padding_mask, prev_output_tokens)
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(
embed_dim=self.embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True,
)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, "activation_fn", "relu")
)
self.activation_dropout = getattr(args, "activation_dropout", 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, "relu_dropout", 0)
self.normalize_before = args.decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export = getattr(args, "char_inputs", False)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = MultiheadAttention(
self.embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
self.need_attn = True
self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def forward(self, input):
"""
Args:
input (Tuple):
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)`
input[2] (ByteTensor/FloatTensor): encoder padding mask -
binary ByteTensor of shape `(batch, src_len)` where padding elements
are indicated by ``1``.
Returns:
output (Tuple):
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
output[1] (ByteTensor/FloatTensor): encoder padding mask
output[2] (LongTensor): previous decoder outputs
"""
# Note: incremental state is not yet supported
mt_task = False
if isinstance(input, tuple):
x = input[0]
encoder_out = input[1]
encoder_padding_mask = input[2]
incremental_state = None
mt_task = True
else:
x = input
encoder_out = None
encoder_padding_mask = None
incremental_state = None
if incremental_state is None:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
# TODO: add back prev_self_attn_state, prev_attn_state,
# self_attn_padding_mask
prev_self_attn_state = None
prev_attn_state = None
self_attn_padding_mask = None
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
if prev_self_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_self_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.self_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
if self.encoder_attn is not None:
residual = x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
if prev_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=(not self.training and self.need_attn),
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
if mt_task:
return (x, encoder_out, encoder_padding_mask)
return x
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m

View File

@@ -1,789 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import (
Embedding,
TransformerDecoderEmbedding,
TransformerDecoderLayer,
TransformerDecoderOutputLayer,
TransformerEncoderEmbedding,
TransformerEncoderLayer,
TransformerEncoderLayerNorm,
)
from fairseq.models import (
BaseFairseqModel,
FairseqDecoder,
FairseqEncoder,
register_model,
register_model_architecture,
)
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import (
base_architecture,
transformer_iwslt_de_en,
transformer_wmt_en_de_big,
)
from fairseq.modules import SinusoidalPositionalEmbedding
logger = logging.getLogger(__name__)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
TORCH_PIPE = False
RPC_INIT = False
def import_pipe():
global TORCH_PIPE
global RPC_INIT
try:
from torch.distributed.pipeline.sync import Pipe # noqa
global Pipe
from torch.distributed.pipeline.sync.utils import partition_model
global partition_model
from torch.distributed import rpc
import tempfile
TORCH_PIPE = True
# Initialize single process RPC agent since TORCH_PIPE requires
# RRef. RRef depends on RPC being initialized and as a result we initialize
# RPC with a single node.
tmpfile = tempfile.NamedTemporaryFile()
if not RPC_INIT:
rpc.init_rpc(
name="worker",
rank=0,
world_size=1,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
init_method="file://{}".format(tmpfile.name),
),
)
RPC_INIT = True
logger.info("Using torch pipe")
except ImportError:
try:
from fairscale.nn import Pipe # noqa
logger.info("Using fairscale pipe")
except ImportError:
raise ImportError("Please install fairscale with: pip install fairscale")
@register_model("pipeline_parallel_transformer")
class PipelineParallelTransformerModel(BaseFairseqModel):
def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint):
import_pipe()
super().__init__()
assert isinstance(encoder, FairseqEncoder)
assert isinstance(decoder, FairseqDecoder)
encoder_module_list = (
[encoder.embedding_layer]
+ list(encoder.encoder_layers)
+ [encoder.final_layer_norm]
)
self.num_encoder_modules = len(encoder_module_list)
decoder_module_list = (
[decoder.embedding_layer]
+ list(decoder.decoder_layers)
+ [decoder.decoder_output_layer]
)
self.num_decoder_modules = len(decoder_module_list)
module_list = encoder_module_list + decoder_module_list
self.devices = devices
if TORCH_PIPE:
self.model = Pipe(
partition_model(nn.Sequential(*module_list), balance, devices),
chunks=chunks,
checkpoint=checkpoint,
)
else:
self.model = Pipe(
nn.Sequential(*module_list),
balance=balance,
devices=devices,
chunks=chunks,
checkpoint=checkpoint,
)
self.encoder_max_positions = self.max_positions_helper(
encoder.embedding_layer, "max_source_positions"
)
self.decoder_max_positions = self.max_positions_helper(
decoder.embedding_layer, "max_target_positions"
)
self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None)
# Note: To be populated during inference
self.encoder = None
self.decoder = None
def forward(self, src_tokens, src_lengths, prev_output_tokens):
if self.training:
input_lst = [src_tokens, src_lengths, prev_output_tokens]
input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst)
if TORCH_PIPE:
return self.model(input).local_value()
else:
return self.model(input)
else:
assert self.encoder is not None and self.decoder is not None, (
"encoder and decoder need to be initialized by "
+ "calling the `prepare_for_inference_()` method"
)
encoder_output_tuple = self.encoder(input)
return self.decoder(encoder_output_tuple)
def prepare_for_inference_(self, cfg):
if self.encoder is not None and self.decoder is not None:
logger.info("Encoder and Decoder already initialized")
return
encoder_module_list = []
decoder_module_list = []
module_count = 0
for partition in self.model.partitions:
for module in partition:
if module_count < self.num_encoder_modules:
encoder_module_list.append(module)
else:
decoder_module_list.append(module)
module_count += 1
self.model = None
self.encoder = TransformerEncoder(
cfg.distributed_training, None, None, encoder_module_list
)
self.decoder = TransformerDecoder(
cfg.distributed_training,
None,
None,
decoder_module_list=decoder_module_list,
)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--num-embedding-chunks', type=int, metavar='N', default=1,
help='Number of embedding layer chunks (enables more even distribution'
'of optimizer states across data parallel nodes'
'when using optimizer state sharding and'
'a big embedding vocabulary)')
# fmt: on
@classmethod
def build_model_base(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if not hasattr(args, "max_source_positions"):
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if not hasattr(args, "max_target_positions"):
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1):
assert embed_dim % num_embed_chunks == 0, (
f"Number of embedding chunks = {num_embed_chunks} should be "
+ f"divisible by the embedding dimension = {embed_dim}"
)
assert path is None or num_embed_chunks == 1, (
"Loading embedding from a path with number of embedding chunks > 1"
+ " is not yet supported"
)
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
# if provided, load from preloaded dictionaries
if path:
emb = Embedding(num_embeddings, embed_dim, padding_idx)
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
else:
embed_chunk_dim = embed_dim // num_embed_chunks
emb = nn.ModuleList()
for i in range(num_embed_chunks):
emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx))
return emb
num_embed_chunks = args.num_embedding_chunks
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError("--share-all-embeddings requires a joined dictionary")
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path
):
raise ValueError(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
encoder_embed_tokens = build_embedding(
src_dict,
args.encoder_embed_dim,
args.encoder_embed_path,
num_embed_chunks,
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
assert args.share_decoder_input_output_embed or num_embed_chunks == 1, (
"Not sharing decoder I/O embeddings is not yet supported with number of "
+ "embedding chunks > 1"
)
encoder_embed_tokens = build_embedding(
src_dict,
args.encoder_embed_dim,
args.encoder_embed_path,
num_embed_chunks,
)
decoder_embed_tokens = build_embedding(
tgt_dict,
args.decoder_embed_dim,
args.decoder_embed_path,
num_embed_chunks,
)
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
return (encoder, decoder)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder(args, tgt_dict, embed_tokens)
@classmethod
def build_model(cls, args, task):
encoder, decoder = cls.build_model_base(args, task)
return PipelineParallelTransformerModel(
encoder=encoder,
decoder=decoder,
balance=utils.eval_str_list(args.pipeline_balance, type=int),
devices=utils.eval_str_list(args.pipeline_devices, type=int),
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
def output_layer(self, features, **kwargs):
"""Project features to the default output size (typically vocabulary size)."""
return self.decoder.output_layer(features, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return (self.encoder_max_positions, self.decoder_max_positions)
def max_positions_helper(
self, embedding_layer, max_positions_field="max_source_positions"
):
"""Maximum input length supported by the encoder or decoder."""
if embedding_layer.embed_positions is None:
return getattr(embedding_layer, max_positions_field)
return min(
getattr(embedding_layer, max_positions_field),
embedding_layer.embed_positions.max_positions,
)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
if sample is not None:
assert "target" in sample
target = sample["target"]
else:
target = None
out = self.adaptive_softmax.get_log_prob(net_output, target=target)
return out.exp_() if not log_probs else out
# A Pipe() module returns a tuple of tensors as the output.
# In this case, the tuple has one element - the output tensor of logits
logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0]
if log_probs:
return utils.log_softmax(logits, dim=-1, onnx_trace=False)
else:
return utils.softmax(logits, dim=-1, onnx_trace=False)
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder_max_positions
def load_state_dict(self, state_dict, strict=True, model_cfg=None):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
self.upgrade_state_dict(state_dict)
is_regular_transformer = not any("model.partitions" in k for k in state_dict)
if is_regular_transformer:
state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict)
return super().load_state_dict(state_dict, strict)
def convert_to_pipeline_parallel_state_dict(self, state_dict):
new_state_dict = self.state_dict()
encoder_layer_idx = 0
decoder_layer_idx = 0
encoder_key_suffixes = [
"self_attn.k_proj.weight",
"self_attn.k_proj.bias",
"self_attn.v_proj.weight",
"self_attn.v_proj.bias",
"self_attn.q_proj.weight",
"self_attn.q_proj.bias",
"self_attn.out_proj.weight",
"self_attn.out_proj.bias",
"self_attn_layer_norm.weight",
"self_attn_layer_norm.bias",
"fc1.weight",
"fc1.bias",
"fc2.weight",
"fc2.bias",
"final_layer_norm.weight",
"final_layer_norm.bias",
]
decoder_key_suffixes = [
"self_attn.k_proj.weight",
"self_attn.k_proj.bias",
"self_attn.v_proj.weight",
"self_attn.v_proj.bias",
"self_attn.q_proj.weight",
"self_attn.q_proj.bias",
"self_attn.out_proj.weight",
"self_attn.out_proj.bias",
"self_attn_layer_norm.weight",
"self_attn_layer_norm.bias",
"encoder_attn.k_proj.weight",
"encoder_attn.k_proj.bias",
"encoder_attn.v_proj.weight",
"encoder_attn.v_proj.bias",
"encoder_attn.q_proj.weight",
"encoder_attn.q_proj.bias",
"encoder_attn.out_proj.weight",
"encoder_attn.out_proj.bias",
"encoder_attn_layer_norm.weight",
"encoder_attn_layer_norm.bias",
"fc1.weight",
"fc1.bias",
"fc2.weight",
"fc2.bias",
"final_layer_norm.weight",
"final_layer_norm.bias",
]
for pid, partition in enumerate(self.model.partitions):
logger.info(f"Begin Partition {pid}")
for mid, module in enumerate(partition):
# fmt: off
if isinstance(module, TransformerEncoderEmbedding):
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['encoder.embed_positions._float_tensor']
if isinstance(module, TransformerEncoderLayer):
for suffix in encoder_key_suffixes:
new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}']
encoder_layer_idx += 1
if isinstance(module, TransformerDecoderLayer):
for suffix in decoder_key_suffixes:
new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'decoder.layers.{decoder_layer_idx}.{suffix}']
decoder_layer_idx += 1
if isinstance(module, TransformerEncoderLayerNorm):
if 'encoder.layer_norm.weight' in state_dict:
new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.weight'] = state_dict['encoder.layer_norm.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias']
if isinstance(module, TransformerDecoderEmbedding):
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['decoder.embed_positions._float_tensor']
if isinstance(module, TransformerDecoderOutputLayer):
new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight']
# fmt: on
return new_state_dict
class TransformerEncoder(FairseqEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None):
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
import_pipe()
self.use_pipeline = encoder_module_list is not None
if not self.use_pipeline:
self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens)
self.encoder_layers = nn.Sequential(
*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]
)
if isinstance(embed_tokens, nn.ModuleList):
emb_dim = sum(e.embedding_dim for e in embed_tokens)
else:
emb_dim = embed_tokens.embedding_dim
self.final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim)
else:
encoder_balance = utils.eval_str_list(
args.pipeline_encoder_balance, type=int
)
encoder_devices = utils.eval_str_list(
args.pipeline_encoder_devices, type=int
)
assert sum(encoder_balance) == len(encoder_module_list), (
f"Sum of encoder_balance={encoder_balance} is not equal "
+ f"to num_encoder_modules={len(encoder_module_list)}"
)
if TORCH_PIPE:
self.model = Pipe(
module=partition_model(
nn.Sequential(*encoder_module_list),
encoder_balance,
encoder_devices,
),
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
else:
self.model = Pipe(
module=nn.Sequential(*encoder_module_list),
balance=encoder_balance,
devices=encoder_devices,
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
def forward(self, src_tokens, src_lengths):
"""
Args:
input_tuple(
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
)
Returns:
output_tuple(
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- prev_output_tokens
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
)
"""
dummy_prev_output_tokens = torch.zeros(
1, dtype=src_tokens.dtype, device=src_tokens.device
)
input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens)
if self.use_pipeline:
input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
if TORCH_PIPE:
encoder_out = self.model(input_tuple).local_value()
else:
encoder_out = self.model(input_tuple)
else:
encoder_embed_output_tuple = self.embedding_layer(input_tuple)
encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple)
encoder_out = self.final_layer_norm(encoder_layers_output)
# first element is the encoder output
# second element is the encoder padding mask
# the remaining elements of EncoderOut are not computed by
# the PipelineParallelTransformer
return EncoderOut(encoder_out[0], encoder_out[1], None, None, None, None)
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if encoder_out.encoder_out is not None:
encoder_out = encoder_out._replace(
encoder_out=encoder_out.encoder_out.index_select(1, new_order)
)
if encoder_out.encoder_padding_mask is not None:
encoder_out = encoder_out._replace(
encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(
0, new_order
)
)
if encoder_out.encoder_embedding is not None:
encoder_out = encoder_out._replace(
encoder_embedding=encoder_out.encoder_embedding.index_select(
0, new_order
)
)
if encoder_out.encoder_states is not None:
for idx, state in enumerate(encoder_out.encoder_states):
encoder_out.encoder_states[idx] = state.index_select(1, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embedding_layer.embed_positions is None:
return self.embedding_layer.max_source_positions
return min(
self.embedding_layer.max_source_positions,
self.embedding_layer.embed_positions.max_positions,
)
class TransformerDecoder(FairseqDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self,
args,
dictionary,
embed_tokens,
no_encoder_attn=False,
decoder_module_list=None,
):
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
import_pipe()
self.use_pipeline = decoder_module_list is not None
if not self.use_pipeline:
self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
self.decoder_layers = nn.Sequential(
*[
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
]
)
self.decoder_output_layer = TransformerDecoderOutputLayer(
args, embed_tokens, dictionary
)
else:
decoder_balance = utils.eval_str_list(
args.pipeline_decoder_balance, type=int
)
decoder_devices = utils.eval_str_list(
args.pipeline_decoder_devices, type=int
)
assert sum(decoder_balance) == len(decoder_module_list), (
f"Sum of decoder_balance={decoder_balance} is not equal "
+ f"to num_decoder_modules={len(decoder_module_list)}"
)
if TORCH_PIPE:
self.model = Pipe(
module=partition_model(
nn.Sequential(*decoder_module_list),
decoder_balance,
decoder_devices,
),
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
else:
self.model = Pipe(
module=nn.Sequential(*decoder_module_list),
balance=decoder_balance,
devices=decoder_devices,
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
def forward(
self,
prev_output_tokens,
encoder_out=None,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
input_tuple = (
encoder_out.encoder_out,
encoder_out.encoder_padding_mask,
prev_output_tokens,
)
if self.use_pipeline:
input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
if TORCH_PIPE:
return (self.model(input_tuple).local_value(),)
else:
return (self.model(input_tuple),)
else:
embed_layer_output = self.embedding_layer(input_tuple)
state = self.decoder_layers(embed_layer_output)
return (self.decoder_output_layer(state),)
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
return F.linear(features, self.embed_tokens.weight)
else:
return F.linear(features, self.embed_out)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embedding_layer.embed_positions is None:
return self.embedding_layer.max_target_positions
return min(
self.embedding_layer.max_target_positions,
self.embedding_layer.embed_positions.max_positions,
)
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
layer_norm_map = {
"0": "self_attn_layer_norm",
"1": "encoder_attn_layer_norm",
"2": "final_layer_norm",
}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
if k in state_dict:
state_dict[
"{}.layers.{}.{}.{}".format(name, i, new, m)
] = state_dict[k]
del state_dict[k]
version_key = "{}.version".format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
@register_model_architecture(
"pipeline_parallel_transformer", "transformer_iwslt_de_en_pipeline_parallel"
)
def transformer_iwslt_de_en_dist(args):
transformer_iwslt_de_en(args)
@register_model_architecture(
"pipeline_parallel_transformer", "transformer_wmt_en_de_big_pipeline_parallel"
)
def transformer_wmt_en_de_big_dist(args):
transformer_wmt_en_de_big(args)

View File

@@ -1,6 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .model import * # noqa

View File

@@ -1,225 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
"""
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.model_parallel.models.transformer import ModelParallelTransformerEncoder
from fairseq.models import register_model, register_model_architecture
from fairseq.models.roberta import (
roberta_base_architecture,
roberta_prenorm_architecture,
RobertaEncoder,
RobertaModel,
)
from fairseq.modules import LayerNorm
try:
from fairseq.model_parallel.megatron.mpu import (
copy_to_model_parallel_region,
gather_from_model_parallel_region,
ColumnParallelLinear,
VocabParallelEmbedding,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
logger = logging.getLogger(__name__)
@register_model("model_parallel_roberta")
class ModelParallelRobertaModel(RobertaModel):
def __init__(self, args, encoder):
super().__init__(args, encoder)
self.classification_heads = nn.ModuleDict()
@staticmethod
def add_args(parser):
RobertaModel.add_args(parser)
parser.add_argument(
"--no-final-layer-norm",
action="store_true",
help=(
"don't add final layernorm (only applicable when "
"--encoder-normalize-before=True"
),
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present
base_architecture(args)
task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
if not hasattr(args, "max_positions"):
args.max_positions = args.tokens_per_sample
if getattr(args, "untie_weights_roberta", False):
raise NotImplementedError(
"--untie-weights-roberta is not supported in model parallel mode"
)
encoder = ModelParallelRobertaEncoder(args, task.source_dictionary)
return cls(args, encoder)
def forward(
self,
src_tokens,
features_only=False,
return_all_hiddens=False,
classification_head_name=None,
**kwargs
):
if classification_head_name is not None:
features_only = True
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
return x, extra
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = ModelParallelRobertaClassificationHead(
self.args.encoder_embed_dim,
inner_dim or self.args.encoder_embed_dim,
num_classes,
self.args.pooler_activation_fn,
self.args.pooler_dropout,
)
class ModelParallelRobertaLMHead(nn.Module):
"""Head for masked language modeling."""
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
super().__init__()
self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.layer_norm = LayerNorm(embed_dim)
if weight is None:
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features, masked_tokens=None, **kwargs):
# Only project the unmasked tokens while training,
# saves both memory and computation
if masked_tokens is not None:
features = features[masked_tokens, :]
x = self.dense(features)
x = self.activation_fn(x)
x = self.layer_norm(x)
x = copy_to_model_parallel_region(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight)
x = gather_from_model_parallel_region(x).contiguous()
x = x + self.bias
return x
class ModelParallelRobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout
):
super().__init__()
self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class ModelParallelRobertaEncoder(RobertaEncoder):
"""RoBERTa encoder."""
def __init__(self, args, dictionary):
super().__init__(args, dictionary)
assert not self.args.untie_weights_roberta
def build_embedding(self, vocab_size, embedding_dim, padding_idx):
return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx)
def build_encoder(self, args, dictionary, embed_tokens):
return ModelParallelTransformerEncoder(args, dictionary, embed_tokens)
def build_lm_head(self, embed_dim, output_dim, activation_fn, weight):
return ModelParallelRobertaLMHead(embed_dim, output_dim, activation_fn, weight)
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta")
def base_architecture(args):
args.no_final_layer_norm = getattr(args, "no_final_layer_norm", False)
# model parallel RoBERTa defaults to "Pre-LN" formulation
roberta_prenorm_architecture(args)
# earlier versions of model parallel RoBERTa removed the final layer norm
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_v1")
def model_parallel_roberta_v1_architecture(args):
args.no_final_layer_norm = getattr(args, "no_final_layer_norm", True)
base_architecture(args)
@register_model_architecture(
"model_parallel_roberta", "model_parallel_roberta_postnorm"
)
def model_parallel_roberta_postnorm_architecture(args):
# the original BERT/RoBERTa uses the "Post-LN" formulation
roberta_base_architecture(args)
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base")
def model_parallel_roberta_base_architecture(args):
base_architecture(args)
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large")
def model_parallel_roberta_large_architecture(args):
args.encoder_layers = getattr(args, "encoder_layers", 24)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
base_architecture(args)

View File

@@ -1,121 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch.nn as nn
from fairseq.model_parallel.modules import (
ModelParallelTransformerDecoderLayer,
ModelParallelTransformerEncoderLayer,
)
from fairseq.models import register_model
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
TransformerModel,
)
try:
from fairseq.model_parallel.megatron.mpu import (
VocabParallelEmbedding,
copy_to_model_parallel_region,
gather_from_model_parallel_region,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
logger = logging.getLogger(__name__)
@register_model("model_parallel_transformer")
class ModelParallelTransformerModel(TransformerModel):
"""
Model parallel Transformer model.
"""
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
dictionary.pad_to_multiple_(args.model_parallel_size * 8)
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
def _vocab_init(tensor, **kwargs):
nn.init.normal_(tensor, mean=0, std=num_embeddings**-0.5)
nn.init.constant_(tensor[1], 0)
emb = VocabParallelEmbedding(
num_embeddings, embed_dim, padding_idx, init_method=_vocab_init
)
# if provided, load from preloaded dictionaries
if path:
raise NotImplementedError(
"Loading of embedding from path is not supported for model parallel"
)
return emb
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return ModelParallelTransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return ModelParallelTransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, "no_cross_attention", False),
)
class ModelParallelTransformerEncoder(TransformerEncoder):
"""
Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`ModelParallelTransformerEncoderLayer`.
"""
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
if args.no_final_layer_norm:
self.layer_norm = None
def build_encoder_layer(self, args):
return ModelParallelTransformerEncoderLayer(args)
class ModelParallelTransformerDecoder(TransformerDecoder):
"""
Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`ModelParallelTransformerDecoderLayer`.
"""
def build_decoder_layer(self, args, no_encoder_attn=False):
return ModelParallelTransformerDecoderLayer(args, no_encoder_attn)
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if not self.share_input_output_embed:
raise NotImplementedError(
"Model parallel training currently requires --share-decoder-input-output-embed"
)
features = copy_to_model_parallel_region(features)
# project back to size of vocabulary
x = self.output_projection(features)
if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy":
x = gather_from_model_parallel_region(x).contiguous()
return x

View File

@@ -1,169 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer_lm import TransformerLanguageModel
try:
from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model("model_parallel_transformer_lm")
class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
@staticmethod
def add_args(parser):
TransformerLanguageModel.add_args(parser)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
# make sure all arguments are present in older models
base_lm_architecture(args)
task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)
if args.character_embeddings:
raise NotImplementedError(
"Character embeddings is not supported for model parallel"
)
elif args.adaptive_input:
raise NotImplementedError(
"Adaptive input is not supported for model parallel"
)
else:
embed_tokens = cls.build_embedding(
args, task.source_dictionary, args.decoder_input_dim
)
decoder = ModelParallelTransformerDecoder(
args,
task.target_dictionary,
embed_tokens,
no_encoder_attn=True,
)
return cls(decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
def _vocab_init(tensor, **kwargs):
nn.init.normal_(tensor, mean=0, std=embed_dim**-0.5)
nn.init.constant_(tensor[1], 0)
embed_tokens = VocabParallelEmbedding(
len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init
)
return embed_tokens
def base_lm_architecture(args):
# backward compatibility for older model checkpoints
if hasattr(args, "no_tie_adaptive_proj"):
# previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args.no_decoder_final_norm = True # old models always set this to True
if args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
if hasattr(args, "decoder_final_norm"):
args.no_decoder_final_norm = not args.decoder_final_norm
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.relu_dropout = getattr(args, "relu_dropout", 0.0)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
# Model training is not stable without this
args.decoder_normalize_before = True
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.character_embeddings = getattr(args, "character_embeddings", False)
args.character_filters = getattr(
args,
"character_filters",
"[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]",
)
args.character_embedding_dim = getattr(args, "character_embedding_dim", 4)
args.char_embedder_highway_layers = getattr(args, "char_embedder_highway_layers", 2)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0.0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0.0)
args.add_bos_token = getattr(args, "add_bos_token", False)
@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_megatron")
def transformer_lm_megatron(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 4)
args.decoder_layers = getattr(args, "decoder_layers", 72)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
@register_model_architecture(
"model_parallel_transformer_lm", "transformer_lm_megatron_11b"
)
def transformer_lm_megatron_11b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 6)
args.decoder_layers = getattr(args, "decoder_layers", 72)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)

View File

@@ -1,17 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
from .multihead_attention import ModelParallelMultiheadAttention
from .transformer_layer import (
ModelParallelTransformerEncoderLayer,
ModelParallelTransformerDecoderLayer,
)
__all__ = [
"ModelParallelMultiheadAttention",
"ModelParallelTransformerEncoderLayer",
"ModelParallelTransformerDecoderLayer",
]

View File

@@ -1,349 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout
try:
from fairseq.model_parallel.megatron.mpu import (
ColumnParallelLinear,
RowParallelLinear,
get_cuda_rng_tracker,
get_model_parallel_world_size,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
@with_incremental_state
class ModelParallelMultiheadAttention(nn.Module):
"""Model parallel Multi-headed attention.
This performs the Multi-headed attention over multiple gpus.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
self_attention=False,
encoder_decoder_attention=False,
):
super().__init__()
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.model_parallel_size = get_model_parallel_world_size()
self.num_heads_partition = num_heads // self.model_parallel_size
assert (
self.num_heads_partition * self.model_parallel_size == num_heads
), "Number of heads must be divisible by model parallel size"
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert (
not self.self_attention or self.qkv_same_dim
), "Self-attention requires query, key and value to be of the same size"
self.k_proj = ColumnParallelLinear(
self.kdim, embed_dim, bias=bias, gather_output=False
)
self.v_proj = ColumnParallelLinear(
self.vdim, embed_dim, bias=bias, gather_output=False
)
self.q_proj = ColumnParallelLinear(
embed_dim, embed_dim, bias=bias, gather_output=False
)
self.out_proj = RowParallelLinear(
embed_dim, embed_dim, bias=bias, input_is_parallel=True
)
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
**unused_kwargs,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
is_tpu = query.device.type == "xla"
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads_partition, self.head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads_partition, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads_partition, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads_partition, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(
bsz * self.num_heads_partition, -1, self.head_dim
)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(
bsz * self.num_heads_partition, -1, self.head_dim
)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = (
ModelParallelMultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
)
saved_state["prev_key"] = k.view(
bsz, self.num_heads_partition, -1, self.head_dim
)
saved_state["prev_value"] = v.view(
bsz, self.num_heads_partition, -1, self.head_dim
)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [
bsz * self.num_heads_partition,
tgt_len,
src_len,
]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(
bsz, self.num_heads_partition, tgt_len, src_len
)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(
bsz * self.num_heads_partition, tgt_len, src_len
)
attn_weights_float = utils.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
with get_cuda_rng_tracker().fork():
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [
bsz * self.num_heads_partition,
tgt_len,
self.head_dim,
]
embed_dim_partition = embed_dim // self.model_parallel_size
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim_partition)
attn = self.out_proj(attn)
# return attn_weights None to keep the return type same as single gpu multihead attention
# This will be deprecated.
attn_weights: Optional[Tensor] = None
return attn, attn_weights
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
if prev_key_padding_mask.is_cuda:
filler = filler.cuda()
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
elif key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
if key_padding_mask.is_cuda:
filler = filler.cuda()
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
def reorder_incremental_state(
self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
if input_buffer[k] is not None:
input_buffer[k] = input_buffer[k].index_select(0, new_order)
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return incremental_state
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)

View File

@@ -1,78 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.model_parallel.modules import ModelParallelMultiheadAttention
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
try:
from fairseq.model_parallel.megatron.mpu import (
ColumnParallelLinear,
RowParallelLinear,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
class ModelParallelTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer block over multiple gpus.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
def build_self_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
)
class ModelParallelTransformerDecoderLayer(TransformerDecoderLayer):
"""Decoder layer block.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
def build_self_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim=embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout,
self_attention=not getattr(args, "cross_self_attention", False),
)
def build_encoder_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim=embed_dim,
num_heads=args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)

View File

@@ -1,55 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from abc import ABC, abstractmethod
from fairseq import registry
from omegaconf import DictConfig
class BaseScorer(ABC):
def __init__(self, cfg):
self.cfg = cfg
self.ref = []
self.pred = []
def add_string(self, ref, pred):
self.ref.append(ref)
self.pred.append(pred)
@abstractmethod
def score(self) -> float:
pass
@abstractmethod
def result_string(self) -> str:
pass
_build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry(
"--scoring", default="bleu"
)
def build_scorer(choice, tgt_dict):
_choice = choice._name if isinstance(choice, DictConfig) else choice
if _choice == "bleu":
from fairseq.scoring import bleu
return bleu.Scorer(
bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk())
)
return _build_scorer(choice)
# automatically import any Python files in the current directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("fairseq.scoring." + module)

View File

@@ -1,44 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import numpy as np
from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer
@dataclass
class BertScoreScorerConfig(FairseqDataclass):
bert_score_lang: str = field(default="en", metadata={"help": "BERTScore language"})
@register_scorer("bert_score", dataclass=BertScoreScorerConfig)
class BertScoreScorer(BaseScorer):
def __init__(self, cfg):
super(BertScoreScorer, self).__init__(cfg)
try:
import bert_score as _bert_score
except ImportError:
raise ImportError("Please install BERTScore: pip install bert-score")
self.cfg = cfg
self._bert_score = _bert_score
self.scores = None
def add_string(self, ref, pred):
self.ref.append(ref)
self.pred.append(pred)
def score(self, order=4):
_, _, self.scores = self._bert_score.score(
self.pred, self.ref, lang=self.cfg.bert_score_lang
)
self.scores = self.scores.numpy()
return np.mean(self.scores)
def result_string(self, order=4):
return f"BERTScore: {self.score():.4f}"

View File

@@ -1,168 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ctypes
import math
import sys
from dataclasses import dataclass, field
import torch
from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer
from fairseq.scoring.tokenizer import EvaluationTokenizer
class BleuStat(ctypes.Structure):
_fields_ = [
("reflen", ctypes.c_size_t),
("predlen", ctypes.c_size_t),
("match1", ctypes.c_size_t),
("count1", ctypes.c_size_t),
("match2", ctypes.c_size_t),
("count2", ctypes.c_size_t),
("match3", ctypes.c_size_t),
("count3", ctypes.c_size_t),
("match4", ctypes.c_size_t),
("count4", ctypes.c_size_t),
]
@dataclass
class SacrebleuConfig(FairseqDataclass):
sacrebleu_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field(
default="13a", metadata={"help": "tokenizer"}
)
sacrebleu_lowercase: bool = field(
default=False, metadata={"help": "apply lowercasing"}
)
sacrebleu_char_level: bool = field(
default=False, metadata={"help": "evaluate at character level"}
)
@register_scorer("sacrebleu", dataclass=SacrebleuConfig)
class SacrebleuScorer(BaseScorer):
def __init__(self, cfg):
super(SacrebleuScorer, self).__init__(cfg)
import sacrebleu
self.sacrebleu = sacrebleu
self.tokenizer = EvaluationTokenizer(
tokenizer_type=cfg.sacrebleu_tokenizer,
lowercase=cfg.sacrebleu_lowercase,
character_tokenization=cfg.sacrebleu_char_level,
)
def add_string(self, ref, pred):
self.ref.append(self.tokenizer.tokenize(ref))
self.pred.append(self.tokenizer.tokenize(pred))
def _score(self, order=4):
if order != 4:
raise NotImplementedError
# tokenization and lowercasing are performed by self.tokenizer instead.
return self.sacrebleu.corpus_bleu(self.pred, [self.ref], tokenize="none")
def score(self, order=4):
return self._score(order).score
def result_string(self, order=4):
return self._score(order).format()
@dataclass
class BleuConfig(FairseqDataclass):
pad: int = field(default=1, metadata={"help": "padding index"})
eos: int = field(default=2, metadata={"help": "eos index"})
unk: int = field(default=3, metadata={"help": "unk index"})
@register_scorer("bleu", dataclass=BleuConfig)
class Scorer(object):
def __init__(self, cfg):
self.stat = BleuStat()
self.pad = cfg.pad
self.eos = cfg.eos
self.unk = cfg.unk
try:
from fairseq import libbleu
except ImportError as e:
sys.stderr.write(
"ERROR: missing libbleu.so. run `pip install --editable .`\n"
)
raise e
self.C = ctypes.cdll.LoadLibrary(libbleu.__file__)
self.reset()
def reset(self, one_init=False):
if one_init:
self.C.bleu_one_init(ctypes.byref(self.stat))
else:
self.C.bleu_zero_init(ctypes.byref(self.stat))
def add(self, ref, pred):
if not isinstance(ref, torch.IntTensor):
raise TypeError("ref must be a torch.IntTensor (got {})".format(type(ref)))
if not isinstance(pred, torch.IntTensor):
raise TypeError("pred must be a torch.IntTensor(got {})".format(type(pred)))
# don't match unknown words
rref = ref.clone()
assert not rref.lt(0).any()
rref[rref.eq(self.unk)] = -999
rref = rref.contiguous().view(-1)
pred = pred.contiguous().view(-1)
self.C.bleu_add(
ctypes.byref(self.stat),
ctypes.c_size_t(rref.size(0)),
ctypes.c_void_p(rref.data_ptr()),
ctypes.c_size_t(pred.size(0)),
ctypes.c_void_p(pred.data_ptr()),
ctypes.c_int(self.pad),
ctypes.c_int(self.eos),
)
def score(self, order=4):
psum = sum(
math.log(p) if p > 0 else float("-Inf") for p in self.precision()[:order]
)
return self.brevity() * math.exp(psum / order) * 100
def precision(self):
def ratio(a, b):
return a / b if b > 0 else 0
return [
ratio(self.stat.match1, self.stat.count1),
ratio(self.stat.match2, self.stat.count2),
ratio(self.stat.match3, self.stat.count3),
ratio(self.stat.match4, self.stat.count4),
]
def brevity(self):
r = self.stat.reflen / self.stat.predlen
return min(1, math.exp(1 - r))
def result_string(self, order=4):
assert order <= 4, "BLEU scores for order > 4 aren't supported"
fmt = "BLEU{} = {:2.2f}, {:2.1f}"
for _ in range(1, order):
fmt += "/{:2.1f}"
fmt += " (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})"
bleup = [p * 100 for p in self.precision()[:order]]
return fmt.format(
order,
self.score(order=order),
*bleup,
self.brevity(),
self.stat.predlen / self.stat.reflen,
self.stat.predlen,
self.stat.reflen
)

View File

@@ -1,36 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer
@dataclass
class ChrFScorerConfig(FairseqDataclass):
pass
@register_scorer("chrf", dataclass=ChrFScorerConfig)
class ChrFScorer(BaseScorer):
def __init__(self, args):
super(ChrFScorer, self).__init__(args)
import sacrebleu
self.sacrebleu = sacrebleu
def add_string(self, ref, pred):
self.ref.append(ref)
self.pred.append(pred)
def score(self, order=4):
return self.result_string(order).score
def result_string(self, order=4):
if order != 4:
raise NotImplementedError
return self.sacrebleu.corpus_chrf(self.pred, [self.ref]).format()

View File

@@ -1,42 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from dataclasses import dataclass
from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer
@dataclass
class MeteorScorerConfig(FairseqDataclass):
pass
@register_scorer("meteor", dataclass=MeteorScorerConfig)
class MeteorScorer(BaseScorer):
def __init__(self, args):
super(MeteorScorer, self).__init__(args)
try:
import nltk
except ImportError:
raise ImportError("Please install nltk to use METEOR scorer")
self.nltk = nltk
self.scores = []
def add_string(self, ref, pred):
self.ref.append(ref)
self.pred.append(pred)
def score(self, order=4):
self.scores = [
self.nltk.translate.meteor_score.single_meteor_score(r, p)
for r, p in zip(self.ref, self.pred)
]
return np.mean(self.scores)
def result_string(self, order=4):
return f"METEOR: {self.score():.4f}"

View File

@@ -1,80 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unicodedata
import sacrebleu as sb
from fairseq.dataclass import ChoiceEnum
SACREBLEU_V2_ABOVE = int(sb.__version__[0]) >= 2
class EvaluationTokenizer(object):
"""A generic evaluation-time tokenizer, which leverages built-in tokenizers
in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides
lowercasing, punctuation removal and character tokenization, which are
applied after sacreBLEU tokenization.
Args:
tokenizer_type (str): the type of sacreBLEU tokenizer to apply.
lowercase (bool): lowercase the text.
punctuation_removal (bool): remove punctuation (based on unicode
category) from text.
character_tokenization (bool): tokenize the text to characters.
"""
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
_ALL_TOKENIZER_TYPES = (
sb.BLEU.TOKENIZERS
if SACREBLEU_V2_ABOVE
else ["none", "13a", "intl", "zh", "ja-mecab"]
)
ALL_TOKENIZER_TYPES = ChoiceEnum(_ALL_TOKENIZER_TYPES)
def __init__(
self,
tokenizer_type: str = "13a",
lowercase: bool = False,
punctuation_removal: bool = False,
character_tokenization: bool = False,
):
assert (
tokenizer_type in self._ALL_TOKENIZER_TYPES
), f"{tokenizer_type}, {self._ALL_TOKENIZER_TYPES}"
self.lowercase = lowercase
self.punctuation_removal = punctuation_removal
self.character_tokenization = character_tokenization
if SACREBLEU_V2_ABOVE:
self.tokenizer = sb.BLEU(tokenize=str(tokenizer_type)).tokenizer
else:
self.tokenizer = sb.tokenizers.TOKENIZERS[tokenizer_type]()
@classmethod
def remove_punctuation(cls, sent: str):
"""Remove punctuation based on Unicode category."""
return cls.SPACE.join(
t
for t in sent.split(cls.SPACE)
if not all(unicodedata.category(c)[0] == "P" for c in t)
)
def tokenize(self, sent: str):
tokenized = self.tokenizer(sent)
if self.punctuation_removal:
tokenized = self.remove_punctuation(tokenized)
if self.character_tokenization:
tokenized = self.SPACE.join(
list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE))
)
if self.lowercase:
tokenized = tokenized.lower()
return tokenized

View File

@@ -1,58 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer
from fairseq.scoring.tokenizer import EvaluationTokenizer
@dataclass
class WerScorerConfig(FairseqDataclass):
wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field(
default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"}
)
wer_remove_punct: bool = field(
default=False, metadata={"help": "remove punctuation"}
)
wer_char_level: bool = field(
default=False, metadata={"help": "evaluate at character level"}
)
wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"})
@register_scorer("wer", dataclass=WerScorerConfig)
class WerScorer(BaseScorer):
def __init__(self, cfg):
super().__init__(cfg)
self.reset()
try:
import editdistance as ed
except ImportError:
raise ImportError("Please install editdistance to use WER scorer")
self.ed = ed
self.tokenizer = EvaluationTokenizer(
tokenizer_type=self.cfg.wer_tokenizer,
lowercase=self.cfg.wer_lowercase,
punctuation_removal=self.cfg.wer_remove_punct,
character_tokenization=self.cfg.wer_char_level,
)
def reset(self):
self.distance = 0
self.ref_length = 0
def add_string(self, ref, pred):
ref_items = self.tokenizer.tokenize(ref).split()
pred_items = self.tokenizer.tokenize(pred).split()
self.distance += self.ed.eval(ref_items, pred_items)
self.ref_length += len(ref_items)
def result_string(self):
return f"WER: {self.score():.2f}"
def score(self):
return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0

View File

@@ -20,8 +20,8 @@ import os
import logging
DEBUG_PREFIX = "<RVC module>"
INPUT_FILE_PATH = "rvc_input.wav" #"./data/tmp/rvc_input.wav"
OUTPUT_FILE_PATH = "rvc_output.wav" #"./data/tmp/rvc_output.wav"
INPUT_FILE_PATH = "data/tmp/rvc_input.wav" #"./data/tmp/rvc_input.wav"
OUTPUT_FILE_PATH = "data/tmp/rvc_output.wav" #"./data/tmp/rvc_output.wav"
def rvc_process_audio():
"""