Skip to content

An unofficial PyTorch implementation of "Lin et al. ViBERTgrid: A Jointly Trained Multi-Modal 2D Document Representation for Key Information Extraction from Documents. ICDAR, 2021"

Notifications You must be signed in to change notification settings

ZeningLin/ViBERTgrid-PyTorch

Repository files navigation

ViBERTgrid PyTorch

An Unofficial PyTorch implementation of Lin et al. ViBERTgrid: A Jointly Trained Multi-Modal 2D Document Representation for Key Information Extraction from Documents. ICDAR, 2021.

To learn more about Visual Information Extraction and Document AI, please refer to Document-AI-Recommendations.

Repo Structure

  • data: data loaders
  • model: ViBERTgrid net architecture
  • pipeline: preprocessing, trainer, evaluation metrics
  • utils: data visualization, dataset spiltting
  • deployment: examples for inference and deployment
  • train_*.py: main train scripts
  • eval_*.py: evaluation scripts

Usage

1. Env Setting


pip install -r requirements.txt

2. Data Preprocessing

The following components are required for the processed dataset:

  • the original image [bs, 3, h, w]
  • label_csv, label file in csv format. Each row should contain text, left-coor, right-coor, top-coor, bot-coor, class type, pos_neg type information of each text segment.

2.1 ICDAR SROIE

It is worth noting that the labeling of the dataset will strongly affect the final result. The original SROIE dataset (https://rrc.cvc.uab.es/?ch=13&com=tasks) only contains text label of the key information. Coordinates, however, are necessary for constructing the grid, hence we re-labeled the dataset to obtain the coordinates. Unfortunately, our re-labeled data cannot be made public for some reasons. But the model trained using our relabelled data is available here (97+ entity-level F1).

Here's another method for matching the coordinates through regular expression and cosine similarity, referring to (https://github.com/antoinedelplace/Chargrid). The matching result is not satisfying and can only achieve an entity-level F1 of around 60.

You can preprocess the data by following the steps below, or you can just download the preprocessed version from here

  1. Download the official SROIE dataset here.
  2. 0325updated.task1train(626p) contains images and OCR results of the training set, put the images in img folder and txt files in box folder. txt files in 0325updatd.task2train(626p) are key-type labels, put them in the key folder.
  3. Run sroie_data_preprocessing.py. The arg train_raw is the dir to the root of the three folders mentioned above. The arg train_processed is the dir to the processed csv labels generated by the sroie_data_preprocessing.py, named label. Put the img, key and label folders in the same root named train.
  4. Download the raw data of test set from the link provided at the bottom of the official page see here and here. Follow the instructions in 2, but name the folder as test. Put it in the same root with the train folder and name the root as SROIE.
  5. The processed data should be organized as shown below
    .
    ├── test
    │   ├── image
    │   ├── key
    │   └── label
    └── train
        ├── image
        ├── key
        └── label
    

We recommend re-labeling the dataset on your own as it contains around 1k images and will not take up a lot of time, or find out a better solution to match the coordinates.

2.2 EPHOIE

The dataset can be obtained from (https://github.com/HCIILAB/EPHOIE). An unzip password will be provided after submitting the application form.
EPHOIE provides labels in txt format, you should first convert it into JSON format on your own. Then run the following command:

python ./pipeline/ephoie_data_preprocessing.py

2.3 FUNSD

Images and labels can be found here.
The FUNSD dataset contains two subtasks, entity labeling and entity linking. The ViBERTgrid model can only perform KIE on the first task, in which the text contents are labeled into 3 key types(header, question, answer). Run the following commands to generate formatted labels.

python ./pipeline/funsd_data_preprocessing.py

3. Training

First, you need to set up configurations. An example config file example_config.yaml is provided. Then run the following command. Replace * with SROIE, EPHOIE, or FUNSD.

torchrun --nnodes 1 --nproc_per_node 2 ./train_*.py -c dir_to_config_file.yaml

4. Inference

Scripts for inference are provided in the deployment folder. run inference_* to get the VIE result in JSON format.


Adaptation and Exploration in This Implementation

1. Data Level

In the paper, the author applied classification on word-level, which predicts the key type of each word and joins the words that belong to the same class as the final entity-level result.

In fact, ViBERTgrid can work on any data level, like line-level or char-level. Choosing a proper data level may significantly boost the final score. According to our experiment result, Line-level is the most suitable choice for the SROIE dataset, char-level for EPHOIE, and segment-level for FUNSD.

2. CNN Backbone

The author of the paper used an ImageNet pre-trained ResNet18-D to initialize the weights of the CNN backbone. Pretrained weights of ResNet-D, however, cannot be found in PyTorch's model zoo. Hence we use an ImageNet pretrained ResNet34 instead.

CNN backbones can be changed by setting different values in the config file, supported backbones are shown below

  • resnet_18_fpn
  • resnet_34_fpn
  • resnet_18_fpn_pretrained
  • resnet_34_fpn_pretrained
  • resnet_18_D_fpn
  • resnet_34_D_fpn

3. Field Type Classification Head

Some words could be labeled with more than one field type tags (similar to the nested named entity recognition task), we design two classifiers to input and perform field type classification for each word

To solve the problem mentioned above, the author designed a complicated two-stage classifier. We found that this classifier does not work well and is hard to fine-tune. Since the multi-label problem does not occur in SROIE, EPHOIE, and FUNSD datasets, we use a one-stage multi-class classifier with multi-class cross-entropy loss to replace the original design.

Experiments show that an additional, independent key information binary classifier may improve the final F1 score. The classifier indicates whether a text segment belongs to key information or not, which may boost the recall metric`.

5. Auxiliary Semantic Segmentation Head

In our case, the auxiliary semantic segmentation head does not help on both the SROIE and EPHOIE datasets. You can remove this branch by setting the loss_control_lambda to zero in the configuration file.

6. Tag Mode

The model can directly predict the category of each text-line/char/segment, or predict the BIO tags under the restriction of a CRF layer. We found that the representative ability of ViBERTgrid is good enough and the direct prediction works best. Using BIO tagging with CRF layers is unnecessary and has a negative effect.


Experiment Results

Dataset Configuration # of Parameters F1
SROIE original paper, BERT-Base, ResNet18-D-pretrained 142M 96.25
SROIE original paper, RoBERTa-Base, ResNet18-D-pretrained 147M 96.40
SROIE BERT-Base uncased, ResNet34-pretrained 151M 97.16
EPHOIE BERT-Base chinese, ResNet34-pretrained 145M 96.55
FUNSD BERT-Base uncased, ResNet34-pretrained 151M 87.63

Note

  • Due to source limitations, I used 2 NVIDIA TITAN X for training, which can only afford a batch size of 4 (2 on each GPU). The loss curve is not stable in this situation and may affect the performance.

About

An unofficial PyTorch implementation of "Lin et al. ViBERTgrid: A Jointly Trained Multi-Modal 2D Document Representation for Key Information Extraction from Documents. ICDAR, 2021"

Topics

Resources

Stars

Watchers

Forks

Packages

No packages published

Languages

pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy