Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add project Donut #1994

Open
wants to merge 10 commits into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,10 @@ MMOCR 是一款由来自不同高校和企业的研发人员共同参与贡献

## 欢迎加入 OpenMMLab 社区

扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),加入 OpenMMLab 团队的 [官方交流 QQ 群](https://r.vansin.top/?r=join-qq),或通过添加微信“Open小喵Lab”加入官方交流微信群。
扫描下方的二维码可关注 OpenMMLab 团队的 知乎官方账号,扫描下方微信二维码添加喵喵好友,进入 MMOCR 微信交流社群。【加好友申请格式:研究方向+地区+学校/公司+姓名】

<div align="center">
<img src="https://raw.githubusercontent.com/open-mmlab/mmcv/master/docs/en/_static/zhihu_qrcode.jpg" height="400" /> <img src="https://cdn.vansin.top/OpenMMLab/q3.png" height="400" /> <img src="https://raw.githubusercontent.com/open-mmlab/mmcv/master/docs/en/_static/wechat_qrcode.jpg" height="400" />
<img src="https://raw.githubusercontent.com/open-mmlab/mmcv/master/docs/en/_static/zhihu_qrcode.jpg" height="400" /> <img src="https://github.com/open-mmlab/mmocr/assets/62195058/bf1e53fe-df4f-4296-9e1b-61db8971985e" height="400" />
</div>

我们会在 OpenMMLab 社区为大家
Expand Down
1 change: 1 addition & 0 deletions mmocr/models/textdet/detectors/mmdet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def adapt_predictions(self, data: MMDET_SampleList,
# convert by text_repr_type
if self.text_repr_type == 'quad':
for j, poly in enumerate(filterd_polygons):
poly = poly.reshape(-1, 2)
rect = cv2.minAreaRect(poly)
vertices = cv2.boxPoints(rect)
poly = vertices.flatten()
Expand Down
2 changes: 2 additions & 0 deletions projects/Donut/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/datasets
/data
134 changes: 134 additions & 0 deletions projects/Donut/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Donut

## Description

This is an reimplementation of Donut official repo https://github.com/clovaai/donut.

## Usage

### Prerequisites

- Python 3.7
- PyTorch 1.6 or higher
- [MIM](https://github.com/open-mmlab/mim)
- [MMOCR](https://github.com/open-mmlab/mmocr)
- transformers 4.25.1

All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `Donut/` root directory, run the following line to add the current directory to `PYTHONPATH`:

```shell
# Linux
export PYTHONPATH=`pwd`:$PYTHONPATH
# Windows PowerShell
$env:PYTHONPATH=Get-Location
```

### Training commands

In the current directory, run the following command to train the model:

```bash
mim train mmocr configs/donut_cord_30e.py --work-dir work_dirs/donut_cord_30e/
```

To train on multiple GPUs, e.g. 8 GPUs, run the following command:

```bash
mim train mmocr configs/donut_cord_30e.py --work-dir work_dirs/donut_cord_30e/ --launcher pytorch --gpus 8
```

### Testing commands

Before test, you need change tokenizer_cfg in config. The checkpoint shuold be the model save dir, like `work_dirs/donut_cord_30e/`.
In the current directory, run the following command to test the model:

```bash
mim test mmocr configs/donut_cord_30e.py --work-dir work_dirs/donut_cord_30e/ --checkpoint ${CHECKPOINT_PATH}
```

## Results

> List the results as usually done in other model's README. [Example](https://github.com/open-mmlab/mmocr/blob/1.x/configs/textdet/dbnet/README.md#results-and-models)
>
> You should claim whether this is based on the pre-trained weights, which are converted from the official release; or it's a reproduced result obtained from retraining the model in this project.

| Method | Pretrained Model | Training set | Test set | #epoch | Test size | TED Acc | F1 | Download |
| :-------------------------------------: | :-----------------------: | :-----------: | :----------: | :----: | :-------: | :-----: | :----: | :----------------------: |
| [Donut_CORD](configs/donut_cord_30e.py) | naver-clova-ix/donut-base | cord-v2 Train | cord-v2 Test | 30 | 736 | 0.8977 | 0.8279 | [model](<>) \| [log](<>) |

## Citation

<!--- cslint:disable -->

```bibtex
@article{Kim_Hong_Yim_Nam_Park_Yim_Hwang_Yun_Han_Park_2021,
title={OCR-free Document Understanding Transformer},
DOI={10.48550/arxiv.2111.15664},
author={Kim, Geewook and Hong, Teakgyu and Yim, Moonbin and Nam, Jeongyeon and Park, Jinyoung and Yim, Jinyeong and Hwang, Wonseok and Yun, Sangdoo and Han, Dongyoon and Park, Seunghyun},
year={2021},
month={Nov},
language={en-US}
}
```

<!--- cslint:enable -->

## Checklist

Here is a checklist illustrating a usual development workflow of a successful project, and also serves as an overview of this project's progress.

> The PIC (person in charge) or contributors of this project should check all the items that they believe have been finished, which will further be verified by codebase maintainers via a PR.
>
> OpenMMLab's maintainer will review the code to ensure the project's quality. Reaching the first milestone means that this project suffices the minimum requirement of being merged into 'projects/'. But this project is only eligible to become a part of the core package upon attaining the last milestone.
>
> Note that keeping this section up-to-date is crucial not only for this project's developers but the entire community, since there might be some other contributors joining this project and deciding their starting point from this list. It also helps maintainers accurately estimate time and effort on further code polishing, if needed.
>
> A project does not necessarily have to be finished in a single PR, but it's essential for the project to at least reach the first milestone in its very first PR.

- [ ] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.

- [x] Finish the code

> The code's design shall follow existing interfaces and convention. For example, each model component should be registered into `mmocr.registry.MODELS` and configurable via a config file.

- [x] Basic docstrings & proper citation

> Each major object should contain a docstring, describing its functionality and arguments. If you have adapted the code from other open-source projects, don't forget to cite the source project in docstring and make sure your behavior is not against its license. Typically, we do not accept any code snippet under GPL license. [A Short Guide to Open Source Licenses](https://medium.com/nationwide-technology/a-short-guide-to-open-source-licenses-cf5b1c329edd)

- [ ] Test-time correctness

> If you are reproducing the result from a paper, make sure your model's inference-time performance matches that in the original paper. The weights usually could be obtained by simply renaming the keys in the official pre-trained weights. This test could be skipped though, if you are able to prove the training-time correctness and check the second milestone.

- [x] A full README

> As this template does.

- [x] Milestone 2: Indicates a successful model implementation.

- [x] Training-time correctness

> If you are reproducing the result from a paper, checking this item means that you should have trained your model from scratch based on the original paper's specification and verified that the final result matches the report within a minor error range.

- [ ] Milestone 3: Good to be a part of our core package!

- [ ] Type hints and docstrings

> Ideally *all* the methods should have [type hints](https://www.pythontutorial.net/python-basics/python-type-hints/) and [docstrings](https://google.github.io/styleguide/pyguide.html#381-docstrings). [Example](https://github.com/open-mmlab/mmocr/blob/76637a290507f151215d299707c57cea5120976e/mmocr/utils/polygon_utils.py#L80-L96)

- [ ] Unit tests

> Unit tests for each module are required. [Example](https://github.com/open-mmlab/mmocr/blob/76637a290507f151215d299707c57cea5120976e/tests/test_utils/test_polygon_utils.py#L97-L106)

- [ ] Code polishing

> Refactor your code according to reviewer's comment.

- [ ] Metafile.yml

> It will be parsed by MIM and Inferencer. [Example](https://github.com/open-mmlab/mmocr/blob/1.x/configs/textdet/dbnet/metafile.yml)

- [ ] Move your modules into the core package following the codebase's file hierarchy structure.

> In particular, you may have to refactor this README into a standard one. [Example](/configs/textdet/dbnet/README.md)

- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
40 changes: 40 additions & 0 deletions projects/Donut/configs/_base_/default_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
default_scope = 'mmocr'
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
randomness = dict(seed=None)

default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='TokenCheckpointHook', interval=1),
sampler_seed=dict(type='DistSamplerSeedHook'),
sync_buffer=dict(type='SyncBuffersHook'),
visualization=dict(
type='VisualizationHook',
interval=1,
enable=False,
show=False,
draw_gt=False,
draw_pred=False),
)

# Logging
log_level = 'INFO'
log_processor = dict(type='LogProcessor', window_size=10, by_epoch=True)

load_from = None
resume = False

vis_backends = [
dict(type='LocalVisBackend'),
dict(type='TensorboardVisBackend')
]
visualizer = dict(
type='KIELocalVisualizer',
name='visualizer',
vis_backends=vis_backends,
is_openset=False)
22 changes: 22 additions & 0 deletions projects/Donut/configs/_base_/schedules/schedule_adam_fp16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# optimizer
optim_wrapper = dict(
type='AmpOptimWrapper',
dtype='float16',
optimizer=dict(type='Adam', lr=3e-5, weight_decay=0.0001))
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=30, val_interval=1)
val_cfg = dict(type='ValLoop', fp16=True)
test_cfg = dict(type='TestLoop', fp16=True)
# learning rate
param_scheduler = [
# warm up learning rate scheduler
dict(
type='LinearLR',
start_factor=3e-5,
by_epoch=True,
begin=0,
end=3,
# update by iter
convert_to_iter_based=True),
# main learning rate scheduler
dict(type='CosineAnnealingLR', by_epoch=True, begin=3, end=30)
]
108 changes: 108 additions & 0 deletions projects/Donut/configs/donut_cord_30e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
_base_ = [
'_base_/default_runtime.py',
'_base_/schedules/schedule_adam_fp16.py',
]

data_root = 'datasets/cord-v2'
task_name = 'cord-v2'

custom_imports = dict(imports=['donut'], allow_failed_imports=False)

# dictionary = dict(
# type='Dictionary',
# dict_file='{{ fileDirname }}/../../../dicts/english_digits_symbols.txt',
# with_padding=True,
# with_unknown=True,
# same_start_end=True,
# with_start=True,
# with_end=True)

model = dict(
type='Donut',
data_preprocessor=dict(
type='DonutDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375]),
encoder=dict(
type='SwinEncoder',
input_size=[1280, 960],
align_long_axis=False,
window_size=10,
encoder_layer=[2, 2, 14, 2],
init_cfg=dict(
type='Pretrained', checkpoint='data/donut_base_encoder.pth')),
decoder=dict(
type='BARTDecoder',
max_position_embeddings=None,
task_start_token=f'<s_{task_name}>',
prompt_end_token=f'<s_{task_name}>',
decoder_layer=4,
tokenizer_cfg=dict(
type='XLMRobertaTokenizer',
checkpoint='naver-clova-ix/donut-base'),
init_cfg=dict(
type='Pretrained', checkpoint='data/donut_base_decoder.pth')),
sort_json_key=False,
)

train_pipeline = [
dict(type='LoadImageFromFile', ignore_empty=True, min_size=2),
dict(type='LoadJsonAnnotations', with_bbox=False, with_label=False),
dict(type='TorchVisionWrapper', op='Resize', size=960, max_size=1280),
dict(type='RandomPad', input_size=[1280, 960], random_padding=True),
dict(
type='PackKIEInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'parses_json'))
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TorchVisionWrapper', op='Resize', size=960, max_size=1280),
dict(type='RandomPad', input_size=[1280, 960], random_padding=False),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadJsonAnnotations', with_bbox=False, with_label=False),
dict(
type='PackKIEInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'parses_json'))
]

# dataset settings
train_dataset = dict(
type='CORDDataset',
data_root=data_root,
split_name='train',
pipeline=train_pipeline)
val_dataset = dict(
type='CORDDataset',
data_root=data_root,
split_name='validation',
pipeline=test_pipeline)
test_dataset = dict(
type='CORDDataset',
data_root=data_root,
split_name='test',
pipeline=test_pipeline)

train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=train_dataset)

test_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=test_dataset)

val_dataloader = test_dataloader

val_evaluator = dict(type='DonutValEvaluator', key='parses')
test_evaluator = dict(type='JSONParseEvaluator', key='parses_json')

randomness = dict(seed=2022)
find_unused_parameters = True
4 changes: 4 additions & 0 deletions projects/Donut/donut/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .datasets import * # NOQA
from .engine import * # NOQA
from .evaluation import * # NOQA
from .model import * # NOQA
4 changes: 4 additions & 0 deletions projects/Donut/donut/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .cord_dataset import CORDDataset
from .transforms import * # NOQA

__all__ = ['CORDDataset']
Loading