123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898 |
- from typing import Iterable, Optional
- import types
- import time
- import numpy as np
- import torch
- import torch.nn.functional as F
- from torch import Tensor
- from torch import nn
- from torch.cuda.amp import autocast
- from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
- from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
- from funasr.train_utils.device_funcs import force_gatherable
- from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
- from funasr.utils.datadir_writer import DatadirWriter
- from funasr.models.ctc.ctc import CTC
- from funasr.register import tables
- from funasr.models.paraformer.search import Hypothesis
- class SinusoidalPositionEncoder(torch.nn.Module):
- """ """
- def __int__(self, d_model=80, dropout_rate=0.1):
- pass
- def encode(
- self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
- ):
- batch_size = positions.size(0)
- positions = positions.type(dtype)
- device = positions.device
- log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (
- depth / 2 - 1
- )
- inv_timescales = torch.exp(
- torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment)
- )
- inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
- scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
- inv_timescales, [1, 1, -1]
- )
- encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
- return encoding.type(dtype)
- def forward(self, x):
- batch_size, timesteps, input_dim = x.size()
- positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
- position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
- return x + position_encoding
- class PositionwiseFeedForward(torch.nn.Module):
- """Positionwise feed forward layer.
- Args:
- idim (int): Input dimenstion.
- hidden_units (int): The number of hidden units.
- dropout_rate (float): Dropout rate.
- """
- def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
- """Construct an PositionwiseFeedForward object."""
- super(PositionwiseFeedForward, self).__init__()
- self.w_1 = torch.nn.Linear(idim, hidden_units)
- self.w_2 = torch.nn.Linear(hidden_units, idim)
- self.dropout = torch.nn.Dropout(dropout_rate)
- self.activation = activation
- def forward(self, x):
- """Forward function."""
- return self.w_2(self.dropout(self.activation(self.w_1(x))))
- class MultiHeadedAttentionSANM(nn.Module):
- """Multi-Head Attention layer.
- Args:
- n_head (int): The number of heads.
- n_feat (int): The number of features.
- dropout_rate (float): Dropout rate.
- """
- def __init__(
- self,
- n_head,
- in_feat,
- n_feat,
- dropout_rate,
- kernel_size,
- sanm_shfit=0,
- lora_list=None,
- lora_rank=8,
- lora_alpha=16,
- lora_dropout=0.1,
- ):
- """Construct an MultiHeadedAttention object."""
- super().__init__()
- assert n_feat % n_head == 0
- # We assume d_v always equals d_k
- self.d_k = n_feat // n_head
- self.h = n_head
- # self.linear_q = nn.Linear(n_feat, n_feat)
- # self.linear_k = nn.Linear(n_feat, n_feat)
- # self.linear_v = nn.Linear(n_feat, n_feat)
- self.linear_out = nn.Linear(n_feat, n_feat)
- self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
- self.attn = None
- self.dropout = nn.Dropout(p=dropout_rate)
- self.fsmn_block = nn.Conv1d(
- n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
- )
- # padding
- left_padding = (kernel_size - 1) // 2
- if sanm_shfit > 0:
- left_padding = left_padding + sanm_shfit
- right_padding = kernel_size - 1 - left_padding
- self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
- def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
- b, t, d = inputs.size()
- if mask is not None:
- mask = torch.reshape(mask, (b, -1, 1))
- if mask_shfit_chunk is not None:
- mask = mask * mask_shfit_chunk
- inputs = inputs * mask
- x = inputs.transpose(1, 2)
- x = self.pad_fn(x)
- x = self.fsmn_block(x)
- x = x.transpose(1, 2)
- x += inputs
- x = self.dropout(x)
- if mask is not None:
- x = x * mask
- return x
- def forward_qkv(self, x):
- """Transform query, key and value.
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- Returns:
- torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
- torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
- torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
- """
- b, t, d = x.size()
- q_k_v = self.linear_q_k_v(x)
- q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
- q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
- 1, 2
- ) # (batch, head, time1, d_k)
- k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
- 1, 2
- ) # (batch, head, time2, d_k)
- v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
- 1, 2
- ) # (batch, head, time2, d_k)
- return q_h, k_h, v_h, v
- def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
- """Compute attention context vector.
- Args:
- value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
- scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
- mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
- Returns:
- torch.Tensor: Transformed value (#batch, time1, d_model)
- weighted by the attention score (#batch, time1, time2).
- """
- n_batch = value.size(0)
- if mask is not None:
- if mask_att_chunk_encoder is not None:
- mask = mask * mask_att_chunk_encoder
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = -float(
- "inf"
- ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
- scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
- mask, 0.0
- ) # (batch, head, time1, time2)
- else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
- x = (
- x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
- ) # (batch, time1, d_model)
- return self.linear_out(x) # (batch, time1, d_model)
- def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
- """Compute scaled dot product attention.
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
- """
- q_h, k_h, v_h, v = self.forward_qkv(x)
- fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
- return att_outs + fsmn_memory
- def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
- """Compute scaled dot product attention.
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
- """
- q_h, k_h, v_h, v = self.forward_qkv(x)
- if chunk_size is not None and look_back > 0 or look_back == -1:
- if cache is not None:
- k_h_stride = k_h[:, :, : -(chunk_size[2]), :]
- v_h_stride = v_h[:, :, : -(chunk_size[2]), :]
- k_h = torch.cat((cache["k"], k_h), dim=2)
- v_h = torch.cat((cache["v"], v_h), dim=2)
- cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
- cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
- if look_back != -1:
- cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :]
- cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :]
- else:
- cache_tmp = {
- "k": k_h[:, :, : -(chunk_size[2]), :],
- "v": v_h[:, :, : -(chunk_size[2]), :],
- }
- cache = cache_tmp
- fsmn_memory = self.forward_fsmn(v, None)
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- att_outs = self.forward_attention(v_h, scores, None)
- return att_outs + fsmn_memory, cache
- class LayerNorm(nn.LayerNorm):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- def forward(self, input):
- output = F.layer_norm(
- input.float(),
- self.normalized_shape,
- self.weight.float() if self.weight is not None else None,
- self.bias.float() if self.bias is not None else None,
- self.eps,
- )
- return output.type_as(input)
- def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
- if maxlen is None:
- maxlen = lengths.max()
- row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
- matrix = torch.unsqueeze(lengths, dim=-1)
- mask = row_vector < matrix
- mask = mask.detach()
- return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
- class EncoderLayerSANM(nn.Module):
- def __init__(
- self,
- in_size,
- size,
- self_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- stochastic_depth_rate=0.0,
- ):
- """Construct an EncoderLayer object."""
- super(EncoderLayerSANM, self).__init__()
- self.self_attn = self_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(in_size)
- self.norm2 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.in_size = in_size
- self.size = size
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear = nn.Linear(size + size, size)
- self.stochastic_depth_rate = stochastic_depth_rate
- self.dropout_rate = dropout_rate
- def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
- """Compute encoded features.
- Args:
- x_input (torch.Tensor): Input tensor (#batch, time, size).
- mask (torch.Tensor): Mask tensor for the input (#batch, time).
- cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
- Returns:
- torch.Tensor: Output tensor (#batch, time, size).
- torch.Tensor: Mask tensor (#batch, time).
- """
- skip_layer = False
- # with stochastic depth, residual connection `x + f(x)` becomes
- # `x <- x + 1 / (1 - p) * f(x)` at training time.
- stoch_layer_coeff = 1.0
- if self.training and self.stochastic_depth_rate > 0:
- skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
- stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
- if skip_layer:
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
- return x, mask
- residual = x
- if self.normalize_before:
- x = self.norm1(x)
- if self.concat_after:
- x_concat = torch.cat(
- (
- x,
- self.self_attn(
- x,
- mask,
- mask_shfit_chunk=mask_shfit_chunk,
- mask_att_chunk_encoder=mask_att_chunk_encoder,
- ),
- ),
- dim=-1,
- )
- if self.in_size == self.size:
- x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
- else:
- x = stoch_layer_coeff * self.concat_linear(x_concat)
- else:
- if self.in_size == self.size:
- x = residual + stoch_layer_coeff * self.dropout(
- self.self_attn(
- x,
- mask,
- mask_shfit_chunk=mask_shfit_chunk,
- mask_att_chunk_encoder=mask_att_chunk_encoder,
- )
- )
- else:
- x = stoch_layer_coeff * self.dropout(
- self.self_attn(
- x,
- mask,
- mask_shfit_chunk=mask_shfit_chunk,
- mask_att_chunk_encoder=mask_att_chunk_encoder,
- )
- )
- if not self.normalize_before:
- x = self.norm1(x)
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm2(x)
- return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
- def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
- """Compute encoded features.
- Args:
- x_input (torch.Tensor): Input tensor (#batch, time, size).
- mask (torch.Tensor): Mask tensor for the input (#batch, time).
- cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
- Returns:
- torch.Tensor: Output tensor (#batch, time, size).
- torch.Tensor: Mask tensor (#batch, time).
- """
- residual = x
- if self.normalize_before:
- x = self.norm1(x)
- if self.in_size == self.size:
- attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
- x = residual + attn
- else:
- x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
- if not self.normalize_before:
- x = self.norm1(x)
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- x = residual + self.feed_forward(x)
- if not self.normalize_before:
- x = self.norm2(x)
- return x, cache
- @tables.register("encoder_classes", "SenseVoiceEncoderSmall")
- class SenseVoiceEncoderSmall(nn.Module):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
- https://arxiv.org/abs/2006.01713
- """
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- tp_blocks: int = 0,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- stochastic_depth_rate: float = 0.0,
- input_layer: Optional[str] = "conv2d",
- pos_enc_class=SinusoidalPositionEncoder,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- kernel_size: int = 11,
- sanm_shfit: int = 0,
- selfattention_layer_type: str = "sanm",
- **kwargs,
- ):
- super().__init__()
- self._output_size = output_size
- self.embed = SinusoidalPositionEncoder()
- self.normalize_before = normalize_before
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- )
- encoder_selfattn_layer = MultiHeadedAttentionSANM
- encoder_selfattn_layer_args0 = (
- attention_heads,
- input_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
- self.encoders0 = nn.ModuleList(
- [
- EncoderLayerSANM(
- input_size,
- output_size,
- encoder_selfattn_layer(*encoder_selfattn_layer_args0),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- )
- for i in range(1)
- ]
- )
- self.encoders = nn.ModuleList(
- [
- EncoderLayerSANM(
- output_size,
- output_size,
- encoder_selfattn_layer(*encoder_selfattn_layer_args),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- )
- for i in range(num_blocks - 1)
- ]
- )
- self.tp_encoders = nn.ModuleList(
- [
- EncoderLayerSANM(
- output_size,
- output_size,
- encoder_selfattn_layer(*encoder_selfattn_layer_args),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- )
- for i in range(tp_blocks)
- ]
- )
- self.after_norm = LayerNorm(output_size)
- self.tp_norm = LayerNorm(output_size)
- def output_size(self) -> int:
- return self._output_size
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- ):
- """Embed positions in tensor."""
- masks = sequence_mask(ilens, device=ilens.device)[:, None, :]
- xs_pad *= self.output_size() ** 0.5
- xs_pad = self.embed(xs_pad)
- # forward encoder1
- for layer_idx, encoder_layer in enumerate(self.encoders0):
- encoder_outs = encoder_layer(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- xs_pad = self.after_norm(xs_pad)
- # forward encoder2
- olens = masks.squeeze(1).sum(1).int()
- for layer_idx, encoder_layer in enumerate(self.tp_encoders):
- encoder_outs = encoder_layer(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- xs_pad = self.tp_norm(xs_pad)
- return xs_pad, olens
- @tables.register("model_classes", "SenseVoiceSmall")
- class SenseVoiceSmall(nn.Module):
- """CTC-attention hybrid Encoder-Decoder model"""
- def __init__(
- self,
- specaug: str = None,
- specaug_conf: dict = None,
- normalize: str = None,
- normalize_conf: dict = None,
- encoder: str = None,
- encoder_conf: dict = None,
- ctc_conf: dict = None,
- input_size: int = 80,
- vocab_size: int = -1,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- length_normalized_loss: bool = False,
- **kwargs,
- ):
- super().__init__()
- if specaug is not None:
- specaug_class = tables.specaug_classes.get(specaug)
- specaug = specaug_class(**specaug_conf)
- if normalize is not None:
- normalize_class = tables.normalize_classes.get(normalize)
- normalize = normalize_class(**normalize_conf)
- encoder_class = tables.encoder_classes.get(encoder)
- encoder = encoder_class(input_size=input_size, **encoder_conf)
- encoder_output_size = encoder.output_size()
- if ctc_conf is None:
- ctc_conf = {}
- ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
- self.blank_id = blank_id
- self.sos = sos if sos is not None else vocab_size - 1
- self.eos = eos if eos is not None else vocab_size - 1
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.specaug = specaug
- self.normalize = normalize
- self.encoder = encoder
- self.error_calculator = None
- self.ctc = ctc
- self.length_normalized_loss = length_normalized_loss
- self.encoder_output_size = encoder_output_size
- self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
- self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
- self.textnorm_dict = {"withitn": 14, "woitn": 15}
- self.textnorm_int_dict = {25016: 14, 25017: 15}
- self.embed = torch.nn.Embedding(7 + len(self.lid_dict) + len(self.textnorm_dict), input_size)
-
- self.criterion_att = LabelSmoothingLoss(
- size=self.vocab_size,
- padding_idx=self.ignore_id,
- smoothing=kwargs.get("lsm_weight", 0.0),
- normalize_length=self.length_normalized_loss,
- )
-
- @staticmethod
- def from_pretrained(model:str=None, **kwargs):
- from funasr import AutoModel
- model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
-
- return model, kwargs
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ):
- """Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- # import pdb;
- # pdb.set_trace()
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
- batch_size = speech.shape[0]
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text)
- loss_ctc, cer_ctc = None, None
- loss_rich, acc_rich = None, None
- stats = dict()
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out[:, 4:, :], encoder_out_lens - 4, text[:, 4:], text_lengths - 4
- )
- loss_rich, acc_rich = self._calc_rich_ce_loss(
- encoder_out[:, :4, :], text[:, :4]
- )
- loss = loss_ctc
- # Collect total loss stats
- stats["loss"] = torch.clone(loss.detach()) if loss_ctc is not None else None
- stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None
- stats["acc_rich"] = acc_rich
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = int((text_lengths + 1).sum())
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
- def encode(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- **kwargs,
- ):
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- # Data augmentation
- if self.specaug is not None and self.training:
- speech, speech_lengths = self.specaug(speech, speech_lengths)
- # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- speech, speech_lengths = self.normalize(speech, speech_lengths)
- lids = torch.LongTensor([[self.lid_int_dict[int(lid)] if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict else 0 ] for lid in text[:, 0]]).to(speech.device)
- language_query = self.embed(lids)
-
- styles = torch.LongTensor([[self.textnorm_int_dict[int(style)]] for style in text[:, 3]]).to(speech.device)
- style_query = self.embed(styles)
- speech = torch.cat((style_query, speech), dim=1)
- speech_lengths += 1
- event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
- input_query = torch.cat((language_query, event_emo_query), dim=1)
- speech = torch.cat((input_query, speech), dim=1)
- speech_lengths += 3
- encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
- return encoder_out, encoder_out_lens
- def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- # Calc CTC loss
- loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
- # Calc CER using CTC
- cer_ctc = None
- if not self.training and self.error_calculator is not None:
- ys_hat = self.ctc.argmax(encoder_out).data
- cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
- return loss_ctc, cer_ctc
- def _calc_rich_ce_loss(
- self,
- encoder_out: torch.Tensor,
- ys_pad: torch.Tensor,
- ):
- decoder_out = self.ctc.ctc_lo(encoder_out)
- # 2. Compute attention loss
- loss_rich = self.criterion_att(decoder_out, ys_pad.contiguous())
- acc_rich = th_accuracy(
- decoder_out.view(-1, self.vocab_size),
- ys_pad.contiguous(),
- ignore_label=self.ignore_id,
- )
- return loss_rich, acc_rich
- def inference(
- self,
- data_in,
- data_lengths=None,
- key: list = ["wav_file_tmp_name"],
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
- meta_data = {}
- if (
- isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
- ): # fbank
- speech, speech_lengths = data_in, data_lengths
- if len(speech.shape) < 3:
- speech = speech[None, :, :]
- if speech_lengths is None:
- speech_lengths = speech.shape[1]
- else:
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(
- data_in,
- fs=frontend.fs,
- audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer,
- )
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(
- audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
- )
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data["batch_data_time"] = (
- speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
- )
- speech = speech.to(device=kwargs["device"])
- speech_lengths = speech_lengths.to(device=kwargs["device"])
- language = kwargs.get("language", "auto")
- language_query = self.embed(
- torch.LongTensor(
- [[self.lid_dict[language] if language in self.lid_dict else 0]]
- ).to(speech.device)
- ).repeat(speech.size(0), 1, 1)
-
- use_itn = kwargs.get("use_itn", False)
- textnorm = kwargs.get("text_norm", None)
- if textnorm is None:
- textnorm = "withitn" if use_itn else "woitn"
- textnorm_query = self.embed(
- torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)
- ).repeat(speech.size(0), 1, 1)
- speech = torch.cat((textnorm_query, speech), dim=1)
- speech_lengths += 1
- event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
- speech.size(0), 1, 1
- )
- input_query = torch.cat((language_query, event_emo_query), dim=1)
- speech = torch.cat((input_query, speech), dim=1)
- speech_lengths += 3
- # Encoder
- encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
- # c. Passed the encoder result and the beam search
- ctc_logits = self.ctc.log_softmax(encoder_out)
- results = []
- b, n, d = encoder_out.size()
- if isinstance(key[0], (list, tuple)):
- key = key[0]
- if len(key) < b:
- key = key * b
- for i in range(b):
- x = ctc_logits[i, : encoder_out_lens[i].item(), :]
- yseq = x.argmax(dim=-1)
- yseq = torch.unique_consecutive(yseq, dim=-1)
- ibest_writer = None
- if kwargs.get("output_dir") is not None:
- if not hasattr(self, "writer"):
- self.writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = self.writer[f"1best_recog"]
- mask = yseq != self.blank_id
- token_int = yseq[mask].tolist()
- # Change integer-ids to tokens
- text = tokenizer.decode(token_int)
- result_i = {"key": key[i], "text": text}
- results.append(result_i)
- if ibest_writer is not None:
- ibest_writer["text"][key[i]] = text
- return results, meta_data
- def export(self, **kwargs):
- from .export_meta import export_rebuild_model
- if "max_seq_len" not in kwargs:
- kwargs["max_seq_len"] = 512
- models = export_rebuild_model(model=self, **kwargs)
- return models
|