Skip to content

ictnlp/MonoAttn-Transducer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Overcoming Non-monotonicity in Transducer-based Streaming Generation

Authors: Zhengrui Ma, Yang Feng*, Min Zhang

Files:

  • We mainly provide the following files as plugins into fairseq:920a54 in the fs_plugins directory.

    fs_plugins
    ├── agents
    │   ├── attention_transducer_agent.py
    │   ├── monotonic_transducer_agent.py
    |   ├── transducer_agent.py         
    │   └── transducer_agent_v2.py
    ├── criterions
    │   ├── __init__.py
    │   ├── transducer_loss.py             
    │   └── transducer_loss_asr.py                    
    ├── datasets
    │   └── transducer_speech_to_text_dataset.py
    ├── models
    │   ├── transducer
    │   │    ├── __init__.py
    │   │    ├── attention_transducer.py
    │   │    ├── monotonic_transducer.py
    │   │    ├── monotonic_transducer_chunk_diagonal_prior.py
    │   │    ├── monotonic_transducer_chunk_diagonal_prior_only.py
    │   │    ├── monotonic_transducer_diagonal_prior.py
    │   │    ├── transducer.py
    │   │    ├── transducer_config.py
    │   │    └── transducer_loss.py
    │   └── __init__.py
    ├── modules 
    │   ├── attention_transducer_decoder.py
    │   ├── audio_convs.py
    │   ├── audio_encoder.py
    │   ├── monotonic_transducer_decoder.py
    │   ├── monotonic_transformer_layer.py
    │   ├── multihead_attention_patched.py
    │   ├── multihead_attention_relative.py
    │   ├── rand_pos.py
    │   ├── transducer_decoder.py
    │   ├── transducer_monotonic_multihead_attention.py
    │   └── unidirectional_encoder.py
    ├── optim 
    │   ├── __init__.py 
    │   └── radam.py
    ├── scripts 
    │   ├── average_checkpoints.py
    |   ├── prep_mustc_data.py
    │   └── substitute_target.py
    ├── tasks 
    │   ├── __init__.py 
    │   └── transducer_speech_to_text.py
    ├── __init__.py
    └── utils.py
    

Data Preparation

Please refer to Fairseq's speech-to-text modeling tutorial.

Training Transformer-Transducer

ASR Pretraining

We use a batch size of approximating 160k tokens (GPU number * max_tokens * update_freq == 160k).

main=64
downsample=4
lr=5e-4
warm=4000
dropout=0.1
tokens=8000
language=es

exp=en${language}.asr.cs_${main}.ds_${downsample}.kd.t_t.add.prenorm.amp.adam.lr_${lr}.warm_${warm}.drop_${dropout}.tk_${tokens}.bsz_160k
MUSTC_ROOT=/path_to_your_dataset/mustc/
checkpoint_dir=./checkpoints/$exp

nohup fairseq-train ${MUSTC_ROOT}/en-${language} \
    --amp \
    --config-yaml config_st.yaml --train-subset train_st_distilled --valid-subset dev_st \
    --user-dir fs_plugins \
    --task transducer_speech_to_text --arch t_t \
    --max-source-positions 6000 --max-target-positions 1024 \
    --main-context ${main} --right-context 0 --transducer-downsample ${downsample} \
    --share-decoder-input-output-embed --rand-pos-encoder 300 --encoder-max-relative-position 32 \
    --activation-dropout 0.1 --attention-dropout 0.1 \
    --criterion transducer_loss_asr \
    --dropout ${dropout} --weight-decay 0.01 --clip-norm 5.0 \
    --optimizer adam --adam-betas '(0.9,0.98)' \
    --lr ${lr} --lr-scheduler inverse_sqrt \
    --warmup-init-lr '1e-07' --warmup-updates ${warm} \
    --stop-min-lr '1e-09' --max-update 150000 \
    --max-tokens ${tokens} --update-freq 20 --grouped-shuffling \
    --save-dir ${checkpoint_dir} \
    --ddp-backend=legacy_ddp \
    --no-progress-bar --log-format json --log-interval 100 \
    --save-interval-updates 2000 --keep-interval-updates 10 \
    --save-interval 1000 --keep-last-epochs 10 \
    --fixed-validation-seed 7 \
    --skip-invalid-size-inputs-valid-test \
    --validate-interval 1000 --validate-interval-updates 2000 \
    --best-checkpoint-metric rnn_t_loss --keep-best-checkpoints 5 \
    --patience 20 --num-workers 8 \
    --tensorboard-logdir logs_board/$exp >> logs/$exp.txt &

ST Training

We use a batch size of approximating 160k tokens (GPU number * max_tokens * update_freq == 160k).

main=64
downsample=4
lr=5e-4
warm=4000
dropout=0.1
tokens=8000
language=es
pretrained_path=/path_to_asr_pretrained_checkpoint/avearge.pt


exp=en${language}.s2t.cs_${main}.ds_${downsample}.kd.t_t.add.prenorm.amp.adam.lr_${lr}.warm_${warm}.drop_${dropout}.tk_${tokens}.bsz_160k
MUSTC_ROOT=/path_to_your_dataset/mustc/
checkpoint_dir=./checkpoints/en-${language}/st/$exp

nohup fairseq-train ${MUSTC_ROOT}/en-${language} \
    --load-pretrained-encoder-from ${pretrained_path} \
    --amp \
    --config-yaml config_st.yaml --train-subset train_st_distilled --valid-subset dev_st \
    --user-dir fs_plugins \
    --task transducer_speech_to_text --arch t_t \
    --max-source-positions 6000 --max-target-positions 1024 \
    --main-context ${main} --right-context 0 --transducer-downsample ${downsample} \
    --share-decoder-input-output-embed --rand-pos-encoder 300 --encoder-max-relative-position 32 \
    --activation-dropout 0.1 --attention-dropout 0.1 \
    --criterion transducer_loss \
    --dropout ${dropout} --weight-decay 0.01 --clip-norm 5.0 \
    --optimizer adam --adam-betas '(0.9,0.98)' \
    --lr ${lr} --lr-scheduler inverse_sqrt \
    --warmup-init-lr '1e-07' --warmup-updates ${warm} \
    --stop-min-lr '1e-09' --max-update 150000 \
    --max-tokens ${tokens} --update-freq 10 --grouped-shuffling \
    --save-dir ${checkpoint_dir} \
    --ddp-backend=legacy_ddp \
    --no-progress-bar --log-format json --log-interval 100 \
    --save-interval-updates 2000 --keep-interval-updates 10 \
    --save-interval 1000 --keep-last-epochs 10 \
    --fixed-validation-seed 7 \
    --skip-invalid-size-inputs-valid-test \
    --validate-interval 1000 --validate-interval-updates 2000 \
    --best-checkpoint-metric rnn_t_loss --keep-best-checkpoints 5 \
    --patience 20 --num-workers 8 \
    --tensorboard-logdir logs_board/$exp >> logs/$exp.txt &

Training MonoAttn-Transducer

Offline-Attn Pretraining

We use a batch size of approximating 160k tokens (GPU number * max_tokens * update_freq == 160k).

main=64
downsample=4
lr=5e-4
warm=4000
dropout=0.1
tokens=8000
language=es
pretrained_path=/path_to_asr_pretrained_checkpoint/avearge.pt  # Use Transformer-Transducer ASR Pretrained Model

exp=en${language}.s2t.cs_${main}.ds_${downsample}.kd.attn_t_t.add.prenorm.amp.adam.lr_${lr}.warm_${warm}.drop_${dropout}.tk_${tokens}.bsz_160k
MUSTC_ROOT=/path_to_your_dataset/mustc/
checkpoint_dir=./checkpoints/en-${language}/st/$exp

nohup fairseq-train ${MUSTC_ROOT}/en-${language} \
    --load-pretrained-encoder-from ${pretrained_path} \
    --amp \
    --config-yaml config_st.yaml --train-subset train_st_distilled --valid-subset dev_st \
    --user-dir fs_plugins \
    --task transducer_speech_to_text --arch attention_t_t \
    --max-source-positions 6000 --max-target-positions 1024 \
    --main-context ${main} --right-context 0 --transducer-downsample ${downsample} \
    --share-decoder-input-output-embed --rand-pos-encoder 300 --encoder-max-relative-position 32 \
    --activation-dropout 0.1 --attention-dropout 0.1 \
    --criterion transducer_loss \
    --dropout ${dropout} --weight-decay 0.01 --clip-norm 5.0 \
    --optimizer adam --adam-betas '(0.9,0.98)' \
    --lr ${lr} --lr-scheduler inverse_sqrt \
    --warmup-init-lr '1e-07' --warmup-updates ${warm} \
    --stop-min-lr '1e-09' --max-update 50000 \
    --max-tokens ${tokens} --update-freq 5 --grouped-shuffling \
    --save-dir ${checkpoint_dir} \
    --ddp-backend=legacy_ddp \
    --no-progress-bar --log-format json --log-interval 100 \
    --save-interval-updates 2000 --keep-interval-updates 10 \
    --save-interval 1000 --keep-last-epochs 10 \
    --fixed-validation-seed 7 \
    --skip-invalid-size-inputs-valid-test \
    --validate-interval 1000 --validate-interval-updates 2000 \
    --best-checkpoint-metric rnn_t_loss --keep-best-checkpoints 5 \
    --patience 20 --num-workers 8 --max-tokens-valid 4800 \
    --tensorboard-logdir logs_board/$exp > logs/$exp.txt &

Mono-Attn Training

We use a batch size of approximating 160k tokens (GPU number * max_tokens * update_freq == 160k).

main=64
downsample=4
lr=5e-4
warm=4000
dropout=0.1
tokens=10000
language=es
                                 
pretrained_path=/path_to_offline_attn_trained_model/average.pt

exp=en${language}.s2t.cs_${main}.ds_${downsample}.kd.mono_t_t_chunk_dia_prior.add.prenorm.amp.adam.lr_${lr}.warm_${warm}.drop_${dropout}.tk_${tokens}.bsz_160k
MUSTC_ROOT=/path_to_your_dataset/mustc/
checkpoint_dir=./checkpoints/en-${language}/st/$exp

nohup fairseq-train ${MUSTC_ROOT}/en-${language} \
    --load-pretrained-encoder-from ${pretrained_path} \
    --load-pretrained-decoder-from ${pretrained_path} \
    --amp \
    --config-yaml config_st.yaml --train-subset train_st_distilled --valid-subset dev_st \
    --user-dir fs_plugins \
    --task transducer_speech_to_text --arch monotonic_t_t_chunk_diagonal_prior \
    --max-source-positions 6000 --max-target-positions 1024 \
    --main-context ${main} --right-context ${main} --transducer-downsample ${downsample} \
    --share-decoder-input-output-embed --rand-pos-encoder 300 --encoder-max-relative-position 32 \
    --activation-dropout 0.1 --attention-dropout 0.1 \
    --criterion transducer_loss \
    --dropout ${dropout} --weight-decay 0.01 --clip-norm 5.0 \
    --optimizer adam --adam-betas '(0.9,0.98)' \
    --lr ${lr} --lr-scheduler inverse_sqrt \
    --warmup-init-lr '1e-07' --warmup-updates ${warm} \
    --stop-min-lr '1e-09' --max-update 20000 \
    --max-tokens ${tokens} --update-freq 8 --grouped-shuffling \
    --save-dir ${checkpoint_dir} \
    --ddp-backend=legacy_ddp \
    --no-progress-bar --log-format json --log-interval 100 \
    --save-interval-updates 2000 --keep-interval-updates 20 \
    --save-interval 1000 --keep-last-epochs 10 \
    --fixed-validation-seed 7 \
    --skip-invalid-size-inputs-valid-test \
    --validate-interval 1000 --validate-interval-updates 2000 \
    --best-checkpoint-metric montonic_rnn_t_loss --keep-best-checkpoints 5 \
    --patience 20 --num-workers 8 \
    --tensorboard-logdir logs_board/$exp > logs/$exp.txt &

Inference

Testing Transformer-Transducer

Use the agent transducer_agent_v2

LANGUAGE=es
exp=enes.s2t.cs_64.ds_4.kd.t_t.add.prenorm.amp.adam.lr_5e-4.warm_4000.drop_0.1.tk_10000.bsz_160k
ckpt=average_last_5_40000
file=./checkpoints/en-${LANGUAGE}/st/${exp}/${ckpt}.pt
output_dir=./results/en-${LANGUAGE}/st
main_context=64
downsample=4

simuleval \
    --data-bin /dataset/mustc/en-${LANGUAGE} \
    --source /dataset/mustc/en-${LANGUAGE}/data_segment/tst-COMMON.wav_list --target /dataset/mustc/en-${LANGUAGE}/data_segment/tst-COMMON.${LANGUAGE} \
    --model-path $file \
    --config-yaml config_st.yaml \
    --agent ./fs_plugins/agents/transducer_agent_v2.py \
    --transducer-downsample ${downsample} --main-context ${main_context} --right-context ${main_context} \
    --source-segment-size ${main_context}0 \
    --output $output_dir/${exp}_${ckpt} \
    --quality-metrics BLEU  --latency-metrics AL \
    --device gpu

Inference

Testing MonoAttn-Transducer

Use the agent monotonic_transducer_agent

LANGUAGE=es
exp=enes.s2t.cs_64.ds_4.kd.mono_t_t.add.prenorm.amp.adam.lr_5e-4.warm_4000.drop_0.1.tk_10000.bsz_160k
ckpt=average_last_5_40000
file=./checkpoints/en-${LANGUAGE}/st/${exp}/${ckpt}.pt
output_dir=./results/en-${LANGUAGE}/st
main_context=64
downsample=4

simuleval \
    --data-bin /dataset/mustc/en-${LANGUAGE} \
    --source /dataset/mustc/en-${LANGUAGE}/data_segment/tst-COMMON.wav_list --target /dataset/mustc/en-${LANGUAGE}/data_segment/tst-COMMON.${LANGUAGE} \
    --model-path $file \
    --config-yaml config_st.yaml \
    --agent ./fs_plugins/agents/monotonic_transducer_agent.py \
    --transducer-downsample ${downsample} --main-context ${main_context} --right-context ${main_context} \
    --source-segment-size ${main_context}0 \
    --output $output_dir/${exp}_${ckpt} \
    --quality-metrics BLEU  --latency-metrics AL \
    --device gpu

Citing

Please kindly cite us if you find our papers or codes useful.

@inproceedings{
    ma2025overcoming,
    title={Overcoming Non-monotonicity in Transducer-based Streaming Generation},
    author={Zhengrui Ma and Yang Feng and Min Zhang},
    booktitle={Proceedings of the 42nd International Conference on Machine Learning},
    year={2025},
    url={https://arxiv.org/abs/2411.17170}
}

About

Code for ICML25 Paper "Overcoming Non-monotonicity in Transducer-based Streaming Generation"

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy