anjali2002 commited on
Commit
b4959be
·
1 Parent(s): cc174d2

Initial commit of EasyOCR model

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. MANIFEST.in +8 -0
  3. README.md +178 -0
  4. custom_model.md +24 -0
  5. model.py +24 -0
  6. requirements.txt +12 -0
  7. scripts/.gitignore +2 -0
  8. scripts/generate-ja.rb +55 -0
  9. setup.cfg +2 -0
  10. setup.py +35 -0
  11. trainer/README.md +3 -0
  12. trainer/all_data/folder.txt +1 -0
  13. trainer/config_files/en_filtered_config.yaml +45 -0
  14. trainer/craft/.gitignore +4 -0
  15. trainer/craft/README.md +105 -0
  16. trainer/craft/config/__init__.py +0 -0
  17. trainer/craft/config/custom_data_train.yaml +100 -0
  18. trainer/craft/config/load_config.py +37 -0
  19. trainer/craft/config/syn_train.yaml +68 -0
  20. trainer/craft/data/boxEnlarge.py +65 -0
  21. trainer/craft/data/dataset.py +542 -0
  22. trainer/craft/data/gaussian.py +192 -0
  23. trainer/craft/data/imgaug.py +175 -0
  24. trainer/craft/data/imgproc.py +91 -0
  25. trainer/craft/data/pseudo_label/make_charbox.py +263 -0
  26. trainer/craft/data/pseudo_label/watershed.py +45 -0
  27. trainer/craft/data_root_dir/folder.txt +1 -0
  28. trainer/craft/eval.py +381 -0
  29. trainer/craft/exp/folder.txt +1 -0
  30. trainer/craft/loss/mseloss.py +172 -0
  31. trainer/craft/metrics/eval_det_iou.py +244 -0
  32. trainer/craft/model/craft.py +112 -0
  33. trainer/craft/model/vgg16_bn.py +77 -0
  34. trainer/craft/requirements.txt +10 -0
  35. trainer/craft/scripts/run_cde.sh +7 -0
  36. trainer/craft/train.py +479 -0
  37. trainer/craft/trainSynth.py +409 -0
  38. trainer/craft/train_distributed.py +523 -0
  39. trainer/craft/utils/craft_utils.py +345 -0
  40. trainer/craft/utils/inference_boxes.py +361 -0
  41. trainer/craft/utils/util.py +142 -0
  42. trainer/dataset.py +283 -0
  43. trainer/model.py +74 -0
  44. trainer/modules/feature_extraction.py +246 -0
  45. trainer/modules/prediction.py +81 -0
  46. trainer/modules/sequence_modeling.py +22 -0
  47. trainer/modules/transformation.py +160 -0
  48. trainer/saved_models/folder.txt +1 -0
  49. trainer/test.py +112 -0
  50. trainer/train.py +282 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
MANIFEST.in ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ include LICENSE.txt
2
+ include README.md
3
+
4
+ include easyocr/model/*
5
+ include easyocr/character/*
6
+ include easyocr/dict/*
7
+ include easyocr/scripts/compile_dbnet_dcn.py
8
+ recursive-include easyocr/DBNet *
README.md ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EasyOCR
2
+
3
+ [![PyPI Status](https://badge.fury.io/py/easyocr.svg)](https://badge.fury.io/py/easyocr)
4
+ [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/JaidedAI/EasyOCR/blob/master/LICENSE)
5
+ [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.to/easyocr)
6
+ [![Tweet](https://img.shields.io/twitter/url/https/github.com/JaidedAI/EasyOCR.svg?style=social)](https://twitter.com/intent/tweet?text=Check%20out%20this%20awesome%20library:%20EasyOCR%20https://github.com/JaidedAI/EasyOCR)
7
+ [![Twitter](https://img.shields.io/badge/twitter-@JaidedAI-blue.svg?style=flat)](https://twitter.com/JaidedAI)
8
+
9
+ Ready-to-use OCR with 80+ [supported languages](https://www.jaided.ai/easyocr) and all popular writing scripts including: Latin, Chinese, Arabic, Devanagari, Cyrillic, etc.
10
+
11
+ [Try Demo on our website](https://www.jaided.ai/easyocr)
12
+
13
+ Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/tomofi/EasyOCR)
14
+
15
+
16
+ ## What's new
17
+ - 24 September 2024 - Version 1.7.2
18
+ - Fix several compatibilities
19
+
20
+ - [Read all release notes](https://github.com/JaidedAI/EasyOCR/blob/master/releasenotes.md)
21
+
22
+ ## What's coming next
23
+ - Handwritten text support
24
+
25
+ ## Examples
26
+
27
+ ![example](examples/example.png)
28
+
29
+ ![example2](examples/example2.png)
30
+
31
+ ![example3](examples/example3.png)
32
+
33
+
34
+ ## Installation
35
+
36
+ Install using `pip`
37
+
38
+ For the latest stable release:
39
+
40
+ ``` bash
41
+ pip install easyocr
42
+ ```
43
+
44
+ For the latest development release:
45
+
46
+ ``` bash
47
+ pip install git+https://github.com/JaidedAI/EasyOCR.git
48
+ ```
49
+
50
+ Note 1: For Windows, please install torch and torchvision first by following the official instructions here https://pytorch.org. On the pytorch website, be sure to select the right CUDA version you have. If you intend to run on CPU mode only, select `CUDA = None`.
51
+
52
+ Note 2: We also provide a Dockerfile [here](https://github.com/JaidedAI/EasyOCR/blob/master/Dockerfile).
53
+
54
+ ## Usage
55
+
56
+ ``` python
57
+ import easyocr
58
+ reader = easyocr.Reader(['ch_sim','en']) # this needs to run only once to load the model into memory
59
+ result = reader.readtext('chinese.jpg')
60
+ ```
61
+
62
+ The output will be in a list format, each item represents a bounding box, the text detected and confident level, respectively.
63
+
64
+ ``` bash
65
+ [([[189, 75], [469, 75], [469, 165], [189, 165]], '愚园路', 0.3754989504814148),
66
+ ([[86, 80], [134, 80], [134, 128], [86, 128]], '西', 0.40452659130096436),
67
+ ([[517, 81], [565, 81], [565, 123], [517, 123]], '东', 0.9989598989486694),
68
+ ([[78, 126], [136, 126], [136, 156], [78, 156]], '315', 0.8125889301300049),
69
+ ([[514, 126], [574, 126], [574, 156], [514, 156]], '309', 0.4971577227115631),
70
+ ([[226, 170], [414, 170], [414, 220], [226, 220]], 'Yuyuan Rd.', 0.8261902332305908),
71
+ ([[79, 173], [125, 173], [125, 213], [79, 213]], 'W', 0.9848111271858215),
72
+ ([[529, 173], [569, 173], [569, 213], [529, 213]], 'E', 0.8405593633651733)]
73
+ ```
74
+ Note 1: `['ch_sim','en']` is the list of languages you want to read. You can pass
75
+ several languages at once but not all languages can be used together.
76
+ English is compatible with every language and languages that share common characters are usually compatible with each other.
77
+
78
+ Note 2: Instead of the filepath `chinese.jpg`, you can also pass an OpenCV image object (numpy array) or an image file as bytes. A URL to a raw image is also acceptable.
79
+
80
+ Note 3: The line `reader = easyocr.Reader(['ch_sim','en'])` is for loading a model into memory. It takes some time but it needs to be run only once.
81
+
82
+ You can also set `detail=0` for simpler output.
83
+
84
+ ``` python
85
+ reader.readtext('chinese.jpg', detail = 0)
86
+ ```
87
+ Result:
88
+ ``` bash
89
+ ['愚园路', '西', '东', '315', '309', 'Yuyuan Rd.', 'W', 'E']
90
+ ```
91
+
92
+ Model weights for the chosen language will be automatically downloaded or you can
93
+ download them manually from the [model hub](https://www.jaided.ai/easyocr/modelhub) and put them in the '~/.EasyOCR/model' folder
94
+
95
+ In case you do not have a GPU, or your GPU has low memory, you can run the model in CPU-only mode by adding `gpu=False`.
96
+
97
+ ``` python
98
+ reader = easyocr.Reader(['ch_sim','en'], gpu=False)
99
+ ```
100
+
101
+ For more information, read the [tutorial](https://www.jaided.ai/easyocr/tutorial) and [API Documentation](https://www.jaided.ai/easyocr/documentation).
102
+
103
+ #### Run on command line
104
+
105
+ ```shell
106
+ $ easyocr -l ch_sim en -f chinese.jpg --detail=1 --gpu=True
107
+ ```
108
+
109
+ ## Train/use your own model
110
+
111
+ For recognition model, [Read here](https://github.com/JaidedAI/EasyOCR/blob/master/custom_model.md).
112
+
113
+ For detection model (CRAFT), [Read here](https://github.com/JaidedAI/EasyOCR/blob/master/trainer/craft/README.md).
114
+
115
+ ## Implementation Roadmap
116
+
117
+ - Handwritten support
118
+ - Restructure code to support swappable detection and recognition algorithms
119
+ The api should be as easy as
120
+ ``` python
121
+ reader = easyocr.Reader(['en'], detection='DB', recognition = 'Transformer')
122
+ ```
123
+ The idea is to be able to plug in any state-of-the-art model into EasyOCR. There are a lot of geniuses trying to make better detection/recognition models, but we are not trying to be geniuses here. We just want to make their works quickly accessible to the public ... for free. (well, we believe most geniuses want their work to create a positive impact as fast/big as possible) The pipeline should be something like the below diagram. Grey slots are placeholders for changeable light blue modules.
124
+
125
+ ![plan](examples/easyocr_framework.jpeg)
126
+
127
+ ## Acknowledgement and References
128
+
129
+ This project is based on research and code from several papers and open-source repositories.
130
+
131
+ All deep learning execution is based on [Pytorch](https://pytorch.org). :heart:
132
+
133
+ Detection execution uses the CRAFT algorithm from this [official repository](https://github.com/clovaai/CRAFT-pytorch) and their [paper](https://arxiv.org/abs/1904.01941) (Thanks @YoungminBaek from [@clovaai](https://github.com/clovaai)). We also use their pretrained model. Training script is provided by [@gmuffiness](https://github.com/gmuffiness).
134
+
135
+ The recognition model is a CRNN ([paper](https://arxiv.org/abs/1507.05717)). It is composed of 3 main components: feature extraction (we are currently using [Resnet](https://arxiv.org/abs/1512.03385)) and VGG, sequence labeling ([LSTM](https://www.bioinf.jku.at/publications/older/2604.pdf)) and decoding ([CTC](https://www.cs.toronto.edu/~graves/icml_2006.pdf)). The training pipeline for recognition execution is a modified version of the [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark) framework. (Thanks [@ku21fan](https://github.com/ku21fan) from [@clovaai](https://github.com/clovaai)) This repository is a gem that deserves more recognition.
136
+
137
+ Beam search code is based on this [repository](https://github.com/githubharald/CTCDecoder) and his [blog](https://towardsdatascience.com/beam-search-decoding-in-ctc-trained-neural-networks-5a889a3d85a7). (Thanks [@githubharald](https://github.com/githubharald))
138
+
139
+ Data synthesis is based on [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator). (Thanks [@Belval](https://github.com/Belval))
140
+
141
+ And a good read about CTC from distill.pub [here](https://distill.pub/2017/ctc/).
142
+
143
+ ## Want To Contribute?
144
+
145
+ Let's advance humanity together by making AI available to everyone!
146
+
147
+ 3 ways to contribute:
148
+
149
+ **Coder:** Please send a PR for small bugs/improvements. For bigger ones, discuss with us by opening an issue first. There is a list of possible bug/improvement issues tagged with ['PR WELCOME'](https://github.com/JaidedAI/EasyOCR/issues?q=is%3Aissue+is%3Aopen+label%3A%22PR+WELCOME%22).
150
+
151
+ **User:** Tell us how EasyOCR benefits you/your organization to encourage further development. Also post failure cases in [Issue Section](https://github.com/JaidedAI/EasyOCR/issues) to help improve future models.
152
+
153
+ **Tech leader/Guru:** If you found this library useful, please spread the word! (See [Yann Lecun's post](https://www.facebook.com/yann.lecun/posts/10157018122787143) about EasyOCR)
154
+
155
+ ## Guideline for new language request
156
+
157
+ To request a new language, we need you to send a PR with the 2 following files:
158
+
159
+ 1. In folder [easyocr/character](https://github.com/JaidedAI/EasyOCR/tree/master/easyocr/character),
160
+ we need 'yourlanguagecode_char.txt' that contains list of all characters. Please see format examples from other files in that folder.
161
+ 2. In folder [easyocr/dict](https://github.com/JaidedAI/EasyOCR/tree/master/easyocr/dict),
162
+ we need 'yourlanguagecode.txt' that contains list of words in your language.
163
+ On average, we have ~30000 words per language with more than 50000 words for more popular ones.
164
+ More is better in this file.
165
+
166
+ If your language has unique elements (such as 1. Arabic: characters change form when attached to each other + write from right to left 2. Thai: Some characters need to be above the line and some below), please educate us to the best of your ability and/or give useful links. It is important to take care of the detail to achieve a system that really works.
167
+
168
+ Lastly, please understand that our priority will have to go to popular languages or sets of languages that share large portions of their characters with each other (also tell us if this is the case for your language). It takes us at least a week to develop a new model, so you may have to wait a while for the new model to be released.
169
+
170
+ See [List of languages in development](https://github.com/JaidedAI/EasyOCR/issues/91)
171
+
172
+ ## Github Issues
173
+
174
+ Due to limited resources, an issue older than 6 months will be automatically closed. Please open an issue again if it is critical.
175
+
176
+ ## Business Inquiries
177
+
178
+ For Enterprise Support, [Jaided AI](https://www.jaided.ai/) offers full service for custom OCR/AI systems from implementation, training/finetuning and deployment. Click [here](https://www.jaided.ai/contactus?ref=github) to contact us.
custom_model.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Custom recognition models
2
+
3
+ ## How to train your custom model
4
+
5
+ You can use your own data or generate your own dataset. To generate your own data, we recommend using
6
+ [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator). We provide an example of a dataset [here](https://jaided.ai/easyocr/modelhub/).
7
+ After you have a dataset, you can train your own model by following this repository
8
+ [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark).
9
+ The network needs to be fully convolutional in order to predict flexible text length. Our current network is 'None-VGG-BiLSTM-CTC'.
10
+ Once you have your trained model (a `.pth` file), you need 2 additional files describing recognition network architecture and model configuration.
11
+ An example is provided in `custom_example.zip` file [here](https://jaided.ai/easyocr/modelhub/).
12
+
13
+ Please do not create an issue about data generation and model training in this repository. If you have any question regarding data generation and model training, please ask in the respective repositories.
14
+
15
+ Note: We also provide our version of a training script [here](https://github.com/JaidedAI/EasyOCR/tree/master/trainer). It is a modified version from [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark).
16
+
17
+ ## How to use your custom model
18
+
19
+ To use your own recognition model, you need the three files as explained above. These three files have to share the same name (i.e. `yourmodel.pth`, `yourmodel.yaml`, `yourmodel.py`) that you will then use to call your model with EasyOCR API.
20
+
21
+ We provide [custom_example.zip](https://jaided.ai/easyocr/modelhub/)
22
+ as an example. Please download, extract and place `custom_example.py`, `custom_example.yaml` in the `user_network_directory` (default = `~/.EasyOCR/user_network`) and place `custom_example.pth` in model directory (default = `~/.EasyOCR/model`)
23
+ Once you place all 3 files in their respective places, you can use `custom_example` by
24
+ specifying `recog_network` like this `reader = easyocr.Reader(['en'], recog_network='custom_example')`.
model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import easyocr
2
+ from typing import List
3
+
4
+ class EasyOCRModel:
5
+ def __init__(self):
6
+ self.reader = easyocr.Reader(['en']) # Initialize with English; add languages if needed.
7
+
8
+ def predict(self, image_path: str) -> List[str]:
9
+ """
10
+ Perform OCR on the given image.
11
+
12
+ Args:
13
+ image_path (str): Path to the input image.
14
+
15
+ Returns:
16
+ List[str]: Extracted text from the image.
17
+ """
18
+ return self.reader.readtext(image_path, detail=0)
19
+
20
+ # Test the model locally
21
+ if __name__ == "__main__":
22
+ model = EasyOCRModel()
23
+ result = model.predict("sample_image.jpg")
24
+ print(result)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision>=0.5
3
+ opencv-python-headless
4
+ scipy
5
+ numpy
6
+ Pillow
7
+ scikit-image
8
+ python-bidi
9
+ PyYAML
10
+ Shapely
11
+ pyclipper
12
+ ninja
scripts/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ JMnedict*
2
+ JMdict*
scripts/generate-ja.rb ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # frozen_string_literal: true
2
+
3
+ require 'json'
4
+ require 'nokogiri'
5
+ require 'parallel'
6
+ require 'ruby-progressbar'
7
+
8
+ JMDICT_XML = 'JMdict_e'
9
+ JMNEDICT_XML = 'JMnedict.xml'
10
+ PUNC = '【】《》〈〉⦅⦆{}[]〔〕()『』「」、;:・?〜=。!⁉︎‥…〜※*〽♪♫♬♩〇〒〶〠〄ⓍⓁⓎ→'.chars
11
+
12
+ def download_dict(xml)
13
+ return if File.exist?(File.expand_path(xml, __dir__))
14
+
15
+ archive = "#{xml}.gz"
16
+ url = "http://ftp.monash.edu/pub/nihongo/#{archive}"
17
+ `cd #{File.dirname(__FILE__)} && wget #{url} && gunzip #{archive}`
18
+ end
19
+
20
+ def read_word(word)
21
+ word.css('k_ele keb').map(&:text) + word.css('r_ele reb').map(&:text)
22
+ end
23
+
24
+ def read_dict(filename, root)
25
+ xml = Nokogiri::XML(File.open(File.expand_path(filename, __dir__)))
26
+ words = xml.css("#{root} > entry")
27
+ Parallel.flat_map(words, in_threads: 16, progress: root) do |word|
28
+ read_word(word)
29
+ end
30
+ end
31
+
32
+ def write_files(words)
33
+ src_dir = File.expand_path('../easyocr', __dir__)
34
+ ja_dict = File.join(src_dir, 'dict', 'ja.txt')
35
+ ja_char = File.join(src_dir, 'character', 'ja_char2.txt')
36
+ ja_char_old = File.join(src_dir, 'character', 'ja_char.txt')
37
+ ja_punc = File.join(src_dir, 'character', 'ja_punc.txt')
38
+
39
+ words -= PUNC
40
+ chars = words.join.chars.uniq
41
+ chars_old = IO.read(ja_char_old).split("\n")
42
+
43
+ puts "new characters: #{(chars - chars_old).size}"
44
+ puts "missing characters: #{(chars_old - chars).size}"
45
+ puts chars_old - chars
46
+
47
+ IO.write(ja_dict, words.join("\n"))
48
+ IO.write(ja_char, chars.join("\n"))
49
+ IO.write(ja_punc, PUNC.join("\n"))
50
+ end
51
+
52
+ download_dict(JMDICT_XML)
53
+ download_dict(JMNEDICT_XML)
54
+ words = read_dict(JMDICT_XML, 'JMdict') + read_dict(JMNEDICT_XML, 'JMnedict')
55
+ write_files(words)
setup.cfg ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [metadata]
2
+ description_file = README.md
setup.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ End-to-End Multi-Lingual Optical Character Recognition (OCR) Solution
3
+ """
4
+ from io import open
5
+ from setuptools import setup
6
+
7
+ with open('requirements.txt', encoding="utf-8-sig") as f:
8
+ requirements = f.readlines()
9
+
10
+ def readme():
11
+ with open('README.md', encoding="utf-8-sig") as f:
12
+ README = f.read()
13
+ return README
14
+
15
+ setup(
16
+ name='easyocr',
17
+ packages=['easyocr'],
18
+ include_package_data=True,
19
+ version='1.7.2',
20
+ install_requires=requirements,
21
+ entry_points={"console_scripts": ["easyocr= easyocr.cli:main"]},
22
+ license='Apache License 2.0',
23
+ description='End-to-End Multi-Lingual Optical Character Recognition (OCR) Solution',
24
+ long_description=readme(),
25
+ long_description_content_type="text/markdown",
26
+ author='Rakpong Kittinaradorn',
27
+ author_email='r.kittinaradorn@gmail.com',
28
+ url='https://github.com/jaidedai/easyocr',
29
+ download_url='https://github.com/jaidedai/easyocr.git',
30
+ keywords=['ocr optical character recognition deep learning neural network'],
31
+ classifiers=[
32
+ 'Development Status :: 5 - Production/Stable'
33
+ ],
34
+
35
+ )
trainer/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # EasyOCR trainer
2
+
3
+ use `trainer.ipynb` with yaml config in `config_files` folder
trainer/all_data/folder.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ place dataset folder here
trainer/config_files/en_filtered_config.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ number: '0123456789'
2
+ symbol: "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ €"
3
+ lang_char: 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
4
+ experiment_name: 'en_filtered'
5
+ train_data: 'all_data'
6
+ valid_data: 'all_data/en_val'
7
+ manualSeed: 1111
8
+ workers: 6
9
+ batch_size: 32 #32
10
+ num_iter: 300000
11
+ valInterval: 20000
12
+ saved_model: '' #'saved_models/en_filtered/iter_300000.pth'
13
+ FT: False
14
+ optim: False # default is Adadelta
15
+ lr: 1.
16
+ beta1: 0.9
17
+ rho: 0.95
18
+ eps: 0.00000001
19
+ grad_clip: 5
20
+ #Data processing
21
+ select_data: 'en_train_filtered' # this is dataset folder in train_data
22
+ batch_ratio: '1'
23
+ total_data_usage_ratio: 1.0
24
+ batch_max_length: 34
25
+ imgH: 64
26
+ imgW: 600
27
+ rgb: False
28
+ contrast_adjust: False
29
+ sensitive: True
30
+ PAD: True
31
+ contrast_adjust: 0.0
32
+ data_filtering_off: False
33
+ # Model Architecture
34
+ Transformation: 'None'
35
+ FeatureExtraction: 'VGG'
36
+ SequenceModeling: 'BiLSTM'
37
+ Prediction: 'CTC'
38
+ num_fiducial: 20
39
+ input_channel: 1
40
+ output_channel: 256
41
+ hidden_size: 256
42
+ decode: 'greedy'
43
+ new_prediction: False
44
+ freeze_FeatureFxtraction: False
45
+ freeze_SequenceModeling: False
trainer/craft/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ model/__pycache__/
3
+ wandb/*
4
+ vis_result/*
trainer/craft/README.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CRAFT-train
2
+ On the official CRAFT github, there are many people who want to train CRAFT models.
3
+
4
+ However, the training code is not published in the official CRAFT repository.
5
+
6
+ There are other reproduced codes, but there is a gap between their performance and performance reported in the original paper. (https://arxiv.org/pdf/1904.01941.pdf)
7
+
8
+ The trained model with this code recorded a level of performance similar to that of the original paper.
9
+
10
+ ```bash
11
+ ├── config
12
+ │ ├── syn_train.yaml
13
+ │ └── custom_data_train.yaml
14
+ ├── data
15
+ │ ├── pseudo_label
16
+ │ │ ├── make_charbox.py
17
+ │ │ └── watershed.py
18
+ │ ├── boxEnlarge.py
19
+ │ ├── dataset.py
20
+ │ ├── gaussian.py
21
+ │ ├── imgaug.py
22
+ │ └── imgproc.py
23
+ ├── loss
24
+ │ └── mseloss.py
25
+ ├── metrics
26
+ │ └── eval_det_iou.py
27
+ ├── model
28
+ │ ├── craft.py
29
+ │ └── vgg16_bn.py
30
+ ├── utils
31
+ │ ├── craft_utils.py
32
+ │ ├── inference_boxes.py
33
+ │ └── utils.py
34
+ ├── trainSynth.py
35
+ ├── train.py
36
+ ├── train_distributed.py
37
+ ├── eval.py
38
+ ├── data_root_dir (place dataset folder here)
39
+ └── exp (model and experiment result files will saved here)
40
+ ```
41
+
42
+ ### Installation
43
+
44
+ Install using `pip`
45
+
46
+ ``` bash
47
+ pip install -r requirements.txt
48
+ ```
49
+
50
+
51
+ ### Training
52
+ 1. Put your training, test data in the following format
53
+ ```
54
+ └── data_root_dir (you can change root dir in yaml file)
55
+ ├── ch4_training_images
56
+ │ ├── img_1.jpg
57
+ │ └── img_2.jpg
58
+ ├── ch4_training_localization_transcription_gt
59
+ │ ├── gt_img_1.txt
60
+ │ └── gt_img_2.txt
61
+ ├── ch4_test_images
62
+ │ ├── img_1.jpg
63
+ │ └── img_2.jpg
64
+ └── ch4_training_localization_transcription_gt
65
+ ├── gt_img_1.txt
66
+ └── gt_img_2.txt
67
+ ```
68
+ * localization_transcription_gt files format :
69
+ ```
70
+ 377,117,463,117,465,130,378,130,Genaxis Theatre
71
+ 493,115,519,115,519,131,493,131,[06]
72
+ 374,155,409,155,409,170,374,170,###
73
+ ```
74
+ 2. Write configuration in yaml format (example config files are provided in `config` folder.)
75
+ * To speed up training time with multi-gpu, set num_worker > 0
76
+ 3. Put the yaml file in the config folder
77
+ 4. Run training script like below (If you have multi-gpu, run train_distributed.py)
78
+ 5. Then, experiment results will be saved to ```./exp/[yaml]``` by default.
79
+
80
+ * Step 1 : To train CRAFT with SynthText dataset from scratch
81
+ * Note : This step is not necessary if you use <a href="https://drive.google.com/file/d/1enVIsgNvBf3YiRsVkxodspOn55PIK-LJ/view?usp=sharing">this pretrain</a> as a checkpoint when start training step 2. You can download and put it in `exp/CRAFT_clr_amp_29500.pth` and change `ckpt_path` in the config file according to your local setup.
82
+ ```
83
+ CUDA_VISIBLE_DEVICES=0 python3 trainSynth.py --yaml=syn_train
84
+ ```
85
+
86
+ * Step 2 : To train CRAFT with [SynthText + IC15] or custom dataset
87
+ ```
88
+ CUDA_VISIBLE_DEVICES=0 python3 train.py --yaml=custom_data_train ## if you run on single GPU
89
+ CUDA_VISIBLE_DEVICES=0,1 python3 train_distributed.py --yaml=custom_data_train ## if you run on multi GPU
90
+ ```
91
+
92
+ ### Arguments
93
+ * ```--yaml``` : configuration file name
94
+
95
+ ### Evaluation
96
+ * In the official repository issues, the author mentioned that the first row setting F1-score is around 0.75.
97
+ * In the official paper, it is stated that the result F1-score of the second row setting is 0.87.
98
+ * If you adjust post-process parameter 'text_threshold' from 0.85 to 0.75, then F1-score reaches to 0.856.
99
+ * It took 14h to train weak-supervision 25k iteration with 8 RTX 3090 Ti.
100
+ * Half of GPU assigned for training, and half of GPU assigned for supervision setting.
101
+
102
+ | Training Dataset | Evaluation Dataset | Precision | Recall | F1-score | pretrained model |
103
+ | ------------- |-----|:-----:|:-----:|:-----:|-----:|
104
+ | SynthText | ICDAR2013 | 0.801 | 0.748 | 0.773| <a href="https://drive.google.com/file/d/1enVIsgNvBf3YiRsVkxodspOn55PIK-LJ/view?usp=sharing">download link</a>|
105
+ | SynthText + ICDAR2015 | ICDAR2015 | 0.909 | 0.794 | 0.848| <a href="https://drive.google.com/file/d/1qUeZIDSFCOuGS9yo8o0fi-zYHLEW6lBP/view">download link</a>|
trainer/craft/config/__init__.py ADDED
File without changes
trainer/craft/config/custom_data_train.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_opt: False
2
+
3
+ results_dir: "./exp/"
4
+ vis_test_dir: "./vis_result/"
5
+
6
+ data_root_dir: "./data_root_dir/"
7
+ score_gt_dir: None # "/data/ICDAR2015_official_supervision"
8
+ mode: "weak_supervision"
9
+
10
+
11
+ train:
12
+ backbone : vgg
13
+ use_synthtext: False # If you want to combine SynthText in train time as CRAFT did, you can turn on this option
14
+ synth_data_dir: "/data/SynthText/"
15
+ synth_ratio: 5
16
+ real_dataset: custom
17
+ ckpt_path: "./pretrained_model/CRAFT_clr_amp_29500.pth"
18
+ eval_interval: 1000
19
+ batch_size: 5
20
+ st_iter: 0
21
+ end_iter: 25000
22
+ lr: 0.0001
23
+ lr_decay: 7500
24
+ gamma: 0.2
25
+ weight_decay: 0.00001
26
+ num_workers: 0 # On single gpu, train.py execution only works when num worker = 0 / On multi-gpu, you can set num_worker > 0 to speed up
27
+ amp: True
28
+ loss: 2
29
+ neg_rto: 0.3
30
+ n_min_neg: 5000
31
+ data:
32
+ vis_opt: False
33
+ pseudo_vis_opt: False
34
+ output_size: 768
35
+ do_not_care_label: ['###', '']
36
+ mean: [0.485, 0.456, 0.406]
37
+ variance: [0.229, 0.224, 0.225]
38
+ enlarge_region : [0.5, 0.5] # x axis, y axis
39
+ enlarge_affinity: [0.5, 0.5]
40
+ gauss_init_size: 200
41
+ gauss_sigma: 40
42
+ watershed:
43
+ version: "skimage"
44
+ sure_fg_th: 0.75
45
+ sure_bg_th: 0.05
46
+ syn_sample: -1
47
+ custom_sample: -1
48
+ syn_aug:
49
+ random_scale:
50
+ range: [1.0, 1.5, 2.0]
51
+ option: False
52
+ random_rotate:
53
+ max_angle: 20
54
+ option: False
55
+ random_crop:
56
+ version: "random_resize_crop_synth"
57
+ option: True
58
+ random_horizontal_flip:
59
+ option: False
60
+ random_colorjitter:
61
+ brightness: 0.2
62
+ contrast: 0.2
63
+ saturation: 0.2
64
+ hue: 0.2
65
+ option: True
66
+ custom_aug:
67
+ random_scale:
68
+ range: [ 1.0, 1.5, 2.0 ]
69
+ option: False
70
+ random_rotate:
71
+ max_angle: 20
72
+ option: True
73
+ random_crop:
74
+ version: "random_resize_crop"
75
+ scale: [0.03, 0.4]
76
+ ratio: [0.75, 1.33]
77
+ rnd_threshold: 1.0
78
+ option: True
79
+ random_horizontal_flip:
80
+ option: True
81
+ random_colorjitter:
82
+ brightness: 0.2
83
+ contrast: 0.2
84
+ saturation: 0.2
85
+ hue: 0.2
86
+ option: True
87
+
88
+ test:
89
+ trained_model : null
90
+ custom_data:
91
+ test_set_size: 500
92
+ test_data_dir: "./data_root_dir/"
93
+ text_threshold: 0.75
94
+ low_text: 0.5
95
+ link_threshold: 0.2
96
+ canvas_size: 2240
97
+ mag_ratio: 1.75
98
+ poly: False
99
+ cuda: True
100
+ vis_opt: False
trainer/craft/config/load_config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ from functools import reduce
4
+
5
+ CONFIG_PATH = os.path.dirname(__file__)
6
+
7
+ def load_yaml(config_name):
8
+
9
+ with open(os.path.join(CONFIG_PATH, config_name)+ '.yaml') as file:
10
+ config = yaml.safe_load(file)
11
+
12
+ return config
13
+
14
+ class DotDict(dict):
15
+ def __getattr__(self, k):
16
+ try:
17
+ v = self[k]
18
+ except:
19
+ return super().__getattr__(k)
20
+ if isinstance(v, dict):
21
+ return DotDict(v)
22
+ return v
23
+
24
+ def __getitem__(self, k):
25
+ if isinstance(k, str) and '.' in k:
26
+ k = k.split('.')
27
+ if isinstance(k, (list, tuple)):
28
+ return reduce(lambda d, kk: d[kk], k, self)
29
+ return super().__getitem__(k)
30
+
31
+ def get(self, k, default=None):
32
+ if isinstance(k, str) and '.' in k:
33
+ try:
34
+ return self[k]
35
+ except KeyError:
36
+ return default
37
+ return super().get(k, default=default)
trainer/craft/config/syn_train.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_opt: False
2
+
3
+ results_dir: "./exp/"
4
+ vis_test_dir: "./vis_result/"
5
+ data_dir:
6
+ synthtext: "/data/SynthText/"
7
+ synthtext_gt: NULL
8
+
9
+ train:
10
+ backbone : vgg
11
+ dataset: ["synthtext"]
12
+ ckpt_path: null
13
+ eval_interval: 1000
14
+ batch_size: 5
15
+ st_iter: 0
16
+ end_iter: 50000
17
+ lr: 0.0001
18
+ lr_decay: 15000
19
+ gamma: 0.2
20
+ weight_decay: 0.00001
21
+ num_workers: 4
22
+ amp: True
23
+ loss: 3
24
+ neg_rto: 1
25
+ n_min_neg: 1000
26
+ data:
27
+ vis_opt: False
28
+ output_size: 768
29
+ mean: [0.485, 0.456, 0.406]
30
+ variance: [0.229, 0.224, 0.225]
31
+ enlarge_region : [0.5, 0.5] # x axis, y axis
32
+ enlarge_affinity: [0.5, 0.5]
33
+ gauss_init_size: 200
34
+ gauss_sigma: 40
35
+ syn_sample : -1
36
+ syn_aug:
37
+ random_scale:
38
+ range: [1.0, 1.5, 2.0]
39
+ option: False
40
+ random_rotate:
41
+ max_angle: 20
42
+ option: False
43
+ random_crop:
44
+ version: "random_resize_crop_synth"
45
+ rnd_threshold : 1.0
46
+ option: True
47
+ random_horizontal_flip:
48
+ option: False
49
+ random_colorjitter:
50
+ brightness: 0.2
51
+ contrast: 0.2
52
+ saturation: 0.2
53
+ hue: 0.2
54
+ option: True
55
+
56
+ test:
57
+ trained_model: null
58
+ icdar2013:
59
+ test_set_size: 233
60
+ cuda: True
61
+ vis_opt: True
62
+ test_data_dir : "/data/ICDAR2013/"
63
+ text_threshold: 0.85
64
+ low_text: 0.5
65
+ link_threshold: 0.2
66
+ canvas_size: 960
67
+ mag_ratio: 1.5
68
+ poly: False
trainer/craft/data/boxEnlarge.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+
4
+
5
+ def pointAngle(Apoint, Bpoint):
6
+ angle = (Bpoint[1] - Apoint[1]) / ((Bpoint[0] - Apoint[0]) + 10e-8)
7
+ return angle
8
+
9
+ def pointDistance(Apoint, Bpoint):
10
+ return math.sqrt((Bpoint[1] - Apoint[1])**2 + (Bpoint[0] - Apoint[0])**2)
11
+
12
+ def lineBiasAndK(Apoint, Bpoint):
13
+
14
+ K = pointAngle(Apoint, Bpoint)
15
+ B = Apoint[1] - K*Apoint[0]
16
+ return K, B
17
+
18
+ def getX(K, B, Ypoint):
19
+ return int((Ypoint-B)/K)
20
+
21
+ def sidePoint(Apoint, Bpoint, h, w, placehold, enlarge_size):
22
+
23
+ K, B = lineBiasAndK(Apoint, Bpoint)
24
+ angle = abs(math.atan(pointAngle(Apoint, Bpoint)))
25
+ distance = pointDistance(Apoint, Bpoint)
26
+
27
+ x_enlarge_size, y_enlarge_size = enlarge_size
28
+
29
+ XaxisIncreaseDistance = abs(math.cos(angle) * x_enlarge_size * distance)
30
+ YaxisIncreaseDistance = abs(math.sin(angle) * y_enlarge_size * distance)
31
+
32
+ if placehold == 'leftTop':
33
+ x1 = max(0, Apoint[0] - XaxisIncreaseDistance)
34
+ y1 = max(0, Apoint[1] - YaxisIncreaseDistance)
35
+ elif placehold == 'rightTop':
36
+ x1 = min(w, Bpoint[0] + XaxisIncreaseDistance)
37
+ y1 = max(0, Bpoint[1] - YaxisIncreaseDistance)
38
+ elif placehold == 'rightBottom':
39
+ x1 = min(w, Bpoint[0] + XaxisIncreaseDistance)
40
+ y1 = min(h, Bpoint[1] + YaxisIncreaseDistance)
41
+ elif placehold == 'leftBottom':
42
+ x1 = max(0, Apoint[0] - XaxisIncreaseDistance)
43
+ y1 = min(h, Apoint[1] + YaxisIncreaseDistance)
44
+ return int(x1), int(y1)
45
+
46
+ def enlargebox(box, h, w, enlarge_size, horizontal_text_bool):
47
+
48
+ if not horizontal_text_bool:
49
+ enlarge_size = (enlarge_size[1], enlarge_size[0])
50
+
51
+ box = np.roll(box, -np.argmin(box.sum(axis=1)), axis=0)
52
+
53
+ Apoint, Bpoint, Cpoint, Dpoint = box
54
+ K1, B1 = lineBiasAndK(box[0], box[2])
55
+ K2, B2 = lineBiasAndK(box[3], box[1])
56
+ X = (B2 - B1)/(K1 - K2)
57
+ Y = K1 * X + B1
58
+ center = [X, Y]
59
+
60
+ x1, y1 = sidePoint(Apoint, center, h, w, 'leftTop', enlarge_size)
61
+ x2, y2 = sidePoint(center, Bpoint, h, w, 'rightTop', enlarge_size)
62
+ x3, y3 = sidePoint(center, Cpoint, h, w, 'rightBottom', enlarge_size)
63
+ x4, y4 = sidePoint(Dpoint, center, h, w, 'leftBottom', enlarge_size)
64
+ newcharbox = np.array([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
65
+ return newcharbox
trainer/craft/data/dataset.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import itertools
4
+ import random
5
+
6
+ import numpy as np
7
+ import scipy.io as scio
8
+ from PIL import Image
9
+ import cv2
10
+ from torch.utils.data import Dataset
11
+ import torchvision.transforms as transforms
12
+
13
+ from data import imgproc
14
+ from data.gaussian import GaussianBuilder
15
+ from data.imgaug import (
16
+ rescale,
17
+ random_resize_crop_synth,
18
+ random_resize_crop,
19
+ random_horizontal_flip,
20
+ random_rotate,
21
+ random_scale,
22
+ random_crop,
23
+ )
24
+ from data.pseudo_label.make_charbox import PseudoCharBoxBuilder
25
+ from utils.util import saveInput, saveImage
26
+
27
+
28
+ class CraftBaseDataset(Dataset):
29
+ def __init__(
30
+ self,
31
+ output_size,
32
+ data_dir,
33
+ saved_gt_dir,
34
+ mean,
35
+ variance,
36
+ gauss_init_size,
37
+ gauss_sigma,
38
+ enlarge_region,
39
+ enlarge_affinity,
40
+ aug,
41
+ vis_test_dir,
42
+ vis_opt,
43
+ sample,
44
+ ):
45
+ self.output_size = output_size
46
+ self.data_dir = data_dir
47
+ self.saved_gt_dir = saved_gt_dir
48
+ self.mean, self.variance = mean, variance
49
+ self.gaussian_builder = GaussianBuilder(
50
+ gauss_init_size, gauss_sigma, enlarge_region, enlarge_affinity
51
+ )
52
+ self.aug = aug
53
+ self.vis_test_dir = vis_test_dir
54
+ self.vis_opt = vis_opt
55
+ self.sample = sample
56
+ if self.sample != -1:
57
+ random.seed(0)
58
+ self.idx = random.sample(range(0, len(self.img_names)), self.sample)
59
+
60
+ self.pre_crop_area = []
61
+
62
+ def augment_image(
63
+ self, image, region_score, affinity_score, confidence_mask, word_level_char_bbox
64
+ ):
65
+ augment_targets = [image, region_score, affinity_score, confidence_mask]
66
+
67
+ if self.aug.random_scale.option:
68
+ augment_targets, word_level_char_bbox = random_scale(
69
+ augment_targets, word_level_char_bbox, self.aug.random_scale.range
70
+ )
71
+
72
+ if self.aug.random_rotate.option:
73
+ augment_targets = random_rotate(
74
+ augment_targets, self.aug.random_rotate.max_angle
75
+ )
76
+
77
+ if self.aug.random_crop.option:
78
+ if self.aug.random_crop.version == "random_crop_with_bbox":
79
+ augment_targets = random_crop_with_bbox(
80
+ augment_targets, word_level_char_bbox, self.output_size
81
+ )
82
+ elif self.aug.random_crop.version == "random_resize_crop_synth":
83
+ augment_targets = random_resize_crop_synth(
84
+ augment_targets, self.output_size
85
+ )
86
+ elif self.aug.random_crop.version == "random_resize_crop":
87
+
88
+ if len(self.pre_crop_area) > 0:
89
+ pre_crop_area = self.pre_crop_area
90
+ else:
91
+ pre_crop_area = None
92
+
93
+ augment_targets = random_resize_crop(
94
+ augment_targets,
95
+ self.aug.random_crop.scale,
96
+ self.aug.random_crop.ratio,
97
+ self.output_size,
98
+ self.aug.random_crop.rnd_threshold,
99
+ pre_crop_area,
100
+ )
101
+
102
+ elif self.aug.random_crop.version == "random_crop":
103
+ augment_targets = random_crop(augment_targets, self.output_size,)
104
+
105
+ else:
106
+ assert "Undefined RandomCrop version"
107
+
108
+ if self.aug.random_horizontal_flip.option:
109
+ augment_targets = random_horizontal_flip(augment_targets)
110
+
111
+ if self.aug.random_colorjitter.option:
112
+ image, region_score, affinity_score, confidence_mask = augment_targets
113
+ image = Image.fromarray(image)
114
+ image = transforms.ColorJitter(
115
+ brightness=self.aug.random_colorjitter.brightness,
116
+ contrast=self.aug.random_colorjitter.contrast,
117
+ saturation=self.aug.random_colorjitter.saturation,
118
+ hue=self.aug.random_colorjitter.hue,
119
+ )(image)
120
+ else:
121
+ image, region_score, affinity_score, confidence_mask = augment_targets
122
+
123
+ return np.array(image), region_score, affinity_score, confidence_mask
124
+
125
+ def resize_to_half(self, ground_truth, interpolation):
126
+ return cv2.resize(
127
+ ground_truth,
128
+ (self.output_size // 2, self.output_size // 2),
129
+ interpolation=interpolation,
130
+ )
131
+
132
+ def __len__(self):
133
+ if self.sample != -1:
134
+ return len(self.idx)
135
+ else:
136
+ return len(self.img_names)
137
+
138
+ def __getitem__(self, index):
139
+ if self.sample != -1:
140
+ index = self.idx[index]
141
+ if self.saved_gt_dir is None:
142
+ (
143
+ image,
144
+ region_score,
145
+ affinity_score,
146
+ confidence_mask,
147
+ word_level_char_bbox,
148
+ all_affinity_bbox,
149
+ words,
150
+ ) = self.make_gt_score(index)
151
+ else:
152
+ (
153
+ image,
154
+ region_score,
155
+ affinity_score,
156
+ confidence_mask,
157
+ word_level_char_bbox,
158
+ words,
159
+ ) = self.load_saved_gt_score(index)
160
+ all_affinity_bbox = []
161
+
162
+ if self.vis_opt:
163
+ saveImage(
164
+ self.img_names[index],
165
+ self.vis_test_dir,
166
+ image.copy(),
167
+ word_level_char_bbox.copy(),
168
+ all_affinity_bbox.copy(),
169
+ region_score.copy(),
170
+ affinity_score.copy(),
171
+ confidence_mask.copy(),
172
+ )
173
+
174
+ image, region_score, affinity_score, confidence_mask = self.augment_image(
175
+ image, region_score, affinity_score, confidence_mask, word_level_char_bbox
176
+ )
177
+
178
+ if self.vis_opt:
179
+ saveInput(
180
+ self.img_names[index],
181
+ self.vis_test_dir,
182
+ image,
183
+ region_score,
184
+ affinity_score,
185
+ confidence_mask,
186
+ )
187
+
188
+ region_score = self.resize_to_half(region_score, interpolation=cv2.INTER_CUBIC)
189
+ affinity_score = self.resize_to_half(
190
+ affinity_score, interpolation=cv2.INTER_CUBIC
191
+ )
192
+ confidence_mask = self.resize_to_half(
193
+ confidence_mask, interpolation=cv2.INTER_NEAREST
194
+ )
195
+
196
+ image = imgproc.normalizeMeanVariance(
197
+ np.array(image), mean=self.mean, variance=self.variance
198
+ )
199
+ image = image.transpose(2, 0, 1)
200
+
201
+ return image, region_score, affinity_score, confidence_mask
202
+
203
+
204
+ class SynthTextDataSet(CraftBaseDataset):
205
+ def __init__(
206
+ self,
207
+ output_size,
208
+ data_dir,
209
+ saved_gt_dir,
210
+ mean,
211
+ variance,
212
+ gauss_init_size,
213
+ gauss_sigma,
214
+ enlarge_region,
215
+ enlarge_affinity,
216
+ aug,
217
+ vis_test_dir,
218
+ vis_opt,
219
+ sample,
220
+ ):
221
+ super().__init__(
222
+ output_size,
223
+ data_dir,
224
+ saved_gt_dir,
225
+ mean,
226
+ variance,
227
+ gauss_init_size,
228
+ gauss_sigma,
229
+ enlarge_region,
230
+ enlarge_affinity,
231
+ aug,
232
+ vis_test_dir,
233
+ vis_opt,
234
+ sample,
235
+ )
236
+ self.img_names, self.char_bbox, self.img_words = self.load_data()
237
+ self.vis_index = list(range(1000))
238
+
239
+ def load_data(self, bbox="char"):
240
+
241
+ gt = scio.loadmat(os.path.join(self.data_dir, "gt.mat"))
242
+ img_names = gt["imnames"][0]
243
+ img_words = gt["txt"][0]
244
+
245
+ if bbox == "char":
246
+ img_bbox = gt["charBB"][0]
247
+ else:
248
+ img_bbox = gt["wordBB"][0] # word bbox needed for test
249
+
250
+ return img_names, img_bbox, img_words
251
+
252
+ def dilate_img_to_output_size(self, image, char_bbox):
253
+ h, w, _ = image.shape
254
+ if min(h, w) <= self.output_size:
255
+ scale = float(self.output_size) / min(h, w)
256
+ else:
257
+ scale = 1.0
258
+ image = cv2.resize(
259
+ image, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC
260
+ )
261
+ char_bbox *= scale
262
+ return image, char_bbox
263
+
264
+ def make_gt_score(self, index):
265
+ img_path = os.path.join(self.data_dir, self.img_names[index][0])
266
+ image = cv2.imread(img_path, cv2.IMREAD_COLOR)
267
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
268
+ all_char_bbox = self.char_bbox[index].transpose(
269
+ (2, 1, 0)
270
+ ) # shape : (Number of characters in image, 4, 2)
271
+
272
+ img_h, img_w, _ = image.shape
273
+
274
+ confidence_mask = np.ones((img_h, img_w), dtype=np.float32)
275
+
276
+ words = [
277
+ re.split(" \n|\n |\n| ", word.strip()) for word in self.img_words[index]
278
+ ]
279
+ words = list(itertools.chain(*words))
280
+ words = [word for word in words if len(word) > 0]
281
+
282
+ word_level_char_bbox = []
283
+ char_idx = 0
284
+
285
+ for i in range(len(words)):
286
+ length_of_word = len(words[i])
287
+ word_bbox = all_char_bbox[char_idx : char_idx + length_of_word]
288
+ assert len(word_bbox) == length_of_word
289
+ char_idx += length_of_word
290
+ word_bbox = np.array(word_bbox)
291
+ word_level_char_bbox.append(word_bbox)
292
+
293
+ region_score = self.gaussian_builder.generate_region(
294
+ img_h,
295
+ img_w,
296
+ word_level_char_bbox,
297
+ horizontal_text_bools=[True for _ in range(len(words))],
298
+ )
299
+ affinity_score, all_affinity_bbox = self.gaussian_builder.generate_affinity(
300
+ img_h,
301
+ img_w,
302
+ word_level_char_bbox,
303
+ horizontal_text_bools=[True for _ in range(len(words))],
304
+ )
305
+
306
+ return (
307
+ image,
308
+ region_score,
309
+ affinity_score,
310
+ confidence_mask,
311
+ word_level_char_bbox,
312
+ all_affinity_bbox,
313
+ words,
314
+ )
315
+
316
+
317
+ class CustomDataset(CraftBaseDataset):
318
+ def __init__(
319
+ self,
320
+ output_size,
321
+ data_dir,
322
+ saved_gt_dir,
323
+ mean,
324
+ variance,
325
+ gauss_init_size,
326
+ gauss_sigma,
327
+ enlarge_region,
328
+ enlarge_affinity,
329
+ aug,
330
+ vis_test_dir,
331
+ vis_opt,
332
+ sample,
333
+ watershed_param,
334
+ pseudo_vis_opt,
335
+ do_not_care_label,
336
+ ):
337
+ super().__init__(
338
+ output_size,
339
+ data_dir,
340
+ saved_gt_dir,
341
+ mean,
342
+ variance,
343
+ gauss_init_size,
344
+ gauss_sigma,
345
+ enlarge_region,
346
+ enlarge_affinity,
347
+ aug,
348
+ vis_test_dir,
349
+ vis_opt,
350
+ sample,
351
+ )
352
+ self.pseudo_vis_opt = pseudo_vis_opt
353
+ self.do_not_care_label = do_not_care_label
354
+ self.pseudo_charbox_builder = PseudoCharBoxBuilder(
355
+ watershed_param, vis_test_dir, pseudo_vis_opt, self.gaussian_builder
356
+ )
357
+ self.vis_index = list(range(1000))
358
+ self.img_dir = os.path.join(data_dir, "ch4_training_images")
359
+ self.img_gt_box_dir = os.path.join(
360
+ data_dir, "ch4_training_localization_transcription_gt"
361
+ )
362
+ self.img_names = os.listdir(self.img_dir)
363
+
364
+ def update_model(self, net):
365
+ self.net = net
366
+
367
+ def update_device(self, gpu):
368
+ self.gpu = gpu
369
+
370
+ def load_img_gt_box(self, img_gt_box_path):
371
+ lines = open(img_gt_box_path, encoding="utf-8").readlines()
372
+ word_bboxes = []
373
+ words = []
374
+ for line in lines:
375
+ box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
376
+ box_points = [int(box_info[i]) for i in range(8)]
377
+ box_points = np.array(box_points, np.float32).reshape(4, 2)
378
+ word = box_info[8:]
379
+ word = ",".join(word)
380
+ if word in self.do_not_care_label:
381
+ words.append(self.do_not_care_label[0])
382
+ word_bboxes.append(box_points)
383
+ continue
384
+ word_bboxes.append(box_points)
385
+ words.append(word)
386
+ return np.array(word_bboxes), words
387
+
388
+ def load_data(self, index):
389
+ img_name = self.img_names[index]
390
+ img_path = os.path.join(self.img_dir, img_name)
391
+ image = cv2.imread(img_path)
392
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
393
+
394
+ img_gt_box_path = os.path.join(
395
+ self.img_gt_box_dir, "gt_%s.txt" % os.path.splitext(img_name)[0]
396
+ )
397
+ word_bboxes, words = self.load_img_gt_box(
398
+ img_gt_box_path
399
+ ) # shape : (Number of word bbox, 4, 2)
400
+ confidence_mask = np.ones((image.shape[0], image.shape[1]), np.float32)
401
+
402
+ word_level_char_bbox = []
403
+ do_care_words = []
404
+ horizontal_text_bools = []
405
+
406
+ if len(word_bboxes) == 0:
407
+ return (
408
+ image,
409
+ word_level_char_bbox,
410
+ do_care_words,
411
+ confidence_mask,
412
+ horizontal_text_bools,
413
+ )
414
+ _word_bboxes = word_bboxes.copy()
415
+ for i in range(len(word_bboxes)):
416
+ if words[i] in self.do_not_care_label:
417
+ cv2.fillPoly(confidence_mask, [np.int32(_word_bboxes[i])], 0)
418
+ continue
419
+
420
+ (
421
+ pseudo_char_bbox,
422
+ confidence,
423
+ horizontal_text_bool,
424
+ ) = self.pseudo_charbox_builder.build_char_box(
425
+ self.net, self.gpu, image, word_bboxes[i], words[i], img_name=img_name
426
+ )
427
+
428
+ cv2.fillPoly(confidence_mask, [np.int32(_word_bboxes[i])], confidence)
429
+ do_care_words.append(words[i])
430
+ word_level_char_bbox.append(pseudo_char_bbox)
431
+ horizontal_text_bools.append(horizontal_text_bool)
432
+
433
+ return (
434
+ image,
435
+ word_level_char_bbox,
436
+ do_care_words,
437
+ confidence_mask,
438
+ horizontal_text_bools,
439
+ )
440
+
441
+ def make_gt_score(self, index):
442
+ """
443
+ Make region, affinity scores using pseudo character-level GT bounding box
444
+ word_level_char_bbox's shape : [word_num, [char_num_in_one_word, 4, 2]]
445
+ :rtype region_score: np.float32
446
+ :rtype affinity_score: np.float32
447
+ :rtype confidence_mask: np.float32
448
+ :rtype word_level_char_bbox: np.float32
449
+ :rtype words: list
450
+ """
451
+ (
452
+ image,
453
+ word_level_char_bbox,
454
+ words,
455
+ confidence_mask,
456
+ horizontal_text_bools,
457
+ ) = self.load_data(index)
458
+ img_h, img_w, _ = image.shape
459
+
460
+ if len(word_level_char_bbox) == 0:
461
+ region_score = np.zeros((img_h, img_w), dtype=np.float32)
462
+ affinity_score = np.zeros((img_h, img_w), dtype=np.float32)
463
+ all_affinity_bbox = []
464
+ else:
465
+ region_score = self.gaussian_builder.generate_region(
466
+ img_h, img_w, word_level_char_bbox, horizontal_text_bools
467
+ )
468
+ affinity_score, all_affinity_bbox = self.gaussian_builder.generate_affinity(
469
+ img_h, img_w, word_level_char_bbox, horizontal_text_bools
470
+ )
471
+
472
+ return (
473
+ image,
474
+ region_score,
475
+ affinity_score,
476
+ confidence_mask,
477
+ word_level_char_bbox,
478
+ all_affinity_bbox,
479
+ words,
480
+ )
481
+
482
+ def load_saved_gt_score(self, index):
483
+ """
484
+ Load pre-saved official CRAFT model's region, affinity scores to train
485
+ word_level_char_bbox's shape : [word_num, [char_num_in_one_word, 4, 2]]
486
+ :rtype region_score: np.float32
487
+ :rtype affinity_score: np.float32
488
+ :rtype confidence_mask: np.float32
489
+ :rtype word_level_char_bbox: np.float32
490
+ :rtype words: list
491
+ """
492
+ img_name = self.img_names[index]
493
+ img_path = os.path.join(self.img_dir, img_name)
494
+ image = cv2.imread(img_path)
495
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
496
+
497
+ img_gt_box_path = os.path.join(
498
+ self.img_gt_box_dir, "gt_%s.txt" % os.path.splitext(img_name)[0]
499
+ )
500
+ word_bboxes, words = self.load_img_gt_box(img_gt_box_path)
501
+ image, word_bboxes = rescale(image, word_bboxes)
502
+ img_h, img_w, _ = image.shape
503
+
504
+ query_idx = int(self.img_names[index].split(".")[0].split("_")[1])
505
+
506
+ saved_region_scores_path = os.path.join(
507
+ self.saved_gt_dir, f"res_img_{query_idx}_region.jpg"
508
+ )
509
+ saved_affi_scores_path = os.path.join(
510
+ self.saved_gt_dir, f"res_img_{query_idx}_affi.jpg"
511
+ )
512
+ saved_cf_mask_path = os.path.join(
513
+ self.saved_gt_dir, f"res_img_{query_idx}_cf_mask_thresh_0.6.jpg"
514
+ )
515
+ region_score = cv2.imread(saved_region_scores_path, cv2.IMREAD_GRAYSCALE)
516
+ affinity_score = cv2.imread(saved_affi_scores_path, cv2.IMREAD_GRAYSCALE)
517
+ confidence_mask = cv2.imread(saved_cf_mask_path, cv2.IMREAD_GRAYSCALE)
518
+
519
+ region_score = cv2.resize(region_score, (img_w, img_h))
520
+ affinity_score = cv2.resize(affinity_score, (img_w, img_h))
521
+ confidence_mask = cv2.resize(
522
+ confidence_mask, (img_w, img_h), interpolation=cv2.INTER_NEAREST
523
+ )
524
+
525
+ region_score = region_score.astype(np.float32) / 255
526
+ affinity_score = affinity_score.astype(np.float32) / 255
527
+ confidence_mask = confidence_mask.astype(np.float32) / 255
528
+
529
+ # NOTE : Even though word_level_char_bbox is not necessary, align bbox format with make_gt_score()
530
+ word_level_char_bbox = []
531
+
532
+ for i in range(len(word_bboxes)):
533
+ word_level_char_bbox.append(np.expand_dims(word_bboxes[i], 0))
534
+
535
+ return (
536
+ image,
537
+ region_score,
538
+ affinity_score,
539
+ confidence_mask,
540
+ word_level_char_bbox,
541
+ words,
542
+ )
trainer/craft/data/gaussian.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ from data.boxEnlarge import enlargebox
5
+
6
+
7
+ class GaussianBuilder(object):
8
+ def __init__(self, init_size, sigma, enlarge_region, enlarge_affinity):
9
+ self.init_size = init_size
10
+ self.sigma = sigma
11
+ self.enlarge_region = enlarge_region
12
+ self.enlarge_affinity = enlarge_affinity
13
+ self.gaussian_map, self.gaussian_map_color = self.generate_gaussian_map()
14
+
15
+ def generate_gaussian_map(self):
16
+ circle_mask = self.generate_circle_mask()
17
+
18
+ gaussian_map = np.zeros((self.init_size, self.init_size), np.float32)
19
+
20
+ for i in range(self.init_size):
21
+ for j in range(self.init_size):
22
+ gaussian_map[i, j] = (
23
+ 1
24
+ / 2
25
+ / np.pi
26
+ / (self.sigma ** 2)
27
+ * np.exp(
28
+ -1
29
+ / 2
30
+ * (
31
+ (i - self.init_size / 2) ** 2 / (self.sigma ** 2)
32
+ + (j - self.init_size / 2) ** 2 / (self.sigma ** 2)
33
+ )
34
+ )
35
+ )
36
+
37
+ gaussian_map = gaussian_map * circle_mask
38
+ gaussian_map = (gaussian_map / np.max(gaussian_map)).astype(np.float32)
39
+
40
+ gaussian_map_color = (gaussian_map * 255).astype(np.uint8)
41
+ gaussian_map_color = cv2.applyColorMap(gaussian_map_color, cv2.COLORMAP_JET)
42
+ return gaussian_map, gaussian_map_color
43
+
44
+ def generate_circle_mask(self):
45
+
46
+ zero_arr = np.zeros((self.init_size, self.init_size), np.float32)
47
+ circle_mask = cv2.circle(
48
+ img=zero_arr,
49
+ center=(self.init_size // 2, self.init_size // 2),
50
+ radius=self.init_size // 2,
51
+ color=1,
52
+ thickness=-1,
53
+ )
54
+
55
+ return circle_mask
56
+
57
+ def four_point_transform(self, bbox):
58
+ """
59
+ Using the bbox, standard 2D gaussian map, returns Transformed 2d Gaussian map
60
+ """
61
+ width, height = (
62
+ np.max(bbox[:, 0]).astype(np.int32),
63
+ np.max(bbox[:, 1]).astype(np.int32),
64
+ )
65
+ init_points = np.array(
66
+ [
67
+ [0, 0],
68
+ [self.init_size, 0],
69
+ [self.init_size, self.init_size],
70
+ [0, self.init_size],
71
+ ],
72
+ dtype="float32",
73
+ )
74
+
75
+ M = cv2.getPerspectiveTransform(init_points, bbox)
76
+ warped_gaussian_map = cv2.warpPerspective(self.gaussian_map, M, (width, height))
77
+ return warped_gaussian_map, width, height
78
+
79
+ def add_gaussian_map_to_score_map(
80
+ self, score_map, bbox, enlarge_size, horizontal_text_bool, map_type=None
81
+ ):
82
+ """
83
+ Mapping 2D Gaussian to the character box coordinates of the score_map.
84
+
85
+ :param score_map: Target map to put 2D gaussian on character box
86
+ :type score_map: np.float32
87
+ :param bbox: character boxes
88
+ :type bbox: np.float32
89
+ :param enlarge_size: Enlarge size of gaussian map to fit character shape
90
+ :type enlarge_size: list of enlarge size [x dim, y dim]
91
+ :param horizontal_text_bool: Flag that bbox is horizontal text or not
92
+ :type horizontal_text_bool: bool
93
+ :param map_type: Whether map's type is "region" | "affinity"
94
+ :type map_type: str
95
+ :return score_map: score map that all 2D gaussian put on character box
96
+ :rtype: np.float32
97
+ """
98
+
99
+ map_h, map_w = score_map.shape
100
+ bbox = enlargebox(bbox, map_h, map_w, enlarge_size, horizontal_text_bool)
101
+
102
+ # If any one point of character bbox is out of range, don't put in on map
103
+ if np.any(bbox < 0) or np.any(bbox[:, 0] > map_w) or np.any(bbox[:, 1] > map_h):
104
+ return score_map
105
+
106
+ bbox_left, bbox_top = np.array([np.min(bbox[:, 0]), np.min(bbox[:, 1])]).astype(
107
+ np.int32
108
+ )
109
+ bbox -= (bbox_left, bbox_top)
110
+ warped_gaussian_map, width, height = self.four_point_transform(
111
+ bbox.astype(np.float32)
112
+ )
113
+
114
+ try:
115
+ bbox_area_of_image = score_map[
116
+ bbox_top : bbox_top + height, bbox_left : bbox_left + width,
117
+ ]
118
+ high_value_score = np.where(
119
+ warped_gaussian_map > bbox_area_of_image,
120
+ warped_gaussian_map,
121
+ bbox_area_of_image,
122
+ )
123
+ score_map[
124
+ bbox_top : bbox_top + height, bbox_left : bbox_left + width,
125
+ ] = high_value_score
126
+
127
+ except Exception as e:
128
+ print("Error : {}".format(e))
129
+ print(
130
+ "On generating {} map, strange box came out. (width: {}, height: {})".format(
131
+ map_type, width, height
132
+ )
133
+ )
134
+
135
+ return score_map
136
+
137
+ def calculate_affinity_box_points(self, bbox_1, bbox_2, vertical=False):
138
+ center_1, center_2 = np.mean(bbox_1, axis=0), np.mean(bbox_2, axis=0)
139
+ if vertical:
140
+ tl = (bbox_1[0] + bbox_1[-1] + center_1) / 3
141
+ tr = (bbox_1[1:3].sum(0) + center_1) / 3
142
+ br = (bbox_2[1:3].sum(0) + center_2) / 3
143
+ bl = (bbox_2[0] + bbox_2[-1] + center_2) / 3
144
+ else:
145
+ tl = (bbox_1[0:2].sum(0) + center_1) / 3
146
+ tr = (bbox_2[0:2].sum(0) + center_2) / 3
147
+ br = (bbox_2[2:4].sum(0) + center_2) / 3
148
+ bl = (bbox_1[2:4].sum(0) + center_1) / 3
149
+ affinity_box = np.array([tl, tr, br, bl]).astype(np.float32)
150
+ return affinity_box
151
+
152
+ def generate_region(
153
+ self, img_h, img_w, word_level_char_bbox, horizontal_text_bools
154
+ ):
155
+ region_map = np.zeros([img_h, img_w], dtype=np.float32)
156
+ for i in range(
157
+ len(word_level_char_bbox)
158
+ ): # shape : [word_num, [char_num_in_one_word, 4, 2]]
159
+ for j in range(len(word_level_char_bbox[i])):
160
+ region_map = self.add_gaussian_map_to_score_map(
161
+ region_map,
162
+ word_level_char_bbox[i][j].copy(),
163
+ self.enlarge_region,
164
+ horizontal_text_bools[i],
165
+ map_type="region",
166
+ )
167
+ return region_map
168
+
169
+ def generate_affinity(
170
+ self, img_h, img_w, word_level_char_bbox, horizontal_text_bools
171
+ ):
172
+
173
+ affinity_map = np.zeros([img_h, img_w], dtype=np.float32)
174
+ all_affinity_bbox = []
175
+ for i in range(len(word_level_char_bbox)):
176
+ for j in range(len(word_level_char_bbox[i]) - 1):
177
+ affinity_bbox = self.calculate_affinity_box_points(
178
+ word_level_char_bbox[i][j], word_level_char_bbox[i][j + 1]
179
+ )
180
+
181
+ affinity_map = self.add_gaussian_map_to_score_map(
182
+ affinity_map,
183
+ affinity_bbox.copy(),
184
+ self.enlarge_affinity,
185
+ horizontal_text_bools[i],
186
+ map_type="affinity",
187
+ )
188
+ all_affinity_bbox.append(np.expand_dims(affinity_bbox, axis=0))
189
+
190
+ if len(all_affinity_bbox) > 0:
191
+ all_affinity_bbox = np.concatenate(all_affinity_bbox, axis=0)
192
+ return affinity_map, all_affinity_bbox
trainer/craft/data/imgaug.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ from torchvision.transforms.functional import resized_crop, crop
7
+ from torchvision.transforms import RandomResizedCrop, RandomCrop
8
+ from torchvision.transforms import InterpolationMode
9
+
10
+
11
+ def rescale(img, bboxes, target_size=2240):
12
+ h, w = img.shape[0:2]
13
+ scale = target_size / max(h, w)
14
+ img = cv2.resize(img, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
15
+ bboxes = bboxes * scale
16
+ return img, bboxes
17
+
18
+
19
+ def random_resize_crop_synth(augment_targets, size):
20
+ image, region_score, affinity_score, confidence_mask = augment_targets
21
+
22
+ image = Image.fromarray(image)
23
+ region_score = Image.fromarray(region_score)
24
+ affinity_score = Image.fromarray(affinity_score)
25
+ confidence_mask = Image.fromarray(confidence_mask)
26
+
27
+ short_side = min(image.size)
28
+ i, j, h, w = RandomCrop.get_params(image, output_size=(short_side, short_side))
29
+
30
+ image = resized_crop(
31
+ image, i, j, h, w, size=(size, size), interpolation=InterpolationMode.BICUBIC
32
+ )
33
+ region_score = resized_crop(
34
+ region_score, i, j, h, w, (size, size), interpolation=InterpolationMode.BICUBIC
35
+ )
36
+ affinity_score = resized_crop(
37
+ affinity_score,
38
+ i,
39
+ j,
40
+ h,
41
+ w,
42
+ (size, size),
43
+ interpolation=InterpolationMode.BICUBIC,
44
+ )
45
+ confidence_mask = resized_crop(
46
+ confidence_mask,
47
+ i,
48
+ j,
49
+ h,
50
+ w,
51
+ (size, size),
52
+ interpolation=InterpolationMode.NEAREST,
53
+ )
54
+
55
+ image = np.array(image)
56
+ region_score = np.array(region_score)
57
+ affinity_score = np.array(affinity_score)
58
+ confidence_mask = np.array(confidence_mask)
59
+ augment_targets = [image, region_score, affinity_score, confidence_mask]
60
+
61
+ return augment_targets
62
+
63
+
64
+ def random_resize_crop(
65
+ augment_targets, scale, ratio, size, threshold, pre_crop_area=None
66
+ ):
67
+ image, region_score, affinity_score, confidence_mask = augment_targets
68
+
69
+ image = Image.fromarray(image)
70
+ region_score = Image.fromarray(region_score)
71
+ affinity_score = Image.fromarray(affinity_score)
72
+ confidence_mask = Image.fromarray(confidence_mask)
73
+
74
+ if pre_crop_area != None:
75
+ i, j, h, w = pre_crop_area
76
+
77
+ else:
78
+ if random.random() < threshold:
79
+ i, j, h, w = RandomResizedCrop.get_params(image, scale=scale, ratio=ratio)
80
+ else:
81
+ i, j, h, w = RandomResizedCrop.get_params(
82
+ image, scale=(1.0, 1.0), ratio=(1.0, 1.0)
83
+ )
84
+
85
+ image = resized_crop(
86
+ image, i, j, h, w, size=(size, size), interpolation=InterpolationMode.BICUBIC
87
+ )
88
+ region_score = resized_crop(
89
+ region_score, i, j, h, w, (size, size), interpolation=InterpolationMode.BICUBIC
90
+ )
91
+ affinity_score = resized_crop(
92
+ affinity_score,
93
+ i,
94
+ j,
95
+ h,
96
+ w,
97
+ (size, size),
98
+ interpolation=InterpolationMode.BICUBIC,
99
+ )
100
+ confidence_mask = resized_crop(
101
+ confidence_mask,
102
+ i,
103
+ j,
104
+ h,
105
+ w,
106
+ (size, size),
107
+ interpolation=InterpolationMode.NEAREST,
108
+ )
109
+
110
+ image = np.array(image)
111
+ region_score = np.array(region_score)
112
+ affinity_score = np.array(affinity_score)
113
+ confidence_mask = np.array(confidence_mask)
114
+ augment_targets = [image, region_score, affinity_score, confidence_mask]
115
+
116
+ return augment_targets
117
+
118
+
119
+ def random_crop(augment_targets, size):
120
+ image, region_score, affinity_score, confidence_mask = augment_targets
121
+
122
+ image = Image.fromarray(image)
123
+ region_score = Image.fromarray(region_score)
124
+ affinity_score = Image.fromarray(affinity_score)
125
+ confidence_mask = Image.fromarray(confidence_mask)
126
+
127
+ i, j, h, w = RandomCrop.get_params(image, output_size=(size, size))
128
+
129
+ image = crop(image, i, j, h, w)
130
+ region_score = crop(region_score, i, j, h, w)
131
+ affinity_score = crop(affinity_score, i, j, h, w)
132
+ confidence_mask = crop(confidence_mask, i, j, h, w)
133
+
134
+ image = np.array(image)
135
+ region_score = np.array(region_score)
136
+ affinity_score = np.array(affinity_score)
137
+ confidence_mask = np.array(confidence_mask)
138
+ augment_targets = [image, region_score, affinity_score, confidence_mask]
139
+
140
+ return augment_targets
141
+
142
+
143
+ def random_horizontal_flip(imgs):
144
+ if random.random() < 0.5:
145
+ for i in range(len(imgs)):
146
+ imgs[i] = np.flip(imgs[i], axis=1).copy()
147
+ return imgs
148
+
149
+
150
+ def random_scale(images, word_level_char_bbox, scale_range):
151
+ scale = random.sample(scale_range, 1)[0]
152
+
153
+ for i in range(len(images)):
154
+ images[i] = cv2.resize(images[i], dsize=None, fx=scale, fy=scale)
155
+
156
+ for i in range(len(word_level_char_bbox)):
157
+ word_level_char_bbox[i] *= scale
158
+
159
+ return images
160
+
161
+
162
+ def random_rotate(images, max_angle):
163
+ angle = random.random() * 2 * max_angle - max_angle
164
+ for i in range(len(images)):
165
+ img = images[i]
166
+ w, h = img.shape[:2]
167
+ rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
168
+ if i == len(images) - 1:
169
+ img_rotation = cv2.warpAffine(
170
+ img, M=rotation_matrix, dsize=(h, w), flags=cv2.INTER_NEAREST
171
+ )
172
+ else:
173
+ img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w))
174
+ images[i] = img_rotation
175
+ return images
trainer/craft/data/imgproc.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2019-present NAVER Corp.
3
+ MIT License
4
+ """
5
+
6
+ # -*- coding: utf-8 -*-
7
+ import numpy as np
8
+
9
+ import cv2
10
+ from skimage import io
11
+
12
+
13
+ def loadImage(img_file):
14
+ img = io.imread(img_file) # RGB order
15
+ if img.shape[0] == 2:
16
+ img = img[0]
17
+ if len(img.shape) == 2:
18
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
19
+ if img.shape[2] == 4:
20
+ img = img[:, :, :3]
21
+ img = np.array(img)
22
+
23
+ return img
24
+
25
+
26
+ def normalizeMeanVariance(
27
+ in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)
28
+ ):
29
+ # should be RGB order
30
+ img = in_img.copy().astype(np.float32)
31
+
32
+ img -= np.array(
33
+ [mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32
34
+ )
35
+ img /= np.array(
36
+ [variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0],
37
+ dtype=np.float32,
38
+ )
39
+ return img
40
+
41
+
42
+ def denormalizeMeanVariance(
43
+ in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)
44
+ ):
45
+ # should be RGB order
46
+ img = in_img.copy()
47
+ img *= variance
48
+ img += mean
49
+ img *= 255.0
50
+ img = np.clip(img, 0, 255).astype(np.uint8)
51
+ return img
52
+
53
+
54
+ def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1):
55
+ height, width, channel = img.shape
56
+
57
+ # magnify image size
58
+ target_size = mag_ratio * max(height, width)
59
+
60
+ # set original image size
61
+ if target_size > square_size:
62
+ target_size = square_size
63
+
64
+ ratio = target_size / max(height, width)
65
+
66
+ target_h, target_w = int(height * ratio), int(width * ratio)
67
+
68
+ # NOTE
69
+ valid_size_heatmap = (int(target_h / 2), int(target_w / 2))
70
+
71
+ proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation)
72
+
73
+ # make canvas and paste image
74
+ target_h32, target_w32 = target_h, target_w
75
+ if target_h % 32 != 0:
76
+ target_h32 = target_h + (32 - target_h % 32)
77
+ if target_w % 32 != 0:
78
+ target_w32 = target_w + (32 - target_w % 32)
79
+ resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
80
+ resized[0:target_h, 0:target_w, :] = proc
81
+
82
+ # target_h, target_w = target_h32, target_w32
83
+ # size_heatmap = (int(target_w/2), int(target_h/2))
84
+
85
+ return resized, ratio, valid_size_heatmap
86
+
87
+
88
+ def cvt2HeatmapImg(img):
89
+ img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
90
+ img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
91
+ return img
trainer/craft/data/pseudo_label/make_charbox.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import math
4
+
5
+ import numpy as np
6
+ import cv2
7
+ import torch
8
+
9
+ from data import imgproc
10
+ from data.pseudo_label.watershed import exec_watershed_by_version
11
+
12
+
13
+ class PseudoCharBoxBuilder:
14
+ def __init__(self, watershed_param, vis_test_dir, pseudo_vis_opt, gaussian_builder):
15
+ self.watershed_param = watershed_param
16
+ self.vis_test_dir = vis_test_dir
17
+ self.pseudo_vis_opt = pseudo_vis_opt
18
+ self.gaussian_builder = gaussian_builder
19
+ self.cnt = 0
20
+ self.flag = False
21
+
22
+ def crop_image_by_bbox(self, image, box, word):
23
+ w = max(
24
+ int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[2] - box[3]))
25
+ )
26
+ h = max(
27
+ int(np.linalg.norm(box[0] - box[3])), int(np.linalg.norm(box[1] - box[2]))
28
+ )
29
+ try:
30
+ word_ratio = h / w
31
+ except:
32
+ import ipdb
33
+
34
+ ipdb.set_trace()
35
+
36
+ one_char_ratio = min(h, w) / (max(h, w) / len(word))
37
+
38
+ # NOTE: criterion to split vertical word in here is set to work properly on IC15 dataset
39
+ if word_ratio > 2 or (word_ratio > 1.6 and one_char_ratio > 2.4):
40
+ # warping method of vertical word (classified by upper condition)
41
+ horizontal_text_bool = False
42
+ long_side = h
43
+ short_side = w
44
+ M = cv2.getPerspectiveTransform(
45
+ np.float32(box),
46
+ np.float32(
47
+ np.array(
48
+ [
49
+ [long_side, 0],
50
+ [long_side, short_side],
51
+ [0, short_side],
52
+ [0, 0],
53
+ ]
54
+ )
55
+ ),
56
+ )
57
+ self.flag = True
58
+ else:
59
+ # warping method of horizontal word
60
+ horizontal_text_bool = True
61
+ long_side = w
62
+ short_side = h
63
+ M = cv2.getPerspectiveTransform(
64
+ np.float32(box),
65
+ np.float32(
66
+ np.array(
67
+ [
68
+ [0, 0],
69
+ [long_side, 0],
70
+ [long_side, short_side],
71
+ [0, short_side],
72
+ ]
73
+ )
74
+ ),
75
+ )
76
+ self.flag = False
77
+
78
+ warped = cv2.warpPerspective(image, M, (long_side, short_side))
79
+ return warped, M, horizontal_text_bool
80
+
81
+ def inference_word_box(self, net, gpu, word_image):
82
+ if net.training:
83
+ net.eval()
84
+
85
+ with torch.no_grad():
86
+ word_img_torch = torch.from_numpy(
87
+ imgproc.normalizeMeanVariance(
88
+ word_image,
89
+ mean=(0.485, 0.456, 0.406),
90
+ variance=(0.229, 0.224, 0.225),
91
+ )
92
+ )
93
+ word_img_torch = word_img_torch.permute(2, 0, 1).unsqueeze(0)
94
+ word_img_torch = word_img_torch.type(torch.FloatTensor).cuda(gpu)
95
+ with torch.cuda.amp.autocast():
96
+ word_img_scores, _ = net(word_img_torch)
97
+ return word_img_scores
98
+
99
+ def visualize_pseudo_label(
100
+ self, word_image, region_score, watershed_box, pseudo_char_bbox, img_name,
101
+ ):
102
+ word_img_h, word_img_w, _ = word_image.shape
103
+ word_img_cp1 = word_image.copy()
104
+ word_img_cp2 = word_image.copy()
105
+ _watershed_box = np.int32(watershed_box)
106
+ _pseudo_char_bbox = np.int32(pseudo_char_bbox)
107
+
108
+ region_score_color = cv2.applyColorMap(np.uint8(region_score), cv2.COLORMAP_JET)
109
+ region_score_color = cv2.resize(region_score_color, (word_img_w, word_img_h))
110
+
111
+ for box in _watershed_box:
112
+ cv2.polylines(
113
+ np.uint8(word_img_cp1),
114
+ [np.reshape(box, (-1, 1, 2))],
115
+ True,
116
+ (255, 0, 0),
117
+ )
118
+
119
+ for box in _pseudo_char_bbox:
120
+ cv2.polylines(
121
+ np.uint8(word_img_cp2), [np.reshape(box, (-1, 1, 2))], True, (255, 0, 0)
122
+ )
123
+
124
+ # NOTE: Just for visualize, put gaussian map on char box
125
+ pseudo_gt_region_score = self.gaussian_builder.generate_region(
126
+ word_img_h, word_img_w, [_pseudo_char_bbox], [True]
127
+ )
128
+
129
+ pseudo_gt_region_score = cv2.applyColorMap(
130
+ (pseudo_gt_region_score * 255).astype("uint8"), cv2.COLORMAP_JET
131
+ )
132
+
133
+ overlay_img = cv2.addWeighted(
134
+ word_image[:, :, ::-1], 0.7, pseudo_gt_region_score, 0.3, 5
135
+ )
136
+ vis_result = np.hstack(
137
+ [
138
+ word_image[:, :, ::-1],
139
+ region_score_color,
140
+ word_img_cp1[:, :, ::-1],
141
+ word_img_cp2[:, :, ::-1],
142
+ pseudo_gt_region_score,
143
+ overlay_img,
144
+ ]
145
+ )
146
+
147
+ if not os.path.exists(os.path.dirname(self.vis_test_dir)):
148
+ os.makedirs(os.path.dirname(self.vis_test_dir))
149
+ cv2.imwrite(
150
+ os.path.join(
151
+ self.vis_test_dir,
152
+ "{}_{}".format(
153
+ img_name, f"pseudo_char_bbox_{random.randint(0,100)}.jpg"
154
+ ),
155
+ ),
156
+ vis_result,
157
+ )
158
+
159
+ def clip_into_boundary(self, box, bound):
160
+ if len(box) == 0:
161
+ return box
162
+ else:
163
+ box[:, :, 0] = np.clip(box[:, :, 0], 0, bound[1])
164
+ box[:, :, 1] = np.clip(box[:, :, 1], 0, bound[0])
165
+ return box
166
+
167
+ def get_confidence(self, real_len, pseudo_len):
168
+ if pseudo_len == 0:
169
+ return 0.0
170
+ return (real_len - min(real_len, abs(real_len - pseudo_len))) / real_len
171
+
172
+ def split_word_equal_gap(self, word_img_w, word_img_h, word):
173
+ width = word_img_w
174
+ height = word_img_h
175
+
176
+ width_per_char = width / len(word)
177
+ bboxes = []
178
+ for j, char in enumerate(word):
179
+ if char == " ":
180
+ continue
181
+ left = j * width_per_char
182
+ right = (j + 1) * width_per_char
183
+ bbox = np.array([[left, 0], [right, 0], [right, height], [left, height]])
184
+ bboxes.append(bbox)
185
+
186
+ bboxes = np.array(bboxes, np.float32)
187
+ return bboxes
188
+
189
+ def cal_angle(self, v1):
190
+ theta = np.arccos(min(1, v1[0] / (np.linalg.norm(v1) + 10e-8)))
191
+ return 2 * math.pi - theta if v1[1] < 0 else theta
192
+
193
+ def clockwise_sort(self, points):
194
+ # returns 4x2 [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] ndarray
195
+ v1, v2, v3, v4 = points
196
+ center = (v1 + v2 + v3 + v4) / 4
197
+ theta = np.array(
198
+ [
199
+ self.cal_angle(v1 - center),
200
+ self.cal_angle(v2 - center),
201
+ self.cal_angle(v3 - center),
202
+ self.cal_angle(v4 - center),
203
+ ]
204
+ )
205
+ index = np.argsort(theta)
206
+ return np.array([v1, v2, v3, v4])[index, :]
207
+
208
+ def build_char_box(self, net, gpu, image, word_bbox, word, img_name=""):
209
+ word_image, M, horizontal_text_bool = self.crop_image_by_bbox(
210
+ image, word_bbox, word
211
+ )
212
+ real_word_without_space = word.replace("\s", "")
213
+ real_char_len = len(real_word_without_space)
214
+
215
+ scale = 128.0 / word_image.shape[0]
216
+
217
+ word_image = cv2.resize(word_image, None, fx=scale, fy=scale)
218
+ word_img_h, word_img_w, _ = word_image.shape
219
+
220
+ scores = self.inference_word_box(net, gpu, word_image)
221
+ region_score = scores[0, :, :, 0].cpu().data.numpy()
222
+ region_score = np.uint8(np.clip(region_score, 0, 1) * 255)
223
+
224
+ region_score_rgb = cv2.resize(region_score, (word_img_w, word_img_h))
225
+ region_score_rgb = cv2.cvtColor(region_score_rgb, cv2.COLOR_GRAY2RGB)
226
+
227
+ pseudo_char_bbox = exec_watershed_by_version(
228
+ self.watershed_param, region_score, word_image, self.pseudo_vis_opt
229
+ )
230
+
231
+ # Used for visualize only
232
+ watershed_box = pseudo_char_bbox.copy()
233
+
234
+ pseudo_char_bbox = self.clip_into_boundary(
235
+ pseudo_char_bbox, region_score_rgb.shape
236
+ )
237
+
238
+ confidence = self.get_confidence(real_char_len, len(pseudo_char_bbox))
239
+
240
+ if confidence <= 0.5:
241
+ pseudo_char_bbox = self.split_word_equal_gap(word_img_w, word_img_h, word)
242
+ confidence = 0.5
243
+
244
+ if self.pseudo_vis_opt and self.flag:
245
+ self.visualize_pseudo_label(
246
+ word_image, region_score, watershed_box, pseudo_char_bbox, img_name,
247
+ )
248
+
249
+ if len(pseudo_char_bbox) != 0:
250
+ index = np.argsort(pseudo_char_bbox[:, 0, 0])
251
+ pseudo_char_bbox = pseudo_char_bbox[index]
252
+
253
+ pseudo_char_bbox /= scale
254
+
255
+ M_inv = np.linalg.pinv(M)
256
+ for i in range(len(pseudo_char_bbox)):
257
+ pseudo_char_bbox[i] = cv2.perspectiveTransform(
258
+ pseudo_char_bbox[i][None, :, :], M_inv
259
+ )
260
+
261
+ pseudo_char_bbox = self.clip_into_boundary(pseudo_char_bbox, image.shape)
262
+
263
+ return pseudo_char_bbox, confidence, horizontal_text_bool
trainer/craft/data/pseudo_label/watershed.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from skimage.segmentation import watershed
4
+
5
+
6
+ def segment_region_score(watershed_param, region_score, word_image, pseudo_vis_opt):
7
+ region_score = np.float32(region_score) / 255
8
+ fore = np.uint8(region_score > 0.75)
9
+ back = np.uint8(region_score < 0.05)
10
+ unknown = 1 - (fore + back)
11
+ ret, markers = cv2.connectedComponents(fore)
12
+ markers += 1
13
+ markers[unknown == 1] = 0
14
+
15
+ labels = watershed(-region_score, markers)
16
+ boxes = []
17
+ for label in range(2, ret + 1):
18
+ y, x = np.where(labels == label)
19
+ x_max = x.max()
20
+ y_max = y.max()
21
+ x_min = x.min()
22
+ y_min = y.min()
23
+ box = [[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]]
24
+ box = np.array(box)
25
+ box *= 2
26
+ boxes.append(box)
27
+ return np.array(boxes, dtype=np.float32)
28
+
29
+
30
+ def exec_watershed_by_version(
31
+ watershed_param, region_score, word_image, pseudo_vis_opt
32
+ ):
33
+
34
+ func_name_map_dict = {
35
+ "skimage": segment_region_score,
36
+ }
37
+
38
+ try:
39
+ return func_name_map_dict[watershed_param.version](
40
+ watershed_param, region_score, word_image, pseudo_vis_opt
41
+ )
42
+ except:
43
+ print(
44
+ f"Watershed version {watershed_param.version} does not exist in func_name_map_dict."
45
+ )
trainer/craft/data_root_dir/folder.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ place dataset folder here
trainer/craft/eval.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import argparse
4
+ import os
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torch.backends.cudnn as cudnn
10
+ from tqdm import tqdm
11
+ import wandb
12
+
13
+ from config.load_config import load_yaml, DotDict
14
+ from model.craft import CRAFT
15
+ from metrics.eval_det_iou import DetectionIoUEvaluator
16
+ from utils.inference_boxes import (
17
+ test_net,
18
+ load_icdar2015_gt,
19
+ load_icdar2013_gt,
20
+ load_synthtext_gt,
21
+ )
22
+ from utils.util import copyStateDict
23
+
24
+
25
+
26
+ def save_result_synth(img_file, img, pre_output, pre_box, gt_box=None, result_dir=""):
27
+
28
+ img = np.array(img)
29
+ img_copy = img.copy()
30
+ region = pre_output[0]
31
+ affinity = pre_output[1]
32
+
33
+ # make result file list
34
+ filename, file_ext = os.path.splitext(os.path.basename(img_file))
35
+
36
+ # draw bounding boxes for prediction, color green
37
+ for i, box in enumerate(pre_box):
38
+ poly = np.array(box).astype(np.int32).reshape((-1))
39
+ poly = poly.reshape(-1, 2)
40
+ try:
41
+ cv2.polylines(
42
+ img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
43
+ )
44
+ except:
45
+ pass
46
+
47
+ # draw bounding boxes for gt, color red
48
+ if gt_box is not None:
49
+ for j in range(len(gt_box)):
50
+ cv2.polylines(
51
+ img,
52
+ [np.array(gt_box[j]["points"]).astype(np.int32).reshape((-1, 1, 2))],
53
+ True,
54
+ color=(0, 0, 255),
55
+ thickness=2,
56
+ )
57
+
58
+ # draw overlay image
59
+ overlay_img = overlay(img_copy, region, affinity, pre_box)
60
+
61
+ # Save result image
62
+ res_img_path = result_dir + "/res_" + filename + ".jpg"
63
+ cv2.imwrite(res_img_path, img)
64
+
65
+ overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
66
+ cv2.imwrite(overlay_image_path, overlay_img)
67
+
68
+
69
+ def save_result_2015(img_file, img, pre_output, pre_box, gt_box, result_dir):
70
+
71
+ img = np.array(img)
72
+ img_copy = img.copy()
73
+ region = pre_output[0]
74
+ affinity = pre_output[1]
75
+
76
+ # make result file list
77
+ filename, file_ext = os.path.splitext(os.path.basename(img_file))
78
+
79
+ for i, box in enumerate(pre_box):
80
+ poly = np.array(box).astype(np.int32).reshape((-1))
81
+ poly = poly.reshape(-1, 2)
82
+ try:
83
+ cv2.polylines(
84
+ img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
85
+ )
86
+ except:
87
+ pass
88
+
89
+ if gt_box is not None:
90
+ for j in range(len(gt_box)):
91
+ _gt_box = np.array(gt_box[j]["points"]).reshape(-1, 2).astype(np.int32)
92
+ if gt_box[j]["text"] == "###":
93
+ cv2.polylines(img, [_gt_box], True, color=(128, 128, 128), thickness=2)
94
+ else:
95
+ cv2.polylines(img, [_gt_box], True, color=(0, 0, 255), thickness=2)
96
+
97
+ # draw overlay image
98
+ overlay_img = overlay(img_copy, region, affinity, pre_box)
99
+
100
+ # Save result image
101
+ res_img_path = result_dir + "/res_" + filename + ".jpg"
102
+ cv2.imwrite(res_img_path, img)
103
+
104
+ overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
105
+ cv2.imwrite(overlay_image_path, overlay_img)
106
+
107
+
108
+ def save_result_2013(img_file, img, pre_output, pre_box, gt_box=None, result_dir=""):
109
+
110
+ img = np.array(img)
111
+ img_copy = img.copy()
112
+ region = pre_output[0]
113
+ affinity = pre_output[1]
114
+
115
+ # make result file list
116
+ filename, file_ext = os.path.splitext(os.path.basename(img_file))
117
+
118
+ # draw bounding boxes for prediction, color green
119
+ for i, box in enumerate(pre_box):
120
+ poly = np.array(box).astype(np.int32).reshape((-1))
121
+ poly = poly.reshape(-1, 2)
122
+ try:
123
+ cv2.polylines(
124
+ img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
125
+ )
126
+ except:
127
+ pass
128
+
129
+ # draw bounding boxes for gt, color red
130
+ if gt_box is not None:
131
+ for j in range(len(gt_box)):
132
+ cv2.polylines(
133
+ img,
134
+ [np.array(gt_box[j]["points"]).reshape((-1, 1, 2))],
135
+ True,
136
+ color=(0, 0, 255),
137
+ thickness=2,
138
+ )
139
+
140
+ # draw overlay image
141
+ overlay_img = overlay(img_copy, region, affinity, pre_box)
142
+
143
+ # Save result image
144
+ res_img_path = result_dir + "/res_" + filename + ".jpg"
145
+ cv2.imwrite(res_img_path, img)
146
+
147
+ overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
148
+ cv2.imwrite(overlay_image_path, overlay_img)
149
+
150
+
151
+ def overlay(image, region, affinity, single_img_bbox):
152
+
153
+ height, width, channel = image.shape
154
+
155
+ region_score = cv2.resize(region, (width, height))
156
+ affinity_score = cv2.resize(affinity, (width, height))
157
+
158
+ overlay_region = cv2.addWeighted(image.copy(), 0.4, region_score, 0.6, 5)
159
+ overlay_aff = cv2.addWeighted(image.copy(), 0.4, affinity_score, 0.6, 5)
160
+
161
+ boxed_img = image.copy()
162
+ for word_box in single_img_bbox:
163
+ cv2.polylines(
164
+ boxed_img,
165
+ [word_box.astype(np.int32).reshape((-1, 1, 2))],
166
+ True,
167
+ color=(0, 255, 0),
168
+ thickness=3,
169
+ )
170
+
171
+ temp1 = np.hstack([image, boxed_img])
172
+ temp2 = np.hstack([overlay_region, overlay_aff])
173
+ temp3 = np.vstack([temp1, temp2])
174
+
175
+ return temp3
176
+
177
+
178
+ def load_test_dataset_iou(test_folder_name, config):
179
+
180
+ if test_folder_name == "synthtext":
181
+ total_bboxes_gt, total_img_path = load_synthtext_gt(config.test_data_dir)
182
+
183
+ elif test_folder_name == "icdar2013":
184
+ total_bboxes_gt, total_img_path = load_icdar2013_gt(
185
+ dataFolder=config.test_data_dir
186
+ )
187
+
188
+ elif test_folder_name == "icdar2015":
189
+ total_bboxes_gt, total_img_path = load_icdar2015_gt(
190
+ dataFolder=config.test_data_dir
191
+ )
192
+
193
+ elif test_folder_name == "custom_data":
194
+ total_bboxes_gt, total_img_path = load_icdar2015_gt(
195
+ dataFolder=config.test_data_dir
196
+ )
197
+
198
+ else:
199
+ print("not found test dataset")
200
+ return None, None
201
+
202
+ return total_bboxes_gt, total_img_path
203
+
204
+
205
+ def viz_test(img, pre_output, pre_box, gt_box, img_name, result_dir, test_folder_name):
206
+
207
+ if test_folder_name == "synthtext":
208
+ save_result_synth(
209
+ img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
210
+ )
211
+ elif test_folder_name == "icdar2013":
212
+ save_result_2013(
213
+ img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
214
+ )
215
+ elif test_folder_name == "icdar2015":
216
+ save_result_2015(
217
+ img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
218
+ )
219
+ elif test_folder_name == "custom_data":
220
+ save_result_2015(
221
+ img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
222
+ )
223
+ else:
224
+ print("not found test dataset")
225
+
226
+
227
+ def main_eval(model_path, backbone, config, evaluator, result_dir, buffer, model, mode):
228
+
229
+ if not os.path.exists(result_dir):
230
+ os.makedirs(result_dir, exist_ok=True)
231
+
232
+ total_imgs_bboxes_gt, total_imgs_path = load_test_dataset_iou("custom_data", config)
233
+
234
+ if mode == "weak_supervision" and torch.cuda.device_count() != 1:
235
+ gpu_count = torch.cuda.device_count() // 2
236
+ else:
237
+ gpu_count = torch.cuda.device_count()
238
+ gpu_idx = torch.cuda.current_device()
239
+ torch.cuda.set_device(gpu_idx)
240
+
241
+ # Only evaluation time
242
+ if model is None:
243
+ piece_imgs_path = total_imgs_path
244
+
245
+ if backbone == "vgg":
246
+ model = CRAFT()
247
+ else:
248
+ raise Exception("Undefined architecture")
249
+
250
+ print("Loading weights from checkpoint (" + model_path + ")")
251
+ net_param = torch.load(model_path, map_location=f"cuda:{gpu_idx}")
252
+ model.load_state_dict(copyStateDict(net_param["craft"]))
253
+
254
+ if config.cuda:
255
+ model = model.cuda()
256
+ cudnn.benchmark = False
257
+
258
+ # Distributed evaluation in the middle of training time
259
+ else:
260
+ if buffer is not None:
261
+ # check all buffer value is None for distributed evaluation
262
+ assert all(
263
+ v is None for v in buffer
264
+ ), "Buffer already filled with another value."
265
+ slice_idx = len(total_imgs_bboxes_gt) // gpu_count
266
+
267
+ # last gpu
268
+ if gpu_idx == gpu_count - 1:
269
+ piece_imgs_path = total_imgs_path[gpu_idx * slice_idx :]
270
+ # piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx:]
271
+ else:
272
+ piece_imgs_path = total_imgs_path[
273
+ gpu_idx * slice_idx : (gpu_idx + 1) * slice_idx
274
+ ]
275
+ # piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx: (gpu_idx + 1) * slice_idx]
276
+
277
+ model.eval()
278
+
279
+ # -----------------------------------------------------------------------------------------------------------------#
280
+ total_imgs_bboxes_pre = []
281
+ for k, img_path in enumerate(tqdm(piece_imgs_path)):
282
+ image = cv2.imread(img_path)
283
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
284
+ single_img_bbox = []
285
+ bboxes, polys, score_text = test_net(
286
+ model,
287
+ image,
288
+ config.text_threshold,
289
+ config.link_threshold,
290
+ config.low_text,
291
+ config.cuda,
292
+ config.poly,
293
+ config.canvas_size,
294
+ config.mag_ratio,
295
+ )
296
+
297
+ for box in bboxes:
298
+ box_info = {"points": box, "text": "###", "ignore": False}
299
+ single_img_bbox.append(box_info)
300
+ total_imgs_bboxes_pre.append(single_img_bbox)
301
+ # Distributed evaluation -------------------------------------------------------------------------------------#
302
+ if buffer is not None:
303
+ buffer[gpu_idx * slice_idx + k] = single_img_bbox
304
+ # print(sum([element is not None for element in buffer]))
305
+ # -------------------------------------------------------------------------------------------------------------#
306
+
307
+ if config.vis_opt:
308
+ viz_test(
309
+ image,
310
+ score_text,
311
+ pre_box=polys,
312
+ gt_box=total_imgs_bboxes_gt[k],
313
+ img_name=img_path,
314
+ result_dir=result_dir,
315
+ test_folder_name="custom_data",
316
+ )
317
+
318
+ # When distributed evaluation mode, wait until buffer is full filled
319
+ if buffer is not None:
320
+ while None in buffer:
321
+ continue
322
+ assert all(v is not None for v in buffer), "Buffer not filled"
323
+ total_imgs_bboxes_pre = buffer
324
+
325
+ results = []
326
+ for i, (gt, pred) in enumerate(zip(total_imgs_bboxes_gt, total_imgs_bboxes_pre)):
327
+ perSampleMetrics_dict = evaluator.evaluate_image(gt, pred)
328
+ results.append(perSampleMetrics_dict)
329
+
330
+ metrics = evaluator.combine_results(results)
331
+ print(metrics)
332
+ return metrics
333
+
334
+ def cal_eval(config, data, res_dir_name, opt, mode):
335
+ evaluator = DetectionIoUEvaluator()
336
+ test_config = DotDict(config.test[data])
337
+ res_dir = os.path.join(os.path.join("exp", args.yaml), "{}".format(res_dir_name))
338
+
339
+ if opt == "iou_eval":
340
+ main_eval(
341
+ config.test.trained_model,
342
+ config.train.backbone,
343
+ test_config,
344
+ evaluator,
345
+ res_dir,
346
+ buffer=None,
347
+ model=None,
348
+ mode=mode,
349
+ )
350
+ else:
351
+ print("Undefined evaluation")
352
+
353
+
354
+ if __name__ == "__main__":
355
+
356
+ parser = argparse.ArgumentParser(description="CRAFT Text Detection Eval")
357
+ parser.add_argument(
358
+ "--yaml",
359
+ "--yaml_file_name",
360
+ default="custom_data_train",
361
+ type=str,
362
+ help="Load configuration",
363
+ )
364
+ args = parser.parse_args()
365
+
366
+ # load configure
367
+ config = load_yaml(args.yaml)
368
+ config = DotDict(config)
369
+
370
+ if config["wandb_opt"]:
371
+ wandb.init(project="evaluation", entity="gmuffiness", name=args.yaml)
372
+ wandb.config.update(config)
373
+
374
+ val_result_dir_name = args.yaml
375
+ cal_eval(
376
+ config,
377
+ "custom_data",
378
+ val_result_dir_name + "-ic15-iou",
379
+ opt="iou_eval",
380
+ mode=None,
381
+ )
trainer/craft/exp/folder.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ trained model will be saved here
trainer/craft/loss/mseloss.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Loss(nn.Module):
6
+ def __init__(self):
7
+ super(Loss, self).__init__()
8
+
9
+ def forward(self, gt_region, gt_affinity, pred_region, pred_affinity, conf_map):
10
+ loss = torch.mean(
11
+ ((gt_region - pred_region).pow(2) + (gt_affinity - pred_affinity).pow(2))
12
+ * conf_map
13
+ )
14
+ return loss
15
+
16
+
17
+ class Maploss_v2(nn.Module):
18
+ def __init__(self):
19
+
20
+ super(Maploss_v2, self).__init__()
21
+
22
+ def batch_image_loss(self, pred_score, label_score, neg_rto, n_min_neg):
23
+
24
+ # positive_loss
25
+ positive_pixel = (label_score > 0.1).float()
26
+ positive_pixel_number = torch.sum(positive_pixel)
27
+
28
+ positive_loss_region = pred_score * positive_pixel
29
+
30
+ # negative_loss
31
+ negative_pixel = (label_score <= 0.1).float()
32
+ negative_pixel_number = torch.sum(negative_pixel)
33
+ negative_loss_region = pred_score * negative_pixel
34
+
35
+ if positive_pixel_number != 0:
36
+ if negative_pixel_number < neg_rto * positive_pixel_number:
37
+ negative_loss = (
38
+ torch.sum(
39
+ torch.topk(
40
+ negative_loss_region.view(-1), n_min_neg, sorted=False
41
+ )[0]
42
+ )
43
+ / n_min_neg
44
+ )
45
+ else:
46
+ negative_loss = torch.sum(
47
+ torch.topk(
48
+ negative_loss_region.view(-1),
49
+ int(neg_rto * positive_pixel_number),
50
+ sorted=False,
51
+ )[0]
52
+ ) / (positive_pixel_number * neg_rto)
53
+ positive_loss = torch.sum(positive_loss_region) / positive_pixel_number
54
+ else:
55
+ # only negative pixel
56
+ negative_loss = (
57
+ torch.sum(
58
+ torch.topk(negative_loss_region.view(-1), n_min_neg, sorted=False)[
59
+ 0
60
+ ]
61
+ )
62
+ / n_min_neg
63
+ )
64
+ positive_loss = 0.0
65
+ total_loss = positive_loss + negative_loss
66
+ return total_loss
67
+
68
+ def forward(
69
+ self,
70
+ region_scores_label,
71
+ affinity_socres_label,
72
+ region_scores_pre,
73
+ affinity_scores_pre,
74
+ mask,
75
+ neg_rto,
76
+ n_min_neg,
77
+ ):
78
+ loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
79
+ assert (
80
+ region_scores_label.size() == region_scores_pre.size()
81
+ and affinity_socres_label.size() == affinity_scores_pre.size()
82
+ )
83
+ loss1 = loss_fn(region_scores_pre, region_scores_label)
84
+ loss2 = loss_fn(affinity_scores_pre, affinity_socres_label)
85
+
86
+ loss_region = torch.mul(loss1, mask)
87
+ loss_affinity = torch.mul(loss2, mask)
88
+
89
+ char_loss = self.batch_image_loss(
90
+ loss_region, region_scores_label, neg_rto, n_min_neg
91
+ )
92
+ affi_loss = self.batch_image_loss(
93
+ loss_affinity, affinity_socres_label, neg_rto, n_min_neg
94
+ )
95
+ return char_loss + affi_loss
96
+
97
+
98
+ class Maploss_v3(nn.Module):
99
+ def __init__(self):
100
+
101
+ super(Maploss_v3, self).__init__()
102
+
103
+ def single_image_loss(self, pre_loss, loss_label, neg_rto, n_min_neg):
104
+
105
+ batch_size = pre_loss.shape[0]
106
+
107
+ positive_loss, negative_loss = 0, 0
108
+ for single_loss, single_label in zip(pre_loss, loss_label):
109
+
110
+ # positive_loss
111
+ pos_pixel = (single_label >= 0.1).float()
112
+ n_pos_pixel = torch.sum(pos_pixel)
113
+ pos_loss_region = single_loss * pos_pixel
114
+ positive_loss += torch.sum(pos_loss_region) / max(n_pos_pixel, 1e-12)
115
+
116
+ # negative_loss
117
+ neg_pixel = (single_label < 0.1).float()
118
+ n_neg_pixel = torch.sum(neg_pixel)
119
+ neg_loss_region = single_loss * neg_pixel
120
+
121
+ if n_pos_pixel != 0:
122
+ if n_neg_pixel < neg_rto * n_pos_pixel:
123
+ negative_loss += torch.sum(neg_loss_region) / n_neg_pixel
124
+ else:
125
+ n_hard_neg = max(n_min_neg, neg_rto * n_pos_pixel)
126
+ # n_hard_neg = neg_rto*n_pos_pixel
127
+ negative_loss += (
128
+ torch.sum(
129
+ torch.topk(neg_loss_region.view(-1), int(n_hard_neg))[0]
130
+ )
131
+ / n_hard_neg
132
+ )
133
+ else:
134
+ # only negative pixel
135
+ negative_loss += (
136
+ torch.sum(torch.topk(neg_loss_region.view(-1), n_min_neg)[0])
137
+ / n_min_neg
138
+ )
139
+
140
+ total_loss = (positive_loss + negative_loss) / batch_size
141
+
142
+ return total_loss
143
+
144
+ def forward(
145
+ self,
146
+ region_scores_label,
147
+ affinity_scores_label,
148
+ region_scores_pre,
149
+ affinity_scores_pre,
150
+ mask,
151
+ neg_rto,
152
+ n_min_neg,
153
+ ):
154
+ loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
155
+
156
+ assert (
157
+ region_scores_label.size() == region_scores_pre.size()
158
+ and affinity_scores_label.size() == affinity_scores_pre.size()
159
+ )
160
+ loss1 = loss_fn(region_scores_pre, region_scores_label)
161
+ loss2 = loss_fn(affinity_scores_pre, affinity_scores_label)
162
+
163
+ loss_region = torch.mul(loss1, mask)
164
+ loss_affinity = torch.mul(loss2, mask)
165
+ char_loss = self.single_image_loss(
166
+ loss_region, region_scores_label, neg_rto, n_min_neg
167
+ )
168
+ affi_loss = self.single_image_loss(
169
+ loss_affinity, affinity_scores_label, neg_rto, n_min_neg
170
+ )
171
+
172
+ return char_loss + affi_loss
trainer/craft/metrics/eval_det_iou.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ from collections import namedtuple
4
+ import numpy as np
5
+ from shapely.geometry import Polygon
6
+ """
7
+ cite from:
8
+ PaddleOCR, github: https://github.com/PaddlePaddle/PaddleOCR
9
+ PaddleOCR reference from :
10
+ https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8
11
+ """
12
+
13
+
14
+ class DetectionIoUEvaluator(object):
15
+ def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5):
16
+ self.iou_constraint = iou_constraint
17
+ self.area_precision_constraint = area_precision_constraint
18
+
19
+ def evaluate_image(self, gt, pred):
20
+ def get_union(pD, pG):
21
+ return Polygon(pD).union(Polygon(pG)).area
22
+
23
+ def get_intersection_over_union(pD, pG):
24
+ return get_intersection(pD, pG) / get_union(pD, pG)
25
+
26
+ def get_intersection(pD, pG):
27
+ return Polygon(pD).intersection(Polygon(pG)).area
28
+
29
+ def compute_ap(confList, matchList, numGtCare):
30
+ correct = 0
31
+ AP = 0
32
+ if len(confList) > 0:
33
+ confList = np.array(confList)
34
+ matchList = np.array(matchList)
35
+ sorted_ind = np.argsort(-confList)
36
+ confList = confList[sorted_ind]
37
+ matchList = matchList[sorted_ind]
38
+ for n in range(len(confList)):
39
+ match = matchList[n]
40
+ if match:
41
+ correct += 1
42
+ AP += float(correct) / (n + 1)
43
+
44
+ if numGtCare > 0:
45
+ AP /= numGtCare
46
+
47
+ return AP
48
+
49
+ perSampleMetrics = {}
50
+
51
+ matchedSum = 0
52
+
53
+ Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
54
+
55
+ numGlobalCareGt = 0
56
+ numGlobalCareDet = 0
57
+
58
+ arrGlobalConfidences = []
59
+ arrGlobalMatches = []
60
+
61
+ recall = 0
62
+ precision = 0
63
+ hmean = 0
64
+
65
+ detMatched = 0
66
+
67
+ iouMat = np.empty([1, 1])
68
+
69
+ gtPols = []
70
+ detPols = []
71
+
72
+ gtPolPoints = []
73
+ detPolPoints = []
74
+
75
+ # Array of Ground Truth Polygons' keys marked as don't Care
76
+ gtDontCarePolsNum = []
77
+ # Array of Detected Polygons' matched with a don't Care GT
78
+ detDontCarePolsNum = []
79
+
80
+ pairs = []
81
+ detMatchedNums = []
82
+
83
+ arrSampleConfidences = []
84
+ arrSampleMatch = []
85
+
86
+ evaluationLog = ""
87
+
88
+ # print(len(gt))
89
+
90
+ for n in range(len(gt)):
91
+ points = gt[n]['points']
92
+ # transcription = gt[n]['text']
93
+ dontCare = gt[n]['ignore']
94
+ # points = Polygon(points)
95
+ # points = points.buffer(0)
96
+ try:
97
+ if not Polygon(points).is_valid or not Polygon(points).is_simple:
98
+ continue
99
+ except:
100
+ import ipdb;
101
+ ipdb.set_trace()
102
+
103
+ #import ipdb;ipdb.set_trace()
104
+ gtPol = points
105
+ gtPols.append(gtPol)
106
+ gtPolPoints.append(points)
107
+ if dontCare:
108
+ gtDontCarePolsNum.append(len(gtPols) - 1)
109
+
110
+ evaluationLog += "GT polygons: " + str(len(gtPols)) + (
111
+ " (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
112
+ if len(gtDontCarePolsNum) > 0 else "\n")
113
+
114
+ for n in range(len(pred)):
115
+ points = pred[n]['points']
116
+ # points = Polygon(points)
117
+ # points = points.buffer(0)
118
+ if not Polygon(points).is_valid or not Polygon(points).is_simple:
119
+ continue
120
+
121
+ detPol = points
122
+ detPols.append(detPol)
123
+ detPolPoints.append(points)
124
+ if len(gtDontCarePolsNum) > 0:
125
+ for dontCarePol in gtDontCarePolsNum:
126
+ dontCarePol = gtPols[dontCarePol]
127
+ intersected_area = get_intersection(dontCarePol, detPol)
128
+ pdDimensions = Polygon(detPol).area
129
+ precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
130
+ if (precision > self.area_precision_constraint):
131
+ detDontCarePolsNum.append(len(detPols) - 1)
132
+ break
133
+
134
+ evaluationLog += "DET polygons: " + str(len(detPols)) + (
135
+ " (" + str(len(detDontCarePolsNum)) + " don't care)\n"
136
+ if len(detDontCarePolsNum) > 0 else "\n")
137
+
138
+ if len(gtPols) > 0 and len(detPols) > 0:
139
+ # Calculate IoU and precision matrices
140
+ outputShape = [len(gtPols), len(detPols)]
141
+ iouMat = np.empty(outputShape)
142
+ gtRectMat = np.zeros(len(gtPols), np.int8)
143
+ detRectMat = np.zeros(len(detPols), np.int8)
144
+ for gtNum in range(len(gtPols)):
145
+ for detNum in range(len(detPols)):
146
+ pG = gtPols[gtNum]
147
+ pD = detPols[detNum]
148
+ iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
149
+
150
+ for gtNum in range(len(gtPols)):
151
+ for detNum in range(len(detPols)):
152
+ if gtRectMat[gtNum] == 0 and detRectMat[
153
+ detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
154
+ if iouMat[gtNum, detNum] > self.iou_constraint:
155
+ gtRectMat[gtNum] = 1
156
+ detRectMat[detNum] = 1
157
+ detMatched += 1
158
+ pairs.append({'gt': gtNum, 'det': detNum})
159
+ detMatchedNums.append(detNum)
160
+ evaluationLog += "Match GT #" + \
161
+ str(gtNum) + " with Det #" + str(detNum) + "\n"
162
+
163
+ numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
164
+ numDetCare = (len(detPols) - len(detDontCarePolsNum))
165
+ if numGtCare == 0:
166
+ recall = float(1)
167
+ precision = float(0) if numDetCare > 0 else float(1)
168
+ else:
169
+ recall = float(detMatched) / numGtCare
170
+ precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
171
+
172
+ hmean = 0 if (precision + recall) == 0 else 2.0 * \
173
+ precision * recall / (precision + recall)
174
+
175
+ matchedSum += detMatched
176
+ numGlobalCareGt += numGtCare
177
+ numGlobalCareDet += numDetCare
178
+
179
+ perSampleMetrics = {
180
+ 'precision': precision,
181
+ 'recall': recall,
182
+ 'hmean': hmean,
183
+ 'pairs': pairs,
184
+ 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
185
+ 'gtPolPoints': gtPolPoints,
186
+ 'detPolPoints': detPolPoints,
187
+ 'gtCare': numGtCare,
188
+ 'detCare': numDetCare,
189
+ 'gtDontCare': gtDontCarePolsNum,
190
+ 'detDontCare': detDontCarePolsNum,
191
+ 'detMatched': detMatched,
192
+ 'evaluationLog': evaluationLog
193
+ }
194
+
195
+ return perSampleMetrics
196
+
197
+ def combine_results(self, results):
198
+ numGlobalCareGt = 0
199
+ numGlobalCareDet = 0
200
+ matchedSum = 0
201
+ for result in results:
202
+ numGlobalCareGt += result['gtCare']
203
+ numGlobalCareDet += result['detCare']
204
+ matchedSum += result['detMatched']
205
+
206
+ methodRecall = 0 if numGlobalCareGt == 0 else float(
207
+ matchedSum) / numGlobalCareGt
208
+ methodPrecision = 0 if numGlobalCareDet == 0 else float(
209
+ matchedSum) / numGlobalCareDet
210
+ methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
211
+ methodRecall * methodPrecision / (
212
+ methodRecall + methodPrecision)
213
+ # print(methodRecall, methodPrecision, methodHmean)
214
+ # sys.exit(-1)
215
+ methodMetrics = {
216
+ 'precision': methodPrecision,
217
+ 'recall': methodRecall,
218
+ 'hmean': methodHmean
219
+ }
220
+
221
+ return methodMetrics
222
+
223
+
224
+ if __name__ == '__main__':
225
+ evaluator = DetectionIoUEvaluator()
226
+ gts = [[{
227
+ 'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
228
+ 'text': 1234,
229
+ 'ignore': False,
230
+ }, {
231
+ 'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
232
+ 'text': 5678,
233
+ 'ignore': False,
234
+ }]]
235
+ preds = [[{
236
+ 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
237
+ 'text': 123,
238
+ 'ignore': False,
239
+ }]]
240
+ results = []
241
+ for gt, pred in zip(gts, preds):
242
+ results.append(evaluator.evaluate_image(gt, pred))
243
+ metrics = evaluator.combine_results(results)
244
+ print(metrics)
trainer/craft/model/craft.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2019-present NAVER Corp.
3
+ MIT License
4
+ """
5
+
6
+ # -*- coding: utf-8 -*-
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from model.vgg16_bn import vgg16_bn, init_weights
12
+
13
+ class double_conv(nn.Module):
14
+ def __init__(self, in_ch, mid_ch, out_ch):
15
+ super(double_conv, self).__init__()
16
+ self.conv = nn.Sequential(
17
+ nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
18
+ nn.BatchNorm2d(mid_ch),
19
+ nn.ReLU(inplace=True),
20
+ nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
21
+ nn.BatchNorm2d(out_ch),
22
+ nn.ReLU(inplace=True)
23
+ )
24
+
25
+ def forward(self, x):
26
+ x = self.conv(x)
27
+ return x
28
+
29
+
30
+ class CRAFT(nn.Module):
31
+ def __init__(self, pretrained=True, freeze=False, amp=False):
32
+ super(CRAFT, self).__init__()
33
+
34
+ self.amp = amp
35
+
36
+ """ Base network """
37
+ self.basenet = vgg16_bn(pretrained, freeze)
38
+
39
+ """ U network """
40
+ self.upconv1 = double_conv(1024, 512, 256)
41
+ self.upconv2 = double_conv(512, 256, 128)
42
+ self.upconv3 = double_conv(256, 128, 64)
43
+ self.upconv4 = double_conv(128, 64, 32)
44
+
45
+ num_class = 2
46
+ self.conv_cls = nn.Sequential(
47
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
48
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
49
+ nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
50
+ nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
51
+ nn.Conv2d(16, num_class, kernel_size=1),
52
+ )
53
+
54
+ init_weights(self.upconv1.modules())
55
+ init_weights(self.upconv2.modules())
56
+ init_weights(self.upconv3.modules())
57
+ init_weights(self.upconv4.modules())
58
+ init_weights(self.conv_cls.modules())
59
+
60
+ def forward(self, x):
61
+ """ Base network """
62
+ if self.amp:
63
+ with torch.cuda.amp.autocast():
64
+ sources = self.basenet(x)
65
+
66
+ """ U network """
67
+ y = torch.cat([sources[0], sources[1]], dim=1)
68
+ y = self.upconv1(y)
69
+
70
+ y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
71
+ y = torch.cat([y, sources[2]], dim=1)
72
+ y = self.upconv2(y)
73
+
74
+ y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
75
+ y = torch.cat([y, sources[3]], dim=1)
76
+ y = self.upconv3(y)
77
+
78
+ y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
79
+ y = torch.cat([y, sources[4]], dim=1)
80
+ feature = self.upconv4(y)
81
+
82
+ y = self.conv_cls(feature)
83
+
84
+ return y.permute(0,2,3,1), feature
85
+ else:
86
+
87
+ sources = self.basenet(x)
88
+
89
+ """ U network """
90
+ y = torch.cat([sources[0], sources[1]], dim=1)
91
+ y = self.upconv1(y)
92
+
93
+ y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
94
+ y = torch.cat([y, sources[2]], dim=1)
95
+ y = self.upconv2(y)
96
+
97
+ y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
98
+ y = torch.cat([y, sources[3]], dim=1)
99
+ y = self.upconv3(y)
100
+
101
+ y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
102
+ y = torch.cat([y, sources[4]], dim=1)
103
+ feature = self.upconv4(y)
104
+
105
+ y = self.conv_cls(feature)
106
+
107
+ return y.permute(0, 2, 3, 1), feature
108
+
109
+ if __name__ == '__main__':
110
+ model = CRAFT(pretrained=True).cuda()
111
+ output, _ = model(torch.randn(1, 3, 768, 768).cuda())
112
+ print(output.shape)
trainer/craft/model/vgg16_bn.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.init as init
4
+ import torchvision
5
+ from torchvision import models
6
+ from packaging import version
7
+
8
+ def init_weights(modules):
9
+ for m in modules:
10
+ if isinstance(m, nn.Conv2d):
11
+ init.xavier_uniform_(m.weight.data)
12
+ if m.bias is not None:
13
+ m.bias.data.zero_()
14
+ elif isinstance(m, nn.BatchNorm2d):
15
+ m.weight.data.fill_(1)
16
+ m.bias.data.zero_()
17
+ elif isinstance(m, nn.Linear):
18
+ m.weight.data.normal_(0, 0.01)
19
+ m.bias.data.zero_()
20
+
21
+
22
+ class vgg16_bn(torch.nn.Module):
23
+ def __init__(self, pretrained=True, freeze=True):
24
+ super(vgg16_bn, self).__init__()
25
+ if version.parse(torchvision.__version__) >= version.parse('0.13'):
26
+ vgg_pretrained_features = models.vgg16_bn(
27
+ weights=models.VGG16_BN_Weights.DEFAULT if pretrained else None
28
+ ).features
29
+ else: # torchvision.__version__ < 0.13
30
+ models.vgg.model_urls['vgg16_bn'] = models.vgg.model_urls['vgg16_bn'].replace('https://', 'http://')
31
+ vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
32
+
33
+ self.slice1 = torch.nn.Sequential()
34
+ self.slice2 = torch.nn.Sequential()
35
+ self.slice3 = torch.nn.Sequential()
36
+ self.slice4 = torch.nn.Sequential()
37
+ self.slice5 = torch.nn.Sequential()
38
+ for x in range(12): # conv2_2
39
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
40
+ for x in range(12, 19): # conv3_3
41
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
42
+ for x in range(19, 29): # conv4_3
43
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
44
+ for x in range(29, 39): # conv5_3
45
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
46
+
47
+ # fc6, fc7 without atrous conv
48
+ self.slice5 = torch.nn.Sequential(
49
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
50
+ nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
51
+ nn.Conv2d(1024, 1024, kernel_size=1)
52
+ )
53
+
54
+ if not pretrained:
55
+ init_weights(self.slice1.modules())
56
+ init_weights(self.slice2.modules())
57
+ init_weights(self.slice3.modules())
58
+ init_weights(self.slice4.modules())
59
+
60
+ init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
61
+
62
+ if freeze:
63
+ for param in self.slice1.parameters(): # only first conv
64
+ param.requires_grad= False
65
+
66
+ def forward(self, X):
67
+ h = self.slice1(X)
68
+ h_relu2_2 = h
69
+ h = self.slice2(h)
70
+ h_relu3_2 = h
71
+ h = self.slice3(h)
72
+ h_relu4_3 = h
73
+ h = self.slice4(h)
74
+ h_relu5_3 = h
75
+ h = self.slice5(h)
76
+ h_fc7 = h
77
+ return h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2
trainer/craft/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ conda==4.10.3
2
+ opencv-python==4.5.3.56
3
+ Pillow==9.3.0
4
+ Polygon3==3.0.9.1
5
+ PyYAML==5.4.1
6
+ scikit-image==0.17.2
7
+ Shapely==1.8.0
8
+ torch==1.13.1
9
+ torchvision==0.10.0
10
+ wandb==0.12.9
trainer/craft/scripts/run_cde.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # sed -i -e 's/\r$//' scripts/run_cde.sh
2
+ EXP_NAME=custom_data_release_test_3
3
+ yaml_path="config/$EXP_NAME.yaml"
4
+ cp config/custom_data_train.yaml $yaml_path
5
+ #CUDA_VISIBLE_DEVICES=0,1 python3 train_distributed.py --yaml=$EXP_NAME --port=2468
6
+ CUDA_VISIBLE_DEVICES=0 python3 train.py --yaml=$EXP_NAME --port=2468
7
+ rm "config/$EXP_NAME.yaml"
trainer/craft/train.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import argparse
3
+ import os
4
+ import shutil
5
+ import time
6
+ import multiprocessing as mp
7
+ import yaml
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ import wandb
14
+
15
+ from config.load_config import load_yaml, DotDict
16
+ from data.dataset import SynthTextDataSet, CustomDataset
17
+ from loss.mseloss import Maploss_v2, Maploss_v3
18
+ from model.craft import CRAFT
19
+ from eval import main_eval
20
+ from metrics.eval_det_iou import DetectionIoUEvaluator
21
+ from utils.util import copyStateDict, save_parser
22
+
23
+
24
+ class Trainer(object):
25
+ def __init__(self, config, gpu, mode):
26
+
27
+ self.config = config
28
+ self.gpu = gpu
29
+ self.mode = mode
30
+ self.net_param = self.get_load_param(gpu)
31
+
32
+ def get_synth_loader(self):
33
+
34
+ dataset = SynthTextDataSet(
35
+ output_size=self.config.train.data.output_size,
36
+ data_dir=self.config.train.synth_data_dir,
37
+ saved_gt_dir=None,
38
+ mean=self.config.train.data.mean,
39
+ variance=self.config.train.data.variance,
40
+ gauss_init_size=self.config.train.data.gauss_init_size,
41
+ gauss_sigma=self.config.train.data.gauss_sigma,
42
+ enlarge_region=self.config.train.data.enlarge_region,
43
+ enlarge_affinity=self.config.train.data.enlarge_affinity,
44
+ aug=self.config.train.data.syn_aug,
45
+ vis_test_dir=self.config.vis_test_dir,
46
+ vis_opt=self.config.train.data.vis_opt,
47
+ sample=self.config.train.data.syn_sample,
48
+ )
49
+
50
+ syn_loader = torch.utils.data.DataLoader(
51
+ dataset,
52
+ batch_size=self.config.train.batch_size // self.config.train.synth_ratio,
53
+ shuffle=False,
54
+ num_workers=self.config.train.num_workers,
55
+ drop_last=True,
56
+ pin_memory=True,
57
+ )
58
+ return syn_loader
59
+
60
+ def get_custom_dataset(self):
61
+
62
+ custom_dataset = CustomDataset(
63
+ output_size=self.config.train.data.output_size,
64
+ data_dir=self.config.data_root_dir,
65
+ saved_gt_dir=None,
66
+ mean=self.config.train.data.mean,
67
+ variance=self.config.train.data.variance,
68
+ gauss_init_size=self.config.train.data.gauss_init_size,
69
+ gauss_sigma=self.config.train.data.gauss_sigma,
70
+ enlarge_region=self.config.train.data.enlarge_region,
71
+ enlarge_affinity=self.config.train.data.enlarge_affinity,
72
+ watershed_param=self.config.train.data.watershed,
73
+ aug=self.config.train.data.custom_aug,
74
+ vis_test_dir=self.config.vis_test_dir,
75
+ sample=self.config.train.data.custom_sample,
76
+ vis_opt=self.config.train.data.vis_opt,
77
+ pseudo_vis_opt=self.config.train.data.pseudo_vis_opt,
78
+ do_not_care_label=self.config.train.data.do_not_care_label,
79
+ )
80
+
81
+ return custom_dataset
82
+
83
+ def get_load_param(self, gpu):
84
+
85
+ if self.config.train.ckpt_path is not None:
86
+ map_location = "cuda:%d" % gpu
87
+ param = torch.load(self.config.train.ckpt_path, map_location=map_location)
88
+ else:
89
+ param = None
90
+
91
+ return param
92
+
93
+ def adjust_learning_rate(self, optimizer, gamma, step, lr):
94
+ lr = lr * (gamma ** step)
95
+ for param_group in optimizer.param_groups:
96
+ param_group["lr"] = lr
97
+ return param_group["lr"]
98
+
99
+ def get_loss(self):
100
+ if self.config.train.loss == 2:
101
+ criterion = Maploss_v2()
102
+ elif self.config.train.loss == 3:
103
+ criterion = Maploss_v3()
104
+ else:
105
+ raise Exception("Undefined loss")
106
+ return criterion
107
+
108
+ def iou_eval(self, dataset, train_step, buffer, model):
109
+ test_config = DotDict(self.config.test[dataset])
110
+
111
+ val_result_dir = os.path.join(
112
+ self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step))
113
+ )
114
+
115
+ evaluator = DetectionIoUEvaluator()
116
+
117
+ metrics = main_eval(
118
+ None,
119
+ self.config.train.backbone,
120
+ test_config,
121
+ evaluator,
122
+ val_result_dir,
123
+ buffer,
124
+ model,
125
+ self.mode,
126
+ )
127
+ if self.gpu == 0 and self.config.wandb_opt:
128
+ wandb.log(
129
+ {
130
+ "{} iou Recall".format(dataset): np.round(metrics["recall"], 3),
131
+ "{} iou Precision".format(dataset): np.round(
132
+ metrics["precision"], 3
133
+ ),
134
+ "{} iou F1-score".format(dataset): np.round(metrics["hmean"], 3),
135
+ }
136
+ )
137
+
138
+ def train(self, buffer_dict):
139
+
140
+ torch.cuda.set_device(self.gpu)
141
+
142
+ # MODEL -------------------------------------------------------------------------------------------------------#
143
+ # SUPERVISION model
144
+ if self.config.mode == "weak_supervision":
145
+ if self.config.train.backbone == "vgg":
146
+ supervision_model = CRAFT(pretrained=False, amp=self.config.train.amp)
147
+ else:
148
+ raise Exception("Undefined architecture")
149
+
150
+ supervision_device = self.gpu
151
+ if self.config.train.ckpt_path is not None:
152
+ supervision_param = self.get_load_param(supervision_device)
153
+ supervision_model.load_state_dict(
154
+ copyStateDict(supervision_param["craft"])
155
+ )
156
+ supervision_model = supervision_model.to(f"cuda:{supervision_device}")
157
+ print(f"Supervision model loading on : gpu {supervision_device}")
158
+ else:
159
+ supervision_model, supervision_device = None, None
160
+
161
+ # TRAIN model
162
+ if self.config.train.backbone == "vgg":
163
+ craft = CRAFT(pretrained=False, amp=self.config.train.amp)
164
+ else:
165
+ raise Exception("Undefined architecture")
166
+
167
+ if self.config.train.ckpt_path is not None:
168
+ craft.load_state_dict(copyStateDict(self.net_param["craft"]))
169
+
170
+ craft = craft.cuda()
171
+ craft = torch.nn.DataParallel(craft)
172
+
173
+ torch.backends.cudnn.benchmark = True
174
+
175
+ # DATASET -----------------------------------------------------------------------------------------------------#
176
+
177
+ if self.config.train.use_synthtext:
178
+ trn_syn_loader = self.get_synth_loader()
179
+ batch_syn = iter(trn_syn_loader)
180
+
181
+ if self.config.train.real_dataset == "custom":
182
+ trn_real_dataset = self.get_custom_dataset()
183
+ else:
184
+ raise Exception("Undefined dataset")
185
+
186
+ if self.config.mode == "weak_supervision":
187
+ trn_real_dataset.update_model(supervision_model)
188
+ trn_real_dataset.update_device(supervision_device)
189
+
190
+ trn_real_loader = torch.utils.data.DataLoader(
191
+ trn_real_dataset,
192
+ batch_size=self.config.train.batch_size,
193
+ shuffle=False,
194
+ num_workers=self.config.train.num_workers,
195
+ drop_last=False,
196
+ pin_memory=True,
197
+ )
198
+
199
+ # OPTIMIZER ---------------------------------------------------------------------------------------------------#
200
+ optimizer = optim.Adam(
201
+ craft.parameters(),
202
+ lr=self.config.train.lr,
203
+ weight_decay=self.config.train.weight_decay,
204
+ )
205
+
206
+ if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
207
+ optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
208
+ self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
209
+ self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]
210
+
211
+ # LOSS --------------------------------------------------------------------------------------------------------#
212
+ # mixed precision
213
+ if self.config.train.amp:
214
+ scaler = torch.cuda.amp.GradScaler()
215
+
216
+ if (
217
+ self.config.train.ckpt_path is not None
218
+ and self.config.train.st_iter != 0
219
+ ):
220
+ scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
221
+ else:
222
+ scaler = None
223
+
224
+ criterion = self.get_loss()
225
+
226
+ # TRAIN -------------------------------------------------------------------------------------------------------#
227
+ train_step = self.config.train.st_iter
228
+ whole_training_step = self.config.train.end_iter
229
+ update_lr_rate_step = 0
230
+ training_lr = self.config.train.lr
231
+ loss_value = 0
232
+ batch_time = 0
233
+ start_time = time.time()
234
+
235
+ print(
236
+ "================================ Train start ================================"
237
+ )
238
+ while train_step < whole_training_step:
239
+ for (
240
+ index,
241
+ (
242
+ images,
243
+ region_scores,
244
+ affinity_scores,
245
+ confidence_masks,
246
+ ),
247
+ ) in enumerate(trn_real_loader):
248
+ craft.train()
249
+ if train_step > 0 and train_step % self.config.train.lr_decay == 0:
250
+ update_lr_rate_step += 1
251
+ training_lr = self.adjust_learning_rate(
252
+ optimizer,
253
+ self.config.train.gamma,
254
+ update_lr_rate_step,
255
+ self.config.train.lr,
256
+ )
257
+
258
+ images = images.cuda(non_blocking=True)
259
+ region_scores = region_scores.cuda(non_blocking=True)
260
+ affinity_scores = affinity_scores.cuda(non_blocking=True)
261
+ confidence_masks = confidence_masks.cuda(non_blocking=True)
262
+
263
+ if self.config.train.use_synthtext:
264
+ # Synth image load
265
+ syn_image, syn_region_label, syn_affi_label, syn_confidence_mask = next(
266
+ batch_syn
267
+ )
268
+ syn_image = syn_image.cuda(non_blocking=True)
269
+ syn_region_label = syn_region_label.cuda(non_blocking=True)
270
+ syn_affi_label = syn_affi_label.cuda(non_blocking=True)
271
+ syn_confidence_mask = syn_confidence_mask.cuda(non_blocking=True)
272
+
273
+ # concat syn & custom image
274
+ images = torch.cat((syn_image, images), 0)
275
+ region_image_label = torch.cat(
276
+ (syn_region_label, region_scores), 0
277
+ )
278
+ affinity_image_label = torch.cat((syn_affi_label, affinity_scores), 0)
279
+ confidence_mask_label = torch.cat(
280
+ (syn_confidence_mask, confidence_masks), 0
281
+ )
282
+ else:
283
+ region_image_label = region_scores
284
+ affinity_image_label = affinity_scores
285
+ confidence_mask_label = confidence_masks
286
+
287
+ if self.config.train.amp:
288
+ with torch.cuda.amp.autocast():
289
+
290
+ output, _ = craft(images)
291
+ out1 = output[:, :, :, 0]
292
+ out2 = output[:, :, :, 1]
293
+
294
+ loss = criterion(
295
+ region_image_label,
296
+ affinity_image_label,
297
+ out1,
298
+ out2,
299
+ confidence_mask_label,
300
+ self.config.train.neg_rto,
301
+ self.config.train.n_min_neg,
302
+ )
303
+
304
+ optimizer.zero_grad()
305
+ scaler.scale(loss).backward()
306
+ scaler.step(optimizer)
307
+ scaler.update()
308
+
309
+ else:
310
+ output, _ = craft(images)
311
+ out1 = output[:, :, :, 0]
312
+ out2 = output[:, :, :, 1]
313
+ loss = criterion(
314
+ region_image_label,
315
+ affinity_image_label,
316
+ out1,
317
+ out2,
318
+ confidence_mask_label,
319
+ self.config.train.neg_rto,
320
+ )
321
+
322
+ optimizer.zero_grad()
323
+ loss.backward()
324
+ optimizer.step()
325
+
326
+ end_time = time.time()
327
+ loss_value += loss.item()
328
+ batch_time += end_time - start_time
329
+
330
+ if train_step > 0 and train_step % 5 == 0:
331
+ mean_loss = loss_value / 5
332
+ loss_value = 0
333
+ avg_batch_time = batch_time / 5
334
+ batch_time = 0
335
+
336
+ print(
337
+ "{}, training_step: {}|{}, learning rate: {:.8f}, "
338
+ "training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
339
+ time.strftime(
340
+ "%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
341
+ ),
342
+ train_step,
343
+ whole_training_step,
344
+ training_lr,
345
+ mean_loss,
346
+ avg_batch_time,
347
+ )
348
+ )
349
+
350
+ if self.config.wandb_opt:
351
+ wandb.log({"train_step": train_step, "mean_loss": mean_loss})
352
+
353
+ if (
354
+ train_step % self.config.train.eval_interval == 0
355
+ and train_step != 0
356
+ ):
357
+
358
+ craft.eval()
359
+
360
+ print("Saving state, index:", train_step)
361
+ save_param_dic = {
362
+ "iter": train_step,
363
+ "craft": craft.state_dict(),
364
+ "optimizer": optimizer.state_dict(),
365
+ }
366
+ save_param_path = (
367
+ self.config.results_dir
368
+ + "/CRAFT_clr_"
369
+ + repr(train_step)
370
+ + ".pth"
371
+ )
372
+
373
+ if self.config.train.amp:
374
+ save_param_dic["scaler"] = scaler.state_dict()
375
+ save_param_path = (
376
+ self.config.results_dir
377
+ + "/CRAFT_clr_amp_"
378
+ + repr(train_step)
379
+ + ".pth"
380
+ )
381
+
382
+ torch.save(save_param_dic, save_param_path)
383
+
384
+ # validation
385
+ self.iou_eval(
386
+ "custom_data",
387
+ train_step,
388
+ buffer_dict["custom_data"],
389
+ craft,
390
+ )
391
+
392
+ train_step += 1
393
+ if train_step >= whole_training_step:
394
+ break
395
+
396
+ if self.config.mode == "weak_supervision":
397
+ state_dict = craft.module.state_dict()
398
+ supervision_model.load_state_dict(state_dict)
399
+ trn_real_dataset.update_model(supervision_model)
400
+
401
+ # save last model
402
+ save_param_dic = {
403
+ "iter": train_step,
404
+ "craft": craft.state_dict(),
405
+ "optimizer": optimizer.state_dict(),
406
+ }
407
+ save_param_path = (
408
+ self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
409
+ )
410
+
411
+ if self.config.train.amp:
412
+ save_param_dic["scaler"] = scaler.state_dict()
413
+ save_param_path = (
414
+ self.config.results_dir
415
+ + "/CRAFT_clr_amp_"
416
+ + repr(train_step)
417
+ + ".pth"
418
+ )
419
+ torch.save(save_param_dic, save_param_path)
420
+
421
+
422
+ def main():
423
+ parser = argparse.ArgumentParser(description="CRAFT custom data train")
424
+ parser.add_argument(
425
+ "--yaml",
426
+ "--yaml_file_name",
427
+ default="custom_data_train",
428
+ type=str,
429
+ help="Load configuration",
430
+ )
431
+ parser.add_argument(
432
+ "--port", "--use ddp port", default="2346", type=str, help="Port number"
433
+ )
434
+
435
+ args = parser.parse_args()
436
+
437
+ # load configure
438
+ exp_name = args.yaml
439
+ config = load_yaml(args.yaml)
440
+
441
+ print("-" * 20 + " Options " + "-" * 20)
442
+ print(yaml.dump(config))
443
+ print("-" * 40)
444
+
445
+ # Make result_dir
446
+ res_dir = os.path.join(config["results_dir"], args.yaml)
447
+ config["results_dir"] = res_dir
448
+ if not os.path.exists(res_dir):
449
+ os.makedirs(res_dir)
450
+
451
+ # Duplicate yaml file to result_dir
452
+ shutil.copy(
453
+ "config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
454
+ )
455
+
456
+ if config["mode"] == "weak_supervision":
457
+ mode = "weak_supervision"
458
+ else:
459
+ mode = None
460
+
461
+
462
+ # Apply config to wandb
463
+ if config["wandb_opt"]:
464
+ wandb.init(project="craft-stage2", entity="user_name", name=exp_name)
465
+ wandb.config.update(config)
466
+
467
+ config = DotDict(config)
468
+
469
+ # Start train
470
+ buffer_dict = {"custom_data":None}
471
+ trainer = Trainer(config, 0, mode)
472
+ trainer.train(buffer_dict)
473
+
474
+ if config["wandb_opt"]:
475
+ wandb.finish()
476
+
477
+
478
+ if __name__ == "__main__":
479
+ main()
trainer/craft/trainSynth.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import argparse
3
+ import os
4
+ import shutil
5
+ import time
6
+ import yaml
7
+ import multiprocessing as mp
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ import wandb
14
+
15
+ from config.load_config import load_yaml, DotDict
16
+ from data.dataset import SynthTextDataSet
17
+ from loss.mseloss import Maploss_v2, Maploss_v3
18
+ from model.craft import CRAFT
19
+ from metrics.eval_det_iou import DetectionIoUEvaluator
20
+ from eval import main_eval
21
+ from utils.util import copyStateDict, save_parser
22
+
23
+
24
+ class Trainer(object):
25
+ def __init__(self, config, gpu):
26
+
27
+ self.config = config
28
+ self.gpu = gpu
29
+ self.mode = None
30
+ self.trn_loader, self.trn_sampler = self.get_trn_loader()
31
+ self.net_param = self.get_load_param(gpu)
32
+
33
+ def get_trn_loader(self):
34
+
35
+ dataset = SynthTextDataSet(
36
+ output_size=self.config.train.data.output_size,
37
+ data_dir=self.config.data_dir.synthtext,
38
+ saved_gt_dir=None,
39
+ mean=self.config.train.data.mean,
40
+ variance=self.config.train.data.variance,
41
+ gauss_init_size=self.config.train.data.gauss_init_size,
42
+ gauss_sigma=self.config.train.data.gauss_sigma,
43
+ enlarge_region=self.config.train.data.enlarge_region,
44
+ enlarge_affinity=self.config.train.data.enlarge_affinity,
45
+ aug=self.config.train.data.syn_aug,
46
+ vis_test_dir=self.config.vis_test_dir,
47
+ vis_opt=self.config.train.data.vis_opt,
48
+ sample=self.config.train.data.syn_sample,
49
+ )
50
+
51
+ trn_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
52
+
53
+ trn_loader = torch.utils.data.DataLoader(
54
+ dataset,
55
+ batch_size=self.config.train.batch_size,
56
+ shuffle=False,
57
+ num_workers=self.config.train.num_workers,
58
+ sampler=trn_sampler,
59
+ drop_last=True,
60
+ pin_memory=True,
61
+ )
62
+ return trn_loader, trn_sampler
63
+
64
+ def get_load_param(self, gpu):
65
+ if self.config.train.ckpt_path is not None:
66
+ map_location = {"cuda:%d" % 0: "cuda:%d" % gpu}
67
+ param = torch.load(self.config.train.ckpt_path, map_location=map_location)
68
+ else:
69
+ param = None
70
+ return param
71
+
72
+ def adjust_learning_rate(self, optimizer, gamma, step, lr):
73
+ lr = lr * (gamma ** step)
74
+ for param_group in optimizer.param_groups:
75
+ param_group["lr"] = lr
76
+ return param_group["lr"]
77
+
78
+ def get_loss(self):
79
+ if self.config.train.loss == 2:
80
+ criterion = Maploss_v2()
81
+ elif self.config.train.loss == 3:
82
+ criterion = Maploss_v3()
83
+ else:
84
+ raise Exception("Undefined loss")
85
+ return criterion
86
+
87
+ def iou_eval(self, dataset, train_step, save_param_path, buffer, model):
88
+ test_config = DotDict(self.config.test[dataset])
89
+
90
+ val_result_dir = os.path.join(
91
+ self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step))
92
+ )
93
+
94
+ evaluator = DetectionIoUEvaluator()
95
+
96
+ metrics = main_eval(
97
+ save_param_path,
98
+ self.config.train.backbone,
99
+ test_config,
100
+ evaluator,
101
+ val_result_dir,
102
+ buffer,
103
+ model,
104
+ self.mode,
105
+ )
106
+ if self.gpu == 0 and self.config.wandb_opt:
107
+ wandb.log(
108
+ {
109
+ "{} IoU Recall".format(dataset): np.round(metrics["recall"], 3),
110
+ "{} IoU Precision".format(dataset): np.round(
111
+ metrics["precision"], 3
112
+ ),
113
+ "{} IoU F1-score".format(dataset): np.round(metrics["hmean"], 3),
114
+ }
115
+ )
116
+
117
+ def train(self, buffer_dict):
118
+ torch.cuda.set_device(self.gpu)
119
+
120
+ # DATASET -----------------------------------------------------------------------------------------------------#
121
+ trn_loader = self.trn_loader
122
+
123
+ # MODEL -------------------------------------------------------------------------------------------------------#
124
+ if self.config.train.backbone == "vgg":
125
+ craft = CRAFT(pretrained=True, amp=self.config.train.amp)
126
+ else:
127
+ raise Exception("Undefined architecture")
128
+
129
+ if self.config.train.ckpt_path is not None:
130
+ craft.load_state_dict(copyStateDict(self.net_param["craft"]))
131
+ craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft)
132
+ craft = craft.cuda()
133
+ craft = torch.nn.parallel.DistributedDataParallel(craft, device_ids=[self.gpu])
134
+
135
+ torch.backends.cudnn.benchmark = True
136
+
137
+ # OPTIMIZER----------------------------------------------------------------------------------------------------#
138
+
139
+ optimizer = optim.Adam(
140
+ craft.parameters(),
141
+ lr=self.config.train.lr,
142
+ weight_decay=self.config.train.weight_decay,
143
+ )
144
+
145
+ if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
146
+ optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
147
+ self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
148
+ self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]
149
+
150
+ # LOSS --------------------------------------------------------------------------------------------------------#
151
+ # mixed precision
152
+ if self.config.train.amp:
153
+ scaler = torch.cuda.amp.GradScaler()
154
+
155
+ # load model
156
+ if (
157
+ self.config.train.ckpt_path is not None
158
+ and self.config.train.st_iter != 0
159
+ ):
160
+ scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
161
+ else:
162
+ scaler = None
163
+
164
+ criterion = self.get_loss()
165
+
166
+ # TRAIN -------------------------------------------------------------------------------------------------------#
167
+ train_step = self.config.train.st_iter
168
+ whole_training_step = self.config.train.end_iter
169
+ update_lr_rate_step = 0
170
+ training_lr = self.config.train.lr
171
+ loss_value = 0
172
+ batch_time = 0
173
+ epoch = 0
174
+ start_time = time.time()
175
+
176
+ while train_step < whole_training_step:
177
+ self.trn_sampler.set_epoch(train_step)
178
+ for (
179
+ index,
180
+ (image, region_image, affinity_image, confidence_mask,),
181
+ ) in enumerate(trn_loader):
182
+ craft.train()
183
+ if train_step > 0 and train_step % self.config.train.lr_decay == 0:
184
+ update_lr_rate_step += 1
185
+ training_lr = self.adjust_learning_rate(
186
+ optimizer,
187
+ self.config.train.gamma,
188
+ update_lr_rate_step,
189
+ self.config.train.lr,
190
+ )
191
+
192
+ images = image.cuda(non_blocking=True)
193
+ region_image_label = region_image.cuda(non_blocking=True)
194
+ affinity_image_label = affinity_image.cuda(non_blocking=True)
195
+ confidence_mask_label = confidence_mask.cuda(non_blocking=True)
196
+
197
+ if self.config.train.amp:
198
+ with torch.cuda.amp.autocast():
199
+
200
+ output, _ = craft(images)
201
+ out1 = output[:, :, :, 0]
202
+ out2 = output[:, :, :, 1]
203
+
204
+ loss = criterion(
205
+ region_image_label,
206
+ affinity_image_label,
207
+ out1,
208
+ out2,
209
+ confidence_mask_label,
210
+ self.config.train.neg_rto,
211
+ self.config.train.n_min_neg,
212
+ )
213
+
214
+ optimizer.zero_grad()
215
+ scaler.scale(loss).backward()
216
+ scaler.step(optimizer)
217
+ scaler.update()
218
+
219
+ else:
220
+ output, _ = craft(images)
221
+ out1 = output[:, :, :, 0]
222
+ out2 = output[:, :, :, 1]
223
+ loss = criterion(
224
+ region_image_label,
225
+ affinity_image_label,
226
+ out1,
227
+ out2,
228
+ confidence_mask_label,
229
+ self.config.train.neg_rto,
230
+ )
231
+
232
+ optimizer.zero_grad()
233
+ loss.backward()
234
+ optimizer.step()
235
+
236
+ end_time = time.time()
237
+ loss_value += loss.item()
238
+ batch_time += end_time - start_time
239
+
240
+ if train_step > 0 and train_step % 5 == 0 and self.gpu == 0:
241
+ mean_loss = loss_value / 5
242
+ loss_value = 0
243
+ avg_batch_time = batch_time / 5
244
+ batch_time = 0
245
+
246
+ print(
247
+ "{}, training_step: {}|{}, learning rate: {:.8f}, "
248
+ "training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
249
+ time.strftime(
250
+ "%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
251
+ ),
252
+ train_step,
253
+ whole_training_step,
254
+ training_lr,
255
+ mean_loss,
256
+ avg_batch_time,
257
+ )
258
+ )
259
+ if self.gpu == 0 and self.config.wandb_opt:
260
+ wandb.log({"train_step": train_step, "mean_loss": mean_loss})
261
+
262
+ if (
263
+ train_step % self.config.train.eval_interval == 0
264
+ and train_step != 0
265
+ ):
266
+
267
+ # initialize all buffer value with zero
268
+ if self.gpu == 0:
269
+ for buffer in buffer_dict.values():
270
+ for i in range(len(buffer)):
271
+ buffer[i] = None
272
+
273
+ print("Saving state, index:", train_step)
274
+ save_param_dic = {
275
+ "iter": train_step,
276
+ "craft": craft.state_dict(),
277
+ "optimizer": optimizer.state_dict(),
278
+ }
279
+ save_param_path = (
280
+ self.config.results_dir
281
+ + "/CRAFT_clr_"
282
+ + repr(train_step)
283
+ + ".pth"
284
+ )
285
+
286
+ if self.config.train.amp:
287
+ save_param_dic["scaler"] = scaler.state_dict()
288
+ save_param_path = (
289
+ self.config.results_dir
290
+ + "/CRAFT_clr_amp_"
291
+ + repr(train_step)
292
+ + ".pth"
293
+ )
294
+
295
+ if self.gpu == 0:
296
+ torch.save(save_param_dic, save_param_path)
297
+
298
+ # validation
299
+ self.iou_eval(
300
+ "icdar2013",
301
+ train_step,
302
+ save_param_path,
303
+ buffer_dict["icdar2013"],
304
+ craft,
305
+ )
306
+
307
+ train_step += 1
308
+ if train_step >= whole_training_step:
309
+ break
310
+ epoch += 1
311
+
312
+ # save last model
313
+ if self.gpu == 0:
314
+ save_param_dic = {
315
+ "iter": train_step,
316
+ "craft": craft.state_dict(),
317
+ "optimizer": optimizer.state_dict(),
318
+ }
319
+ save_param_path = (
320
+ self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
321
+ )
322
+
323
+ if self.config.train.amp:
324
+ save_param_dic["scaler"] = scaler.state_dict()
325
+ save_param_path = (
326
+ self.config.results_dir
327
+ + "/CRAFT_clr_amp_"
328
+ + repr(train_step)
329
+ + ".pth"
330
+ )
331
+ torch.save(save_param_dic, save_param_path)
332
+
333
+
334
+ def main():
335
+ parser = argparse.ArgumentParser(description="CRAFT SynthText Train")
336
+ parser.add_argument(
337
+ "--yaml",
338
+ "--yaml_file_name",
339
+ default="syn_train",
340
+ type=str,
341
+ help="Load configuration",
342
+ )
343
+ parser.add_argument(
344
+ "--port", "--use ddp port", default="2646", type=str, help="Load configuration"
345
+ )
346
+
347
+ args = parser.parse_args()
348
+
349
+ # load configure
350
+ exp_name = args.yaml
351
+ config = load_yaml(args.yaml)
352
+
353
+ print("-" * 20 + " Options " + "-" * 20)
354
+ print(yaml.dump(config))
355
+ print("-" * 40)
356
+
357
+ # Make result_dir
358
+ res_dir = os.path.join(config["results_dir"], args.yaml)
359
+ config["results_dir"] = res_dir
360
+ if not os.path.exists(res_dir):
361
+ os.makedirs(res_dir)
362
+
363
+ # Duplicate yaml file to result_dir
364
+ shutil.copy(
365
+ "config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
366
+ )
367
+
368
+ ngpus_per_node = torch.cuda.device_count()
369
+ print(f"Total device num : {ngpus_per_node}")
370
+
371
+ manager = mp.Manager()
372
+ buffer1 = manager.list([None] * config["test"]["icdar2013"]["test_set_size"])
373
+ buffer_dict = {"icdar2013": buffer1}
374
+ torch.multiprocessing.spawn(
375
+ main_worker,
376
+ nprocs=ngpus_per_node,
377
+ args=(args.port, ngpus_per_node, config, buffer_dict, exp_name,),
378
+ )
379
+ print('flag5')
380
+
381
+
382
+ def main_worker(gpu, port, ngpus_per_node, config, buffer_dict, exp_name):
383
+
384
+ torch.distributed.init_process_group(
385
+ backend="nccl",
386
+ init_method="tcp://127.0.0.1:" + port,
387
+ world_size=ngpus_per_node,
388
+ rank=gpu,
389
+ )
390
+
391
+ # Apply config to wandb
392
+ if gpu == 0 and config["wandb_opt"]:
393
+ wandb.init(project="craft-stage1", entity="gmuffiness", name=exp_name)
394
+ wandb.config.update(config)
395
+
396
+ batch_size = int(config["train"]["batch_size"] / ngpus_per_node)
397
+ config["train"]["batch_size"] = batch_size
398
+ config = DotDict(config)
399
+
400
+ # Start train
401
+ trainer = Trainer(config, gpu)
402
+ trainer.train(buffer_dict)
403
+
404
+ if gpu == 0 and config["wandb_opt"]:
405
+ wandb.finish()
406
+ torch.distributed.destroy_process_group()
407
+
408
+ if __name__ == "__main__":
409
+ main()
trainer/craft/train_distributed.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import argparse
3
+ import os
4
+ import shutil
5
+ import time
6
+ import multiprocessing as mp
7
+ import yaml
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ import wandb
14
+
15
+ from config.load_config import load_yaml, DotDict
16
+ from data.dataset import SynthTextDataSet, CustomDataset
17
+ from loss.mseloss import Maploss_v2, Maploss_v3
18
+ from model.craft import CRAFT
19
+ from eval import main_eval
20
+ from metrics.eval_det_iou import DetectionIoUEvaluator
21
+ from utils.util import copyStateDict, save_parser
22
+
23
+
24
+ class Trainer(object):
25
+ def __init__(self, config, gpu, mode):
26
+
27
+ self.config = config
28
+ self.gpu = gpu
29
+ self.mode = mode
30
+ self.net_param = self.get_load_param(gpu)
31
+
32
+ def get_synth_loader(self):
33
+
34
+ dataset = SynthTextDataSet(
35
+ output_size=self.config.train.data.output_size,
36
+ data_dir=self.config.train.synth_data_dir,
37
+ saved_gt_dir=None,
38
+ mean=self.config.train.data.mean,
39
+ variance=self.config.train.data.variance,
40
+ gauss_init_size=self.config.train.data.gauss_init_size,
41
+ gauss_sigma=self.config.train.data.gauss_sigma,
42
+ enlarge_region=self.config.train.data.enlarge_region,
43
+ enlarge_affinity=self.config.train.data.enlarge_affinity,
44
+ aug=self.config.train.data.syn_aug,
45
+ vis_test_dir=self.config.vis_test_dir,
46
+ vis_opt=self.config.train.data.vis_opt,
47
+ sample=self.config.train.data.syn_sample,
48
+ )
49
+
50
+ syn_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
51
+
52
+ syn_loader = torch.utils.data.DataLoader(
53
+ dataset,
54
+ batch_size=self.config.train.batch_size // self.config.train.synth_ratio,
55
+ shuffle=False,
56
+ num_workers=self.config.train.num_workers,
57
+ sampler=syn_sampler,
58
+ drop_last=True,
59
+ pin_memory=True,
60
+ )
61
+ return syn_loader
62
+
63
+ def get_custom_dataset(self):
64
+
65
+ custom_dataset = CustomDataset(
66
+ output_size=self.config.train.data.output_size,
67
+ data_dir=self.config.data_root_dir,
68
+ saved_gt_dir=None,
69
+ mean=self.config.train.data.mean,
70
+ variance=self.config.train.data.variance,
71
+ gauss_init_size=self.config.train.data.gauss_init_size,
72
+ gauss_sigma=self.config.train.data.gauss_sigma,
73
+ enlarge_region=self.config.train.data.enlarge_region,
74
+ enlarge_affinity=self.config.train.data.enlarge_affinity,
75
+ watershed_param=self.config.train.data.watershed,
76
+ aug=self.config.train.data.custom_aug,
77
+ vis_test_dir=self.config.vis_test_dir,
78
+ sample=self.config.train.data.custom_sample,
79
+ vis_opt=self.config.train.data.vis_opt,
80
+ pseudo_vis_opt=self.config.train.data.pseudo_vis_opt,
81
+ do_not_care_label=self.config.train.data.do_not_care_label,
82
+ )
83
+
84
+ return custom_dataset
85
+
86
+ def get_load_param(self, gpu):
87
+
88
+ if self.config.train.ckpt_path is not None:
89
+ map_location = "cuda:%d" % gpu
90
+ param = torch.load(self.config.train.ckpt_path, map_location=map_location)
91
+ else:
92
+ param = None
93
+
94
+ return param
95
+
96
+ def adjust_learning_rate(self, optimizer, gamma, step, lr):
97
+ lr = lr * (gamma ** step)
98
+ for param_group in optimizer.param_groups:
99
+ param_group["lr"] = lr
100
+ return param_group["lr"]
101
+
102
+ def get_loss(self):
103
+ if self.config.train.loss == 2:
104
+ criterion = Maploss_v2()
105
+ elif self.config.train.loss == 3:
106
+ criterion = Maploss_v3()
107
+ else:
108
+ raise Exception("Undefined loss")
109
+ return criterion
110
+
111
+ def iou_eval(self, dataset, train_step, buffer, model):
112
+ test_config = DotDict(self.config.test[dataset])
113
+
114
+ val_result_dir = os.path.join(
115
+ self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step))
116
+ )
117
+
118
+ evaluator = DetectionIoUEvaluator()
119
+
120
+ metrics = main_eval(
121
+ None,
122
+ self.config.train.backbone,
123
+ test_config,
124
+ evaluator,
125
+ val_result_dir,
126
+ buffer,
127
+ model,
128
+ self.mode,
129
+ )
130
+ if self.gpu == 0 and self.config.wandb_opt:
131
+ wandb.log(
132
+ {
133
+ "{} iou Recall".format(dataset): np.round(metrics["recall"], 3),
134
+ "{} iou Precision".format(dataset): np.round(
135
+ metrics["precision"], 3
136
+ ),
137
+ "{} iou F1-score".format(dataset): np.round(metrics["hmean"], 3),
138
+ }
139
+ )
140
+
141
+ def train(self, buffer_dict):
142
+
143
+ torch.cuda.set_device(self.gpu)
144
+ total_gpu_num = torch.cuda.device_count()
145
+
146
+ # MODEL -------------------------------------------------------------------------------------------------------#
147
+ # SUPERVISION model
148
+ if self.config.mode == "weak_supervision":
149
+ if self.config.train.backbone == "vgg":
150
+ supervision_model = CRAFT(pretrained=False, amp=self.config.train.amp)
151
+ else:
152
+ raise Exception("Undefined architecture")
153
+
154
+ # NOTE: only work on half GPU assign train / half GPU assign supervision setting
155
+ supervision_device = total_gpu_num // 2 + self.gpu
156
+ if self.config.train.ckpt_path is not None:
157
+ supervision_param = self.get_load_param(supervision_device)
158
+ supervision_model.load_state_dict(
159
+ copyStateDict(supervision_param["craft"])
160
+ )
161
+ supervision_model = supervision_model.to(f"cuda:{supervision_device}")
162
+ print(f"Supervision model loading on : gpu {supervision_device}")
163
+ else:
164
+ supervision_model, supervision_device = None, None
165
+
166
+ # TRAIN model
167
+ if self.config.train.backbone == "vgg":
168
+ craft = CRAFT(pretrained=False, amp=self.config.train.amp)
169
+ else:
170
+ raise Exception("Undefined architecture")
171
+
172
+ if self.config.train.ckpt_path is not None:
173
+ craft.load_state_dict(copyStateDict(self.net_param["craft"]))
174
+
175
+ craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft)
176
+ craft = craft.cuda()
177
+ craft = torch.nn.parallel.DistributedDataParallel(craft, device_ids=[self.gpu])
178
+
179
+ torch.backends.cudnn.benchmark = True
180
+
181
+ # DATASET -----------------------------------------------------------------------------------------------------#
182
+
183
+ if self.config.train.use_synthtext:
184
+ trn_syn_loader = self.get_synth_loader()
185
+ batch_syn = iter(trn_syn_loader)
186
+
187
+ if self.config.train.real_dataset == "custom":
188
+ trn_real_dataset = self.get_custom_dataset()
189
+ else:
190
+ raise Exception("Undefined dataset")
191
+
192
+ if self.config.mode == "weak_supervision":
193
+ trn_real_dataset.update_model(supervision_model)
194
+ trn_real_dataset.update_device(supervision_device)
195
+
196
+ trn_real_sampler = torch.utils.data.distributed.DistributedSampler(
197
+ trn_real_dataset
198
+ )
199
+ trn_real_loader = torch.utils.data.DataLoader(
200
+ trn_real_dataset,
201
+ batch_size=self.config.train.batch_size,
202
+ shuffle=False,
203
+ num_workers=self.config.train.num_workers,
204
+ sampler=trn_real_sampler,
205
+ drop_last=False,
206
+ pin_memory=True,
207
+ )
208
+
209
+ # OPTIMIZER ---------------------------------------------------------------------------------------------------#
210
+ optimizer = optim.Adam(
211
+ craft.parameters(),
212
+ lr=self.config.train.lr,
213
+ weight_decay=self.config.train.weight_decay,
214
+ )
215
+
216
+ if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
217
+ optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
218
+ self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
219
+ self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]
220
+
221
+ # LOSS --------------------------------------------------------------------------------------------------------#
222
+ # mixed precision
223
+ if self.config.train.amp:
224
+ scaler = torch.cuda.amp.GradScaler()
225
+
226
+ if (
227
+ self.config.train.ckpt_path is not None
228
+ and self.config.train.st_iter != 0
229
+ ):
230
+ scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
231
+ else:
232
+ scaler = None
233
+
234
+ criterion = self.get_loss()
235
+
236
+ # TRAIN -------------------------------------------------------------------------------------------------------#
237
+ train_step = self.config.train.st_iter
238
+ whole_training_step = self.config.train.end_iter
239
+ update_lr_rate_step = 0
240
+ training_lr = self.config.train.lr
241
+ loss_value = 0
242
+ batch_time = 0
243
+ start_time = time.time()
244
+
245
+ print(
246
+ "================================ Train start ================================"
247
+ )
248
+ while train_step < whole_training_step:
249
+ trn_real_sampler.set_epoch(train_step)
250
+ for (
251
+ index,
252
+ (
253
+ images,
254
+ region_scores,
255
+ affinity_scores,
256
+ confidence_masks,
257
+ ),
258
+ ) in enumerate(trn_real_loader):
259
+ craft.train()
260
+ if train_step > 0 and train_step % self.config.train.lr_decay == 0:
261
+ update_lr_rate_step += 1
262
+ training_lr = self.adjust_learning_rate(
263
+ optimizer,
264
+ self.config.train.gamma,
265
+ update_lr_rate_step,
266
+ self.config.train.lr,
267
+ )
268
+
269
+ images = images.cuda(non_blocking=True)
270
+ region_scores = region_scores.cuda(non_blocking=True)
271
+ affinity_scores = affinity_scores.cuda(non_blocking=True)
272
+ confidence_masks = confidence_masks.cuda(non_blocking=True)
273
+
274
+ if self.config.train.use_synthtext:
275
+ # Synth image load
276
+ syn_image, syn_region_label, syn_affi_label, syn_confidence_mask = next(
277
+ batch_syn
278
+ )
279
+ syn_image = syn_image.cuda(non_blocking=True)
280
+ syn_region_label = syn_region_label.cuda(non_blocking=True)
281
+ syn_affi_label = syn_affi_label.cuda(non_blocking=True)
282
+ syn_confidence_mask = syn_confidence_mask.cuda(non_blocking=True)
283
+
284
+ # concat syn & custom image
285
+ images = torch.cat((syn_image, images), 0)
286
+ region_image_label = torch.cat(
287
+ (syn_region_label, region_scores), 0
288
+ )
289
+ affinity_image_label = torch.cat((syn_affi_label, affinity_scores), 0)
290
+ confidence_mask_label = torch.cat(
291
+ (syn_confidence_mask, confidence_masks), 0
292
+ )
293
+ else:
294
+ region_image_label = region_scores
295
+ affinity_image_label = affinity_scores
296
+ confidence_mask_label = confidence_masks
297
+
298
+ if self.config.train.amp:
299
+ with torch.cuda.amp.autocast():
300
+
301
+ output, _ = craft(images)
302
+ out1 = output[:, :, :, 0]
303
+ out2 = output[:, :, :, 1]
304
+
305
+ loss = criterion(
306
+ region_image_label,
307
+ affinity_image_label,
308
+ out1,
309
+ out2,
310
+ confidence_mask_label,
311
+ self.config.train.neg_rto,
312
+ self.config.train.n_min_neg,
313
+ )
314
+
315
+ optimizer.zero_grad()
316
+ scaler.scale(loss).backward()
317
+ scaler.step(optimizer)
318
+ scaler.update()
319
+
320
+ else:
321
+ output, _ = craft(images)
322
+ out1 = output[:, :, :, 0]
323
+ out2 = output[:, :, :, 1]
324
+ loss = criterion(
325
+ region_image_label,
326
+ affinity_image_label,
327
+ out1,
328
+ out2,
329
+ confidence_mask_label,
330
+ self.config.train.neg_rto,
331
+ )
332
+
333
+ optimizer.zero_grad()
334
+ loss.backward()
335
+ optimizer.step()
336
+
337
+ end_time = time.time()
338
+ loss_value += loss.item()
339
+ batch_time += end_time - start_time
340
+
341
+ if train_step > 0 and train_step % 5 == 0 and self.gpu == 0:
342
+ mean_loss = loss_value / 5
343
+ loss_value = 0
344
+ avg_batch_time = batch_time / 5
345
+ batch_time = 0
346
+
347
+ print(
348
+ "{}, training_step: {}|{}, learning rate: {:.8f}, "
349
+ "training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
350
+ time.strftime(
351
+ "%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
352
+ ),
353
+ train_step,
354
+ whole_training_step,
355
+ training_lr,
356
+ mean_loss,
357
+ avg_batch_time,
358
+ )
359
+ )
360
+
361
+ if self.gpu == 0 and self.config.wandb_opt:
362
+ wandb.log({"train_step": train_step, "mean_loss": mean_loss})
363
+
364
+ if (
365
+ train_step % self.config.train.eval_interval == 0
366
+ and train_step != 0
367
+ ):
368
+
369
+ craft.eval()
370
+ # initialize all buffer value with zero
371
+ if self.gpu == 0:
372
+ for buffer in buffer_dict.values():
373
+ for i in range(len(buffer)):
374
+ buffer[i] = None
375
+
376
+ print("Saving state, index:", train_step)
377
+ save_param_dic = {
378
+ "iter": train_step,
379
+ "craft": craft.state_dict(),
380
+ "optimizer": optimizer.state_dict(),
381
+ }
382
+ save_param_path = (
383
+ self.config.results_dir
384
+ + "/CRAFT_clr_"
385
+ + repr(train_step)
386
+ + ".pth"
387
+ )
388
+
389
+ if self.config.train.amp:
390
+ save_param_dic["scaler"] = scaler.state_dict()
391
+ save_param_path = (
392
+ self.config.results_dir
393
+ + "/CRAFT_clr_amp_"
394
+ + repr(train_step)
395
+ + ".pth"
396
+ )
397
+
398
+ torch.save(save_param_dic, save_param_path)
399
+
400
+ # validation
401
+ self.iou_eval(
402
+ "custom_data",
403
+ train_step,
404
+ buffer_dict["custom_data"],
405
+ craft,
406
+ )
407
+
408
+ train_step += 1
409
+ if train_step >= whole_training_step:
410
+ break
411
+
412
+ if self.config.mode == "weak_supervision":
413
+ state_dict = craft.module.state_dict()
414
+ supervision_model.load_state_dict(state_dict)
415
+ trn_real_dataset.update_model(supervision_model)
416
+
417
+ # save last model
418
+ if self.gpu == 0:
419
+ save_param_dic = {
420
+ "iter": train_step,
421
+ "craft": craft.state_dict(),
422
+ "optimizer": optimizer.state_dict(),
423
+ }
424
+ save_param_path = (
425
+ self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
426
+ )
427
+
428
+ if self.config.train.amp:
429
+ save_param_dic["scaler"] = scaler.state_dict()
430
+ save_param_path = (
431
+ self.config.results_dir
432
+ + "/CRAFT_clr_amp_"
433
+ + repr(train_step)
434
+ + ".pth"
435
+ )
436
+ torch.save(save_param_dic, save_param_path)
437
+
438
+ def main():
439
+ parser = argparse.ArgumentParser(description="CRAFT custom data train")
440
+ parser.add_argument(
441
+ "--yaml",
442
+ "--yaml_file_name",
443
+ default="custom_data_train",
444
+ type=str,
445
+ help="Load configuration",
446
+ )
447
+ parser.add_argument(
448
+ "--port", "--use ddp port", default="2346", type=str, help="Port number"
449
+ )
450
+
451
+ args = parser.parse_args()
452
+
453
+ # load configure
454
+ exp_name = args.yaml
455
+ config = load_yaml(args.yaml)
456
+
457
+ print("-" * 20 + " Options " + "-" * 20)
458
+ print(yaml.dump(config))
459
+ print("-" * 40)
460
+
461
+ # Make result_dir
462
+ res_dir = os.path.join(config["results_dir"], args.yaml)
463
+ config["results_dir"] = res_dir
464
+ if not os.path.exists(res_dir):
465
+ os.makedirs(res_dir)
466
+
467
+ # Duplicate yaml file to result_dir
468
+ shutil.copy(
469
+ "config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
470
+ )
471
+
472
+ if config["mode"] == "weak_supervision":
473
+ # NOTE: half GPU assign train / half GPU assign supervision setting
474
+ ngpus_per_node = torch.cuda.device_count() // 2
475
+ mode = "weak_supervision"
476
+ else:
477
+ ngpus_per_node = torch.cuda.device_count()
478
+ mode = None
479
+
480
+ print(f"Total process num : {ngpus_per_node}")
481
+
482
+ manager = mp.Manager()
483
+ buffer1 = manager.list([None] * config["test"]["custom_data"]["test_set_size"])
484
+
485
+ buffer_dict = {"custom_data": buffer1}
486
+ torch.multiprocessing.spawn(
487
+ main_worker,
488
+ nprocs=ngpus_per_node,
489
+ args=(args.port, ngpus_per_node, config, buffer_dict, exp_name, mode,),
490
+ )
491
+
492
+
493
+ def main_worker(gpu, port, ngpus_per_node, config, buffer_dict, exp_name, mode):
494
+
495
+ torch.distributed.init_process_group(
496
+ backend="nccl",
497
+ init_method="tcp://127.0.0.1:" + port,
498
+ world_size=ngpus_per_node,
499
+ rank=gpu,
500
+ )
501
+
502
+ # Apply config to wandb
503
+ if gpu == 0 and config["wandb_opt"]:
504
+ wandb.init(project="craft-stage2", entity="user_name", name=exp_name)
505
+ wandb.config.update(config)
506
+
507
+ batch_size = int(config["train"]["batch_size"] / ngpus_per_node)
508
+ config["train"]["batch_size"] = batch_size
509
+ config = DotDict(config)
510
+
511
+ # Start train
512
+ trainer = Trainer(config, gpu, mode)
513
+ trainer.train(buffer_dict)
514
+
515
+ if gpu == 0:
516
+ if config["wandb_opt"]:
517
+ wandb.finish()
518
+
519
+ torch.distributed.barrier()
520
+ torch.distributed.destroy_process_group()
521
+
522
+ if __name__ == "__main__":
523
+ main()
trainer/craft/utils/craft_utils.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ import torch
5
+ import cv2
6
+ import math
7
+ import numpy as np
8
+ from data import imgproc
9
+
10
+ """ auxilary functions """
11
+ # unwarp corodinates
12
+
13
+
14
+
15
+
16
+ def warpCoord(Minv, pt):
17
+ out = np.matmul(Minv, (pt[0], pt[1], 1))
18
+ return np.array([out[0]/out[2], out[1]/out[2]])
19
+ """ end of auxilary functions """
20
+
21
+ def test():
22
+ print('pass')
23
+
24
+
25
+ def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
26
+ # prepare data
27
+ linkmap = linkmap.copy()
28
+ textmap = textmap.copy()
29
+ img_h, img_w = textmap.shape
30
+
31
+ """ labeling method """
32
+ ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
33
+ ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
34
+
35
+ text_score_comb = np.clip(text_score + link_score, 0, 1)
36
+ nLabels, labels, stats, centroids = \
37
+ cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)
38
+
39
+ det = []
40
+ mapper = []
41
+ for k in range(1,nLabels):
42
+ # size filtering
43
+ size = stats[k, cv2.CC_STAT_AREA]
44
+ if size < 10: continue
45
+
46
+ # thresholding
47
+ if np.max(textmap[labels==k]) < text_threshold: continue
48
+
49
+ # make segmentation map
50
+ segmap = np.zeros(textmap.shape, dtype=np.uint8)
51
+ segmap[labels==k] = 255
52
+ segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area
53
+ x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
54
+ w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
55
+ niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
56
+ sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
57
+ # boundary check
58
+ if sx < 0 : sx = 0
59
+ if sy < 0 : sy = 0
60
+ if ex >= img_w: ex = img_w
61
+ if ey >= img_h: ey = img_h
62
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))
63
+ segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel, iterations=1)
64
+ #kernel1 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 5))
65
+ #segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel1, iterations=1)
66
+
67
+
68
+ # make box
69
+ np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
70
+ rectangle = cv2.minAreaRect(np_contours)
71
+ box = cv2.boxPoints(rectangle)
72
+
73
+ # align diamond-shape
74
+ w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
75
+ box_ratio = max(w, h) / (min(w, h) + 1e-5)
76
+ if abs(1 - box_ratio) <= 0.1:
77
+ l, r = min(np_contours[:,0]), max(np_contours[:,0])
78
+ t, b = min(np_contours[:,1]), max(np_contours[:,1])
79
+ box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
80
+
81
+ # make clock-wise order
82
+ startidx = box.sum(axis=1).argmin()
83
+ box = np.roll(box, 4-startidx, 0)
84
+ box = np.array(box)
85
+
86
+ det.append(box)
87
+ mapper.append(k)
88
+
89
+ return det, labels, mapper
90
+
91
+ def getPoly_core(boxes, labels, mapper, linkmap):
92
+ # configs
93
+ num_cp = 5
94
+ max_len_ratio = 0.7
95
+ expand_ratio = 1.45
96
+ max_r = 2.0
97
+ step_r = 0.2
98
+
99
+ polys = []
100
+ for k, box in enumerate(boxes):
101
+ # size filter for small instance
102
+ w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)
103
+ if w < 30 or h < 30:
104
+ polys.append(None); continue
105
+
106
+ # warp image
107
+ tar = np.float32([[0,0],[w,0],[w,h],[0,h]])
108
+ M = cv2.getPerspectiveTransform(box, tar)
109
+ word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)
110
+ try:
111
+ Minv = np.linalg.inv(M)
112
+ except:
113
+ polys.append(None); continue
114
+
115
+ # binarization for selected label
116
+ cur_label = mapper[k]
117
+ word_label[word_label != cur_label] = 0
118
+ word_label[word_label > 0] = 1
119
+
120
+ """ Polygon generation """
121
+ # find top/bottom contours
122
+ cp = []
123
+ max_len = -1
124
+ for i in range(w):
125
+ region = np.where(word_label[:,i] != 0)[0]
126
+ if len(region) < 2 : continue
127
+ cp.append((i, region[0], region[-1]))
128
+ length = region[-1] - region[0] + 1
129
+ if length > max_len: max_len = length
130
+
131
+ # pass if max_len is similar to h
132
+ if h * max_len_ratio < max_len:
133
+ polys.append(None); continue
134
+
135
+ # get pivot points with fixed length
136
+ tot_seg = num_cp * 2 + 1
137
+ seg_w = w / tot_seg # segment width
138
+ pp = [None] * num_cp # init pivot points
139
+ cp_section = [[0, 0]] * tot_seg
140
+ seg_height = [0] * num_cp
141
+ seg_num = 0
142
+ num_sec = 0
143
+ prev_h = -1
144
+ for i in range(0,len(cp)):
145
+ (x, sy, ey) = cp[i]
146
+ if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
147
+ # average previous segment
148
+ if num_sec == 0: break
149
+ cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]
150
+ num_sec = 0
151
+
152
+ # reset variables
153
+ seg_num += 1
154
+ prev_h = -1
155
+
156
+ # accumulate center points
157
+ cy = (sy + ey) * 0.5
158
+ cur_h = ey - sy + 1
159
+ cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]
160
+ num_sec += 1
161
+
162
+ if seg_num % 2 == 0: continue # No polygon area
163
+
164
+ if prev_h < cur_h:
165
+ pp[int((seg_num - 1)/2)] = (x, cy)
166
+ seg_height[int((seg_num - 1)/2)] = cur_h
167
+ prev_h = cur_h
168
+
169
+ # processing last segment
170
+ if num_sec != 0:
171
+ cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
172
+
173
+ # pass if num of pivots is not sufficient or segment widh is smaller than character height
174
+ if None in pp or seg_w < np.max(seg_height) * 0.25:
175
+ polys.append(None); continue
176
+
177
+ # calc median maximum of pivot points
178
+ half_char_h = np.median(seg_height) * expand_ratio / 2
179
+
180
+ # calc gradiant and apply to make horizontal pivots
181
+ new_pp = []
182
+ for i, (x, cy) in enumerate(pp):
183
+ dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
184
+ dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
185
+ if dx == 0: # gradient if zero
186
+ new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
187
+ continue
188
+ rad = - math.atan2(dy, dx)
189
+ c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
190
+ new_pp.append([x - s, cy - c, x + s, cy + c])
191
+
192
+ # get edge points to cover character heatmaps
193
+ isSppFound, isEppFound = False, False
194
+ grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
195
+ grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
196
+ for r in np.arange(0.5, max_r, step_r):
197
+ dx = 2 * half_char_h * r
198
+ if not isSppFound:
199
+ line_img = np.zeros(word_label.shape, dtype=np.uint8)
200
+ dy = grad_s * dx
201
+ p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
202
+ cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
203
+ if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
204
+ spp = p
205
+ isSppFound = True
206
+ if not isEppFound:
207
+ line_img = np.zeros(word_label.shape, dtype=np.uint8)
208
+ dy = grad_e * dx
209
+ p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
210
+ cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
211
+ if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
212
+ epp = p
213
+ isEppFound = True
214
+ if isSppFound and isEppFound:
215
+ break
216
+
217
+ # pass if boundary of polygon is not found
218
+ if not (isSppFound and isEppFound):
219
+ polys.append(None); continue
220
+
221
+ # make final polygon
222
+ poly = []
223
+ poly.append(warpCoord(Minv, (spp[0], spp[1])))
224
+ for p in new_pp:
225
+ poly.append(warpCoord(Minv, (p[0], p[1])))
226
+ poly.append(warpCoord(Minv, (epp[0], epp[1])))
227
+ poly.append(warpCoord(Minv, (epp[2], epp[3])))
228
+ for p in reversed(new_pp):
229
+ poly.append(warpCoord(Minv, (p[2], p[3])))
230
+ poly.append(warpCoord(Minv, (spp[2], spp[3])))
231
+
232
+ # add to final result
233
+ polys.append(np.array(poly))
234
+
235
+ return polys
236
+
237
+ def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
238
+ boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
239
+
240
+ if poly:
241
+ polys = getPoly_core(boxes, labels, mapper, linkmap)
242
+ else:
243
+ polys = [None] * len(boxes)
244
+
245
+ return boxes, polys
246
+
247
+ def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):
248
+ if len(polys) > 0:
249
+ polys = np.array(polys)
250
+ for k in range(len(polys)):
251
+ if polys[k] is not None:
252
+ polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
253
+ return polys
254
+
255
+ def save_outputs(image, region_scores, affinity_scores, text_threshold, link_threshold,
256
+ low_text, outoput_path, confidence_mask = None):
257
+ """save image, region_scores, and affinity_scores in a single image. region_scores and affinity_scores must be
258
+ cpu numpy arrays. You can convert GPU Tensors to CPU numpy arrays like this:
259
+ >>> array = tensor.cpu().data.numpy()
260
+ When saving outputs of the network during training, make sure you convert ALL tensors (image, region_score,
261
+ affinity_score) to numpy array first.
262
+ :param image: numpy array
263
+ :param region_scores: [] 2D numpy array with each element between 0~1.
264
+ :param affinity_scores: same as region_scores
265
+ :param text_threshold: 0 ~ 1. Closer to 0, characters with lower confidence will also be considered a word and be boxed
266
+ :param link_threshold: 0 ~ 1. Closer to 0, links with lower confidence will also be considered a word and be boxed
267
+ :param low_text: 0 ~ 1. Closer to 0, boxes will be more loosely drawn.
268
+ :param outoput_path:
269
+ :param confidence_mask:
270
+ :return:
271
+ """
272
+
273
+ assert region_scores.shape == affinity_scores.shape
274
+ assert len(image.shape) - 1 == len(region_scores.shape)
275
+
276
+ boxes, polys = getDetBoxes(region_scores, affinity_scores, text_threshold, link_threshold,
277
+ low_text, False)
278
+ boxes = np.array(boxes, np.int32) * 2
279
+ if len(boxes) > 0:
280
+ np.clip(boxes[:, :, 0], 0, image.shape[1])
281
+ np.clip(boxes[:, :, 1], 0, image.shape[0])
282
+ for box in boxes:
283
+ cv2.polylines(image, [np.reshape(box, (-1, 1, 2))], True, (0, 0, 255))
284
+
285
+ target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores)
286
+ target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores)
287
+
288
+ if confidence_mask is not None:
289
+ confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask)
290
+ gt_scores = np.hstack([target_gaussian_heatmap_color, target_gaussian_affinity_heatmap_color])
291
+ confidence_mask_gray = np.hstack([np.zeros_like(confidence_mask_gray), confidence_mask_gray])
292
+ output = np.concatenate([gt_scores, confidence_mask_gray], axis=0)
293
+ output = np.hstack([image, output])
294
+
295
+ else:
296
+ gt_scores = np.concatenate([target_gaussian_heatmap_color, target_gaussian_affinity_heatmap_color], axis=0)
297
+ output = np.hstack([image, gt_scores])
298
+
299
+ cv2.imwrite(outoput_path, output)
300
+ return output
301
+
302
+
303
+ def save_outputs_from_tensors(images, region_scores, affinity_scores, text_threshold, link_threshold,
304
+ low_text, output_dir, image_names, confidence_mask = None):
305
+
306
+ """takes images, region_scores, and affinity_scores as tensors (cab be GPU).
307
+ :param images: 4D tensor
308
+ :param region_scores: 3D tensor with values between 0 ~ 1
309
+ :param affinity_scores: 3D tensor with values between 0 ~ 1
310
+ :param text_threshold:
311
+ :param link_threshold:
312
+ :param low_text:
313
+ :param output_dir: direcotry to save the output images. Will be joined with base names of image_names
314
+ :param image_names: names of each image. Doesn't have to be the base name (image file names)
315
+ :param confidence_mask:
316
+ :return:
317
+ """
318
+ #import ipdb;ipdb.set_trace()
319
+ #images = images.cpu().permute(0, 2, 3, 1).contiguous().data.numpy()
320
+ if type(images) == torch.Tensor:
321
+ images = np.array(images)
322
+
323
+ region_scores = region_scores.cpu().data.numpy()
324
+ affinity_scores = affinity_scores.cpu().data.numpy()
325
+
326
+ batch_size = images.shape[0]
327
+ assert batch_size == region_scores.shape[0] and batch_size == affinity_scores.shape[0] and batch_size == len(image_names), \
328
+ "The first dimension (i.e. batch size) of images, region scores, and affinity scores must be equal"
329
+
330
+ output_images = []
331
+
332
+ for i in range(batch_size):
333
+ image = images[i]
334
+ region_score = region_scores[i]
335
+ affinity_score = affinity_scores[i]
336
+
337
+ image_name = os.path.basename(image_names[i])
338
+ outoput_path = os.path.join(output_dir,image_name)
339
+
340
+ output_image = save_outputs(image, region_score, affinity_score, text_threshold, link_threshold,
341
+ low_text, outoput_path, confidence_mask=confidence_mask)
342
+
343
+ output_images.append(output_image)
344
+
345
+ return output_images
trainer/craft/utils/inference_boxes.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import itertools
4
+
5
+ import cv2
6
+ import time
7
+ import numpy as np
8
+ import torch
9
+ from torch.autograd import Variable
10
+
11
+ from utils.craft_utils import getDetBoxes, adjustResultCoordinates
12
+ from data import imgproc
13
+ from data.dataset import SynthTextDataSet
14
+ import math
15
+ import xml.etree.ElementTree as elemTree
16
+
17
+
18
+ #-------------------------------------------------------------------------------------------------------------------#
19
+ def rotatePoint(xc, yc, xp, yp, theta):
20
+ xoff = xp - xc
21
+ yoff = yp - yc
22
+
23
+ cosTheta = math.cos(theta)
24
+ sinTheta = math.sin(theta)
25
+ pResx = cosTheta * xoff + sinTheta * yoff
26
+ pResy = - sinTheta * xoff + cosTheta * yoff
27
+ # pRes = (xc + pResx, yc + pResy)
28
+ return int(xc + pResx), int(yc + pResy)
29
+
30
+ def addRotatedShape(cx, cy, w, h, angle):
31
+ p0x, p0y = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle)
32
+ p1x, p1y = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle)
33
+ p2x, p2y = rotatePoint(cx, cy, cx + w / 2, cy + h / 2, -angle)
34
+ p3x, p3y = rotatePoint(cx, cy, cx - w / 2, cy + h / 2, -angle)
35
+
36
+ points = [[p0x, p0y], [p1x, p1y], [p2x, p2y], [p3x, p3y]]
37
+
38
+ return points
39
+
40
+ def xml_parsing(xml):
41
+ tree = elemTree.parse(xml)
42
+
43
+ annotations = [] # Initialize the list to store labels
44
+ iter_element = tree.iter(tag="object")
45
+
46
+ for element in iter_element:
47
+ annotation = {} # Initialize the dict to store labels
48
+
49
+ annotation['name'] = element.find("name").text # Save the name tag value
50
+
51
+ box_coords = element.iter(tag="robndbox")
52
+
53
+ for box_coord in box_coords:
54
+ cx = float(box_coord.find("cx").text)
55
+ cy = float(box_coord.find("cy").text)
56
+ w = float(box_coord.find("w").text)
57
+ h = float(box_coord.find("h").text)
58
+ angle = float(box_coord.find("angle").text)
59
+
60
+ convertcoodi = addRotatedShape(cx, cy, w, h, angle)
61
+
62
+ annotation['box_coodi'] = convertcoodi
63
+ annotations.append(annotation)
64
+
65
+ box_coords = element.iter(tag="bndbox")
66
+
67
+ for box_coord in box_coords:
68
+ xmin = int(box_coord.find("xmin").text)
69
+ ymin = int(box_coord.find("ymin").text)
70
+ xmax = int(box_coord.find("xmax").text)
71
+ ymax = int(box_coord.find("ymax").text)
72
+ # annotation['bndbox'] = [xmin,ymin,xmax,ymax]
73
+
74
+ annotation['box_coodi'] = [[xmin, ymin], [xmax, ymin], [xmax, ymax],
75
+ [xmin, ymax]]
76
+ annotations.append(annotation)
77
+
78
+
79
+
80
+
81
+ bounds = []
82
+ for i in range(len(annotations)):
83
+ box_info_dict = {"points": None, "text": None, "ignore": None}
84
+
85
+ box_info_dict["points"] = np.array(annotations[i]['box_coodi'])
86
+ if annotations[i]['name'] == "dnc":
87
+ box_info_dict["text"] = "###"
88
+ box_info_dict["ignore"] = True
89
+ else:
90
+ box_info_dict["text"] = annotations[i]['name']
91
+ box_info_dict["ignore"] = False
92
+
93
+ bounds.append(box_info_dict)
94
+
95
+
96
+
97
+ return bounds
98
+
99
+ #-------------------------------------------------------------------------------------------------------------------#
100
+
101
+ def load_prescription_gt(dataFolder):
102
+
103
+
104
+ total_img_path = []
105
+ total_imgs_bboxes = []
106
+ for (root, directories, files) in os.walk(dataFolder):
107
+ for file in files:
108
+ if '.jpg' in file:
109
+ img_path = os.path.join(root, file)
110
+ total_img_path.append(img_path)
111
+ if '.xml' in file:
112
+ gt_path = os.path.join(root, file)
113
+ total_imgs_bboxes.append(gt_path)
114
+
115
+
116
+ total_imgs_parsing_bboxes = []
117
+ for img_path, bbox in zip(sorted(total_img_path), sorted(total_imgs_bboxes)):
118
+ # check file
119
+
120
+ assert img_path.split(".jpg")[0] == bbox.split(".xml")[0]
121
+
122
+ result_label = xml_parsing(bbox)
123
+ total_imgs_parsing_bboxes.append(result_label)
124
+
125
+
126
+ return total_imgs_parsing_bboxes, sorted(total_img_path)
127
+
128
+
129
+ # NOTE
130
+ def load_prescription_cleval_gt(dataFolder):
131
+
132
+
133
+ total_img_path = []
134
+ total_gt_path = []
135
+ for (root, directories, files) in os.walk(dataFolder):
136
+ for file in files:
137
+ if '.jpg' in file:
138
+ img_path = os.path.join(root, file)
139
+ total_img_path.append(img_path)
140
+ if '_cl.txt' in file:
141
+ gt_path = os.path.join(root, file)
142
+ total_gt_path.append(gt_path)
143
+
144
+
145
+ total_imgs_parsing_bboxes = []
146
+ for img_path, gt_path in zip(sorted(total_img_path), sorted(total_gt_path)):
147
+ # check file
148
+
149
+ assert img_path.split(".jpg")[0] == gt_path.split('_label_cl.txt')[0]
150
+
151
+ lines = open(gt_path, encoding="utf-8").readlines()
152
+ word_bboxes = []
153
+
154
+ for line in lines:
155
+ box_info_dict = {"points": None, "text": None, "ignore": None}
156
+ box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
157
+
158
+ box_points = [int(box_info[i]) for i in range(8)]
159
+ box_info_dict["points"] = np.array(box_points)
160
+
161
+ word_bboxes.append(box_info_dict)
162
+ total_imgs_parsing_bboxes.append(word_bboxes)
163
+
164
+ return total_imgs_parsing_bboxes, sorted(total_img_path)
165
+
166
+
167
+ def load_synthtext_gt(data_folder):
168
+
169
+ synth_dataset = SynthTextDataSet(
170
+ output_size=768, data_dir=data_folder, saved_gt_dir=data_folder, logging=False
171
+ )
172
+ img_names, img_bbox, img_words = synth_dataset.load_data(bbox="word")
173
+
174
+ total_img_path = []
175
+ total_imgs_bboxes = []
176
+ for index in range(len(img_bbox[:100])):
177
+ img_path = os.path.join(data_folder, img_names[index][0])
178
+ total_img_path.append(img_path)
179
+ try:
180
+ wordbox = img_bbox[index].transpose((2, 1, 0))
181
+ except:
182
+ wordbox = np.expand_dims(img_bbox[index], axis=0)
183
+ wordbox = wordbox.transpose((0, 2, 1))
184
+
185
+ words = [re.split(" \n|\n |\n| ", t.strip()) for t in img_words[index]]
186
+ words = list(itertools.chain(*words))
187
+ words = [t for t in words if len(t) > 0]
188
+
189
+ if len(words) != len(wordbox):
190
+ import ipdb
191
+
192
+ ipdb.set_trace()
193
+
194
+ single_img_bboxes = []
195
+ for j in range(len(words)):
196
+ box_info_dict = {"points": None, "text": None, "ignore": None}
197
+ box_info_dict["points"] = wordbox[j]
198
+ box_info_dict["text"] = words[j]
199
+ box_info_dict["ignore"] = False
200
+ single_img_bboxes.append(box_info_dict)
201
+
202
+ total_imgs_bboxes.append(single_img_bboxes)
203
+
204
+ return total_imgs_bboxes, total_img_path
205
+
206
+
207
+ def load_icdar2015_gt(dataFolder, isTraing=False):
208
+ if isTraing:
209
+ img_folderName = "ch4_training_images"
210
+ gt_folderName = "ch4_training_localization_transcription_gt"
211
+ else:
212
+ img_folderName = "ch4_test_images"
213
+ gt_folderName = "ch4_test_localization_transcription_gt"
214
+
215
+ gt_folder_path = os.listdir(os.path.join(dataFolder, gt_folderName))
216
+ total_imgs_bboxes = []
217
+ total_img_path = []
218
+ for gt_path in gt_folder_path:
219
+ gt_path = os.path.join(os.path.join(dataFolder, gt_folderName), gt_path)
220
+ img_path = (
221
+ gt_path.replace(gt_folderName, img_folderName)
222
+ .replace(".txt", ".jpg")
223
+ .replace("gt_", "")
224
+ )
225
+ image = cv2.imread(img_path)
226
+ lines = open(gt_path, encoding="utf-8").readlines()
227
+ single_img_bboxes = []
228
+ for line in lines:
229
+ box_info_dict = {"points": None, "text": None, "ignore": None}
230
+
231
+ box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
232
+ box_points = [int(box_info[j]) for j in range(8)]
233
+ word = box_info[8:]
234
+ word = ",".join(word)
235
+ box_points = np.array(box_points, np.int32).reshape(4, 2)
236
+ cv2.polylines(
237
+ image, [np.array(box_points).astype(np.int)], True, (0, 0, 255), 1
238
+ )
239
+ box_info_dict["points"] = box_points
240
+ box_info_dict["text"] = word
241
+ if word == "###":
242
+ box_info_dict["ignore"] = True
243
+ else:
244
+ box_info_dict["ignore"] = False
245
+
246
+ single_img_bboxes.append(box_info_dict)
247
+ total_imgs_bboxes.append(single_img_bboxes)
248
+ total_img_path.append(img_path)
249
+ return total_imgs_bboxes, total_img_path
250
+
251
+
252
+ def load_icdar2013_gt(dataFolder, isTraing=False):
253
+
254
+ # choose test dataset
255
+ if isTraing:
256
+ img_folderName = "Challenge2_Test_Task12_Images"
257
+ gt_folderName = "Challenge2_Test_Task1_GT"
258
+ else:
259
+ img_folderName = "Challenge2_Test_Task12_Images"
260
+ gt_folderName = "Challenge2_Test_Task1_GT"
261
+
262
+ gt_folder_path = os.listdir(os.path.join(dataFolder, gt_folderName))
263
+
264
+ total_imgs_bboxes = []
265
+ total_img_path = []
266
+ for gt_path in gt_folder_path:
267
+ gt_path = os.path.join(os.path.join(dataFolder, gt_folderName), gt_path)
268
+ img_path = (
269
+ gt_path.replace(gt_folderName, img_folderName)
270
+ .replace(".txt", ".jpg")
271
+ .replace("gt_", "")
272
+ )
273
+ image = cv2.imread(img_path)
274
+ lines = open(gt_path, encoding="utf-8").readlines()
275
+ single_img_bboxes = []
276
+ for line in lines:
277
+ box_info_dict = {"points": None, "text": None, "ignore": None}
278
+
279
+ box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
280
+ box = [int(box_info[j]) for j in range(4)]
281
+ word = box_info[4:]
282
+ word = ",".join(word)
283
+ box = [
284
+ [box[0], box[1]],
285
+ [box[2], box[1]],
286
+ [box[2], box[3]],
287
+ [box[0], box[3]],
288
+ ]
289
+
290
+ box_info_dict["points"] = box
291
+ box_info_dict["text"] = word
292
+ if word == "###":
293
+ box_info_dict["ignore"] = True
294
+ else:
295
+ box_info_dict["ignore"] = False
296
+
297
+ single_img_bboxes.append(box_info_dict)
298
+
299
+ total_imgs_bboxes.append(single_img_bboxes)
300
+ total_img_path.append(img_path)
301
+
302
+ return total_imgs_bboxes, total_img_path
303
+
304
+
305
+ def test_net(
306
+ net,
307
+ image,
308
+ text_threshold,
309
+ link_threshold,
310
+ low_text,
311
+ cuda,
312
+ poly,
313
+ canvas_size=1280,
314
+ mag_ratio=1.5,
315
+ ):
316
+ # resize
317
+
318
+ img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
319
+ image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio
320
+ )
321
+ ratio_h = ratio_w = 1 / target_ratio
322
+
323
+ # preprocessing
324
+ x = imgproc.normalizeMeanVariance(img_resized)
325
+ x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
326
+ x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
327
+ if cuda:
328
+ x = x.cuda()
329
+
330
+ # forward pass
331
+ with torch.no_grad():
332
+ y, feature = net(x)
333
+
334
+ # make score and link map
335
+ score_text = y[0, :, :, 0].cpu().data.numpy().astype(np.float32)
336
+ score_link = y[0, :, :, 1].cpu().data.numpy().astype(np.float32)
337
+
338
+ # NOTE
339
+ score_text = score_text[: size_heatmap[0], : size_heatmap[1]]
340
+ score_link = score_link[: size_heatmap[0], : size_heatmap[1]]
341
+
342
+ # Post-processing
343
+ boxes, polys = getDetBoxes(
344
+ score_text, score_link, text_threshold, link_threshold, low_text, poly
345
+ )
346
+
347
+ # coordinate adjustment
348
+ boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
349
+ polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
350
+ for k in range(len(polys)):
351
+ if polys[k] is None:
352
+ polys[k] = boxes[k]
353
+
354
+ # render results (optional)
355
+ score_text = score_text.copy()
356
+ render_score_text = imgproc.cvt2HeatmapImg(score_text)
357
+ render_score_link = imgproc.cvt2HeatmapImg(score_link)
358
+ render_img = [render_score_text, render_score_link]
359
+ # ret_score_text = imgproc.cvt2HeatmapImg(render_img)
360
+
361
+ return boxes, polys, render_img
trainer/craft/utils/util.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import os
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from data import imgproc
8
+ from utils import craft_utils
9
+
10
+
11
+ def copyStateDict(state_dict):
12
+ if list(state_dict.keys())[0].startswith("module"):
13
+ start_idx = 1
14
+ else:
15
+ start_idx = 0
16
+ new_state_dict = OrderedDict()
17
+ for k, v in state_dict.items():
18
+ name = ".".join(k.split(".")[start_idx:])
19
+ new_state_dict[name] = v
20
+ return new_state_dict
21
+
22
+
23
+ def saveInput(
24
+ imagename, vis_dir, image, region_scores, affinity_scores, confidence_mask
25
+ ):
26
+ image = np.uint8(image.copy())
27
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
28
+
29
+ boxes, polys = craft_utils.getDetBoxes(
30
+ region_scores, affinity_scores, 0.85, 0.2, 0.5, False
31
+ )
32
+
33
+ if image.shape[0] / region_scores.shape[0] >= 2:
34
+ boxes = np.array(boxes, np.int32) * 2
35
+ else:
36
+ boxes = np.array(boxes, np.int32)
37
+
38
+ if len(boxes) > 0:
39
+ np.clip(boxes[:, :, 0], 0, image.shape[1])
40
+ np.clip(boxes[:, :, 1], 0, image.shape[0])
41
+ for box in boxes:
42
+ cv2.polylines(image, [np.reshape(box, (-1, 1, 2))], True, (0, 0, 255))
43
+ target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores)
44
+ target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores)
45
+ confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask)
46
+
47
+ # overlay
48
+ height, width, channel = image.shape
49
+ overlay_region = cv2.resize(target_gaussian_heatmap_color, (width, height))
50
+ overlay_aff = cv2.resize(target_gaussian_affinity_heatmap_color, (width, height))
51
+ confidence_mask_gray = cv2.resize(
52
+ confidence_mask_gray, (width, height), interpolation=cv2.INTER_NEAREST
53
+ )
54
+ overlay_region = cv2.addWeighted(image, 0.4, overlay_region, 0.6, 5)
55
+ overlay_aff = cv2.addWeighted(image, 0.4, overlay_aff, 0.7, 6)
56
+
57
+ gt_scores = np.concatenate([overlay_region, overlay_aff], axis=1)
58
+
59
+ output = np.concatenate([gt_scores, confidence_mask_gray], axis=1)
60
+
61
+ output = np.hstack([image, output])
62
+
63
+ # synthtext
64
+ if type(imagename) is not str:
65
+ imagename = imagename[0].split("/")[-1][:-4]
66
+
67
+ outpath = vis_dir + f"/{imagename}_input.jpg"
68
+ if not os.path.exists(os.path.dirname(outpath)):
69
+ os.makedirs(os.path.dirname(outpath), exist_ok=True)
70
+ cv2.imwrite(outpath, output)
71
+ # print(f'Logging train input into {outpath}')
72
+
73
+
74
+ def saveImage(
75
+ imagename,
76
+ vis_dir,
77
+ image,
78
+ bboxes,
79
+ affi_bboxes,
80
+ region_scores,
81
+ affinity_scores,
82
+ confidence_mask,
83
+ ):
84
+ output_image = np.uint8(image.copy())
85
+ output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR)
86
+ if len(bboxes) > 0:
87
+ for i in range(len(bboxes)):
88
+ _bboxes = np.int32(bboxes[i])
89
+ for j in range(_bboxes.shape[0]):
90
+ cv2.polylines(
91
+ output_image,
92
+ [np.reshape(_bboxes[j], (-1, 1, 2))],
93
+ True,
94
+ (0, 0, 255),
95
+ )
96
+
97
+ for i in range(len(affi_bboxes)):
98
+ cv2.polylines(
99
+ output_image,
100
+ [np.reshape(affi_bboxes[i].astype(np.int32), (-1, 1, 2))],
101
+ True,
102
+ (255, 0, 0),
103
+ )
104
+
105
+ target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores)
106
+ target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores)
107
+ confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask)
108
+
109
+ # overlay
110
+ height, width, channel = image.shape
111
+ overlay_region = cv2.resize(target_gaussian_heatmap_color, (width, height))
112
+ overlay_aff = cv2.resize(target_gaussian_affinity_heatmap_color, (width, height))
113
+
114
+ overlay_region = cv2.addWeighted(image.copy(), 0.4, overlay_region, 0.6, 5)
115
+ overlay_aff = cv2.addWeighted(image.copy(), 0.4, overlay_aff, 0.6, 5)
116
+
117
+ heat_map = np.concatenate([overlay_region, overlay_aff], axis=1)
118
+
119
+ # synthtext
120
+ if type(imagename) is not str:
121
+ imagename = imagename[0].split("/")[-1][:-4]
122
+
123
+ output = np.concatenate([output_image, heat_map, confidence_mask_gray], axis=1)
124
+ outpath = vis_dir + f"/{imagename}.jpg"
125
+ if not os.path.exists(os.path.dirname(outpath)):
126
+ os.makedirs(os.path.dirname(outpath), exist_ok=True)
127
+
128
+ cv2.imwrite(outpath, output)
129
+ # print(f'Logging original image into {outpath}')
130
+
131
+
132
+ def save_parser(args):
133
+
134
+ """ final options """
135
+ with open(f"{args.results_dir}/opt.txt", "a", encoding="utf-8") as opt_file:
136
+ opt_log = "------------ Options -------------\n"
137
+ arg = vars(args)
138
+ for k, v in arg.items():
139
+ opt_log += f"{str(k)}: {str(v)}\n"
140
+ opt_log += "---------------------------------------\n"
141
+ print(opt_log)
142
+ opt_file.write(opt_log)
trainer/dataset.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ import six
5
+ import math
6
+ import torch
7
+ import pandas as pd
8
+
9
+ from natsort import natsorted
10
+ from PIL import Image
11
+ import numpy as np
12
+ from torch.utils.data import Dataset, ConcatDataset, Subset
13
+ from torch._utils import _accumulate
14
+ import torchvision.transforms as transforms
15
+
16
+ def contrast_grey(img):
17
+ high = np.percentile(img, 90)
18
+ low = np.percentile(img, 10)
19
+ return (high-low)/(high+low), high, low
20
+
21
+ def adjust_contrast_grey(img, target = 0.4):
22
+ contrast, high, low = contrast_grey(img)
23
+ if contrast < target:
24
+ img = img.astype(int)
25
+ ratio = 200./(high-low)
26
+ img = (img - low + 25)*ratio
27
+ img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8)
28
+ return img
29
+
30
+
31
+ class Batch_Balanced_Dataset(object):
32
+
33
+ def __init__(self, opt):
34
+ """
35
+ Modulate the data ratio in the batch.
36
+ For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5",
37
+ the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST.
38
+ """
39
+ log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
40
+ dashed_line = '-' * 80
41
+ print(dashed_line)
42
+ log.write(dashed_line + '\n')
43
+ print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}')
44
+ log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n')
45
+ assert len(opt.select_data) == len(opt.batch_ratio)
46
+
47
+ _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust = opt.contrast_adjust)
48
+ self.data_loader_list = []
49
+ self.dataloader_iter_list = []
50
+ batch_size_list = []
51
+ Total_batch_size = 0
52
+ for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio):
53
+ _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1)
54
+ print(dashed_line)
55
+ log.write(dashed_line + '\n')
56
+ _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d])
57
+ total_number_dataset = len(_dataset)
58
+ log.write(_dataset_log)
59
+
60
+ """
61
+ The total number of data can be modified with opt.total_data_usage_ratio.
62
+ ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage.
63
+ See 4.2 section in our paper.
64
+ """
65
+ number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio))
66
+ dataset_split = [number_dataset, total_number_dataset - number_dataset]
67
+ indices = range(total_number_dataset)
68
+ _dataset, _ = [Subset(_dataset, indices[offset - length:offset])
69
+ for offset, length in zip(_accumulate(dataset_split), dataset_split)]
70
+ selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n'
71
+ selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}'
72
+ print(selected_d_log)
73
+ log.write(selected_d_log + '\n')
74
+ batch_size_list.append(str(_batch_size))
75
+ Total_batch_size += _batch_size
76
+
77
+ _data_loader = torch.utils.data.DataLoader(
78
+ _dataset, batch_size=_batch_size,
79
+ shuffle=True,
80
+ num_workers=int(opt.workers), #prefetch_factor=2,persistent_workers=True,
81
+ collate_fn=_AlignCollate, pin_memory=True)
82
+ self.data_loader_list.append(_data_loader)
83
+ self.dataloader_iter_list.append(iter(_data_loader))
84
+
85
+ Total_batch_size_log = f'{dashed_line}\n'
86
+ batch_size_sum = '+'.join(batch_size_list)
87
+ Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n'
88
+ Total_batch_size_log += f'{dashed_line}'
89
+ opt.batch_size = Total_batch_size
90
+
91
+ print(Total_batch_size_log)
92
+ log.write(Total_batch_size_log + '\n')
93
+ log.close()
94
+
95
+ def get_batch(self):
96
+ balanced_batch_images = []
97
+ balanced_batch_texts = []
98
+
99
+ for i, data_loader_iter in enumerate(self.dataloader_iter_list):
100
+ try:
101
+ image, text = data_loader_iter.next()
102
+ balanced_batch_images.append(image)
103
+ balanced_batch_texts += text
104
+ except StopIteration:
105
+ self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
106
+ image, text = self.dataloader_iter_list[i].next()
107
+ balanced_batch_images.append(image)
108
+ balanced_batch_texts += text
109
+ except ValueError:
110
+ pass
111
+
112
+ balanced_batch_images = torch.cat(balanced_batch_images, 0)
113
+
114
+ return balanced_batch_images, balanced_batch_texts
115
+
116
+
117
+ def hierarchical_dataset(root, opt, select_data='/'):
118
+ """ select_data='/' contains all sub-directory of root directory """
119
+ dataset_list = []
120
+ dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}'
121
+ print(dataset_log)
122
+ dataset_log += '\n'
123
+ for dirpath, dirnames, filenames in os.walk(root+'/'):
124
+ if not dirnames:
125
+ select_flag = False
126
+ for selected_d in select_data:
127
+ if selected_d in dirpath:
128
+ select_flag = True
129
+ break
130
+
131
+ if select_flag:
132
+ dataset = OCRDataset(dirpath, opt)
133
+ sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}'
134
+ print(sub_dataset_log)
135
+ dataset_log += f'{sub_dataset_log}\n'
136
+ dataset_list.append(dataset)
137
+
138
+ concatenated_dataset = ConcatDataset(dataset_list)
139
+
140
+ return concatenated_dataset, dataset_log
141
+
142
+ class OCRDataset(Dataset):
143
+
144
+ def __init__(self, root, opt):
145
+
146
+ self.root = root
147
+ self.opt = opt
148
+ print(root)
149
+ self.df = pd.read_csv(os.path.join(root,'labels.csv'), sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
150
+ self.nSamples = len(self.df)
151
+
152
+ if self.opt.data_filtering_off:
153
+ self.filtered_index_list = [index + 1 for index in range(self.nSamples)]
154
+ else:
155
+ self.filtered_index_list = []
156
+ for index in range(self.nSamples):
157
+ label = self.df.at[index,'words']
158
+ try:
159
+ if len(label) > self.opt.batch_max_length:
160
+ continue
161
+ except:
162
+ print(label)
163
+ out_of_char = f'[^{self.opt.character}]'
164
+ if re.search(out_of_char, label.lower()):
165
+ continue
166
+ self.filtered_index_list.append(index)
167
+ self.nSamples = len(self.filtered_index_list)
168
+
169
+ def __len__(self):
170
+ return self.nSamples
171
+
172
+ def __getitem__(self, index):
173
+ index = self.filtered_index_list[index]
174
+ img_fname = self.df.at[index,'filename']
175
+ img_fpath = os.path.join(self.root, img_fname)
176
+ label = self.df.at[index,'words']
177
+
178
+ if self.opt.rgb:
179
+ img = Image.open(img_fpath).convert('RGB') # for color image
180
+ else:
181
+ img = Image.open(img_fpath).convert('L')
182
+
183
+ if not self.opt.sensitive:
184
+ label = label.lower()
185
+
186
+ # We only train and evaluate on alphanumerics (or pre-defined character set in train.py)
187
+ out_of_char = f'[^{self.opt.character}]'
188
+ label = re.sub(out_of_char, '', label)
189
+
190
+ return (img, label)
191
+
192
+ class ResizeNormalize(object):
193
+
194
+ def __init__(self, size, interpolation=Image.BICUBIC):
195
+ self.size = size
196
+ self.interpolation = interpolation
197
+ self.toTensor = transforms.ToTensor()
198
+
199
+ def __call__(self, img):
200
+ img = img.resize(self.size, self.interpolation)
201
+ img = self.toTensor(img)
202
+ img.sub_(0.5).div_(0.5)
203
+ return img
204
+
205
+
206
+ class NormalizePAD(object):
207
+
208
+ def __init__(self, max_size, PAD_type='right'):
209
+ self.toTensor = transforms.ToTensor()
210
+ self.max_size = max_size
211
+ self.max_width_half = math.floor(max_size[2] / 2)
212
+ self.PAD_type = PAD_type
213
+
214
+ def __call__(self, img):
215
+ img = self.toTensor(img)
216
+ img.sub_(0.5).div_(0.5)
217
+ c, h, w = img.size()
218
+ Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
219
+ Pad_img[:, :, :w] = img # right pad
220
+ if self.max_size[2] != w: # add border Pad
221
+ Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
222
+
223
+ return Pad_img
224
+
225
+
226
+ class AlignCollate(object):
227
+
228
+ def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, contrast_adjust = 0.):
229
+ self.imgH = imgH
230
+ self.imgW = imgW
231
+ self.keep_ratio_with_pad = keep_ratio_with_pad
232
+ self.contrast_adjust = contrast_adjust
233
+
234
+ def __call__(self, batch):
235
+ batch = filter(lambda x: x is not None, batch)
236
+ images, labels = zip(*batch)
237
+
238
+ if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper
239
+ resized_max_w = self.imgW
240
+ input_channel = 3 if images[0].mode == 'RGB' else 1
241
+ transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
242
+
243
+ resized_images = []
244
+ for image in images:
245
+ w, h = image.size
246
+
247
+ #### augmentation here - change contrast
248
+ if self.contrast_adjust > 0:
249
+ image = np.array(image.convert("L"))
250
+ image = adjust_contrast_grey(image, target = self.contrast_adjust)
251
+ image = Image.fromarray(image, 'L')
252
+
253
+ ratio = w / float(h)
254
+ if math.ceil(self.imgH * ratio) > self.imgW:
255
+ resized_w = self.imgW
256
+ else:
257
+ resized_w = math.ceil(self.imgH * ratio)
258
+
259
+ resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
260
+ resized_images.append(transform(resized_image))
261
+ # resized_image.save('./image_test/%d_test.jpg' % w)
262
+
263
+ image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
264
+
265
+ else:
266
+ transform = ResizeNormalize((self.imgW, self.imgH))
267
+ image_tensors = [transform(image) for image in images]
268
+ image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)
269
+
270
+ return image_tensors, labels
271
+
272
+
273
+ def tensor2im(image_tensor, imtype=np.uint8):
274
+ image_numpy = image_tensor.cpu().float().numpy()
275
+ if image_numpy.shape[0] == 1:
276
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
277
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
278
+ return image_numpy.astype(imtype)
279
+
280
+
281
+ def save_image(image_numpy, image_path):
282
+ image_pil = Image.fromarray(image_numpy)
283
+ image_pil.save(image_path)
trainer/model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from modules.transformation import TPS_SpatialTransformerNetwork
3
+ from modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor
4
+ from modules.sequence_modeling import BidirectionalLSTM
5
+ from modules.prediction import Attention
6
+
7
+ class Model(nn.Module):
8
+
9
+ def __init__(self, opt):
10
+ super(Model, self).__init__()
11
+ self.opt = opt
12
+ self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
13
+ 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}
14
+
15
+ """ Transformation """
16
+ if opt.Transformation == 'TPS':
17
+ self.Transformation = TPS_SpatialTransformerNetwork(
18
+ F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
19
+ else:
20
+ print('No Transformation module specified')
21
+
22
+ """ FeatureExtraction """
23
+ if opt.FeatureExtraction == 'VGG':
24
+ self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
25
+ elif opt.FeatureExtraction == 'RCNN':
26
+ self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel)
27
+ elif opt.FeatureExtraction == 'ResNet':
28
+ self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
29
+ else:
30
+ raise Exception('No FeatureExtraction module specified')
31
+ self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512
32
+ self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
33
+
34
+ """ Sequence modeling"""
35
+ if opt.SequenceModeling == 'BiLSTM':
36
+ self.SequenceModeling = nn.Sequential(
37
+ BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
38
+ BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
39
+ self.SequenceModeling_output = opt.hidden_size
40
+ else:
41
+ print('No SequenceModeling module specified')
42
+ self.SequenceModeling_output = self.FeatureExtraction_output
43
+
44
+ """ Prediction """
45
+ if opt.Prediction == 'CTC':
46
+ self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
47
+ elif opt.Prediction == 'Attn':
48
+ self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class)
49
+ else:
50
+ raise Exception('Prediction is neither CTC or Attn')
51
+
52
+ def forward(self, input, text, is_train=True):
53
+ """ Transformation stage """
54
+ if not self.stages['Trans'] == "None":
55
+ input = self.Transformation(input)
56
+
57
+ """ Feature extraction stage """
58
+ visual_feature = self.FeatureExtraction(input)
59
+ visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
60
+ visual_feature = visual_feature.squeeze(3)
61
+
62
+ """ Sequence modeling stage """
63
+ if self.stages['Seq'] == 'BiLSTM':
64
+ contextual_feature = self.SequenceModeling(visual_feature)
65
+ else:
66
+ contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM
67
+
68
+ """ Prediction stage """
69
+ if self.stages['Pred'] == 'CTC':
70
+ prediction = self.Prediction(contextual_feature.contiguous())
71
+ else:
72
+ prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length)
73
+
74
+ return prediction
trainer/modules/feature_extraction.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class VGG_FeatureExtractor(nn.Module):
6
+ """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """
7
+
8
+ def __init__(self, input_channel, output_channel=512):
9
+ super(VGG_FeatureExtractor, self).__init__()
10
+ self.output_channel = [int(output_channel / 8), int(output_channel / 4),
11
+ int(output_channel / 2), output_channel] # [64, 128, 256, 512]
12
+ self.ConvNet = nn.Sequential(
13
+ nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
14
+ nn.MaxPool2d(2, 2), # 64x16x50
15
+ nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True),
16
+ nn.MaxPool2d(2, 2), # 128x8x25
17
+ nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25
18
+ nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
19
+ nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
20
+ nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False),
21
+ nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25
22
+ nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False),
23
+ nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
24
+ nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
25
+ nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24
26
+
27
+ def forward(self, input):
28
+ return self.ConvNet(input)
29
+
30
+
31
+ class RCNN_FeatureExtractor(nn.Module):
32
+ """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """
33
+
34
+ def __init__(self, input_channel, output_channel=512):
35
+ super(RCNN_FeatureExtractor, self).__init__()
36
+ self.output_channel = [int(output_channel / 8), int(output_channel / 4),
37
+ int(output_channel / 2), output_channel] # [64, 128, 256, 512]
38
+ self.ConvNet = nn.Sequential(
39
+ nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
40
+ nn.MaxPool2d(2, 2), # 64 x 16 x 50
41
+ GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, pad=1),
42
+ nn.MaxPool2d(2, 2), # 64 x 8 x 25
43
+ GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, pad=1),
44
+ nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26
45
+ GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, pad=1),
46
+ nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27
47
+ nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False),
48
+ nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26
49
+
50
+ def forward(self, input):
51
+ return self.ConvNet(input)
52
+
53
+
54
+ class ResNet_FeatureExtractor(nn.Module):
55
+ """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
56
+
57
+ def __init__(self, input_channel, output_channel=512):
58
+ super(ResNet_FeatureExtractor, self).__init__()
59
+ self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3])
60
+
61
+ def forward(self, input):
62
+ return self.ConvNet(input)
63
+
64
+
65
+ # For Gated RCNN
66
+ class GRCL(nn.Module):
67
+
68
+ def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad):
69
+ super(GRCL, self).__init__()
70
+ self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False)
71
+ self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False)
72
+ self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False)
73
+ self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False)
74
+
75
+ self.BN_x_init = nn.BatchNorm2d(output_channel)
76
+
77
+ self.num_iteration = num_iteration
78
+ self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)]
79
+ self.GRCL = nn.Sequential(*self.GRCL)
80
+
81
+ def forward(self, input):
82
+ """ The input of GRCL is consistant over time t, which is denoted by u(0)
83
+ thus wgf_u / wf_u is also consistant over time t.
84
+ """
85
+ wgf_u = self.wgf_u(input)
86
+ wf_u = self.wf_u(input)
87
+ x = F.relu(self.BN_x_init(wf_u))
88
+
89
+ for i in range(self.num_iteration):
90
+ x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x))
91
+
92
+ return x
93
+
94
+
95
+ class GRCL_unit(nn.Module):
96
+
97
+ def __init__(self, output_channel):
98
+ super(GRCL_unit, self).__init__()
99
+ self.BN_gfu = nn.BatchNorm2d(output_channel)
100
+ self.BN_grx = nn.BatchNorm2d(output_channel)
101
+ self.BN_fu = nn.BatchNorm2d(output_channel)
102
+ self.BN_rx = nn.BatchNorm2d(output_channel)
103
+ self.BN_Gx = nn.BatchNorm2d(output_channel)
104
+
105
+ def forward(self, wgf_u, wgr_x, wf_u, wr_x):
106
+ G_first_term = self.BN_gfu(wgf_u)
107
+ G_second_term = self.BN_grx(wgr_x)
108
+ G = F.sigmoid(G_first_term + G_second_term)
109
+
110
+ x_first_term = self.BN_fu(wf_u)
111
+ x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G)
112
+ x = F.relu(x_first_term + x_second_term)
113
+
114
+ return x
115
+
116
+
117
+ class BasicBlock(nn.Module):
118
+ expansion = 1
119
+
120
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
121
+ super(BasicBlock, self).__init__()
122
+ self.conv1 = self._conv3x3(inplanes, planes)
123
+ self.bn1 = nn.BatchNorm2d(planes)
124
+ self.conv2 = self._conv3x3(planes, planes)
125
+ self.bn2 = nn.BatchNorm2d(planes)
126
+ self.relu = nn.ReLU(inplace=True)
127
+ self.downsample = downsample
128
+ self.stride = stride
129
+
130
+ def _conv3x3(self, in_planes, out_planes, stride=1):
131
+ "3x3 convolution with padding"
132
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
133
+ padding=1, bias=False)
134
+
135
+ def forward(self, x):
136
+ residual = x
137
+
138
+ out = self.conv1(x)
139
+ out = self.bn1(out)
140
+ out = self.relu(out)
141
+
142
+ out = self.conv2(out)
143
+ out = self.bn2(out)
144
+
145
+ if self.downsample is not None:
146
+ residual = self.downsample(x)
147
+ out += residual
148
+ out = self.relu(out)
149
+
150
+ return out
151
+
152
+
153
+ class ResNet(nn.Module):
154
+
155
+ def __init__(self, input_channel, output_channel, block, layers):
156
+ super(ResNet, self).__init__()
157
+
158
+ self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
159
+
160
+ self.inplanes = int(output_channel / 8)
161
+ self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
162
+ kernel_size=3, stride=1, padding=1, bias=False)
163
+ self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
164
+ self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
165
+ kernel_size=3, stride=1, padding=1, bias=False)
166
+ self.bn0_2 = nn.BatchNorm2d(self.inplanes)
167
+ self.relu = nn.ReLU(inplace=True)
168
+
169
+ self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
170
+ self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
171
+ self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
172
+ 0], kernel_size=3, stride=1, padding=1, bias=False)
173
+ self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
174
+
175
+ self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
176
+ self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
177
+ self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
178
+ 1], kernel_size=3, stride=1, padding=1, bias=False)
179
+ self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
180
+
181
+ self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
182
+ self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
183
+ self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
184
+ 2], kernel_size=3, stride=1, padding=1, bias=False)
185
+ self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
186
+
187
+ self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
188
+ self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
189
+ 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
190
+ self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
191
+ self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
192
+ 3], kernel_size=2, stride=1, padding=0, bias=False)
193
+ self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
194
+
195
+ def _make_layer(self, block, planes, blocks, stride=1):
196
+ downsample = None
197
+ if stride != 1 or self.inplanes != planes * block.expansion:
198
+ downsample = nn.Sequential(
199
+ nn.Conv2d(self.inplanes, planes * block.expansion,
200
+ kernel_size=1, stride=stride, bias=False),
201
+ nn.BatchNorm2d(planes * block.expansion),
202
+ )
203
+
204
+ layers = []
205
+ layers.append(block(self.inplanes, planes, stride, downsample))
206
+ self.inplanes = planes * block.expansion
207
+ for i in range(1, blocks):
208
+ layers.append(block(self.inplanes, planes))
209
+
210
+ return nn.Sequential(*layers)
211
+
212
+ def forward(self, x):
213
+ x = self.conv0_1(x)
214
+ x = self.bn0_1(x)
215
+ x = self.relu(x)
216
+ x = self.conv0_2(x)
217
+ x = self.bn0_2(x)
218
+ x = self.relu(x)
219
+
220
+ x = self.maxpool1(x)
221
+ x = self.layer1(x)
222
+ x = self.conv1(x)
223
+ x = self.bn1(x)
224
+ x = self.relu(x)
225
+
226
+ x = self.maxpool2(x)
227
+ x = self.layer2(x)
228
+ x = self.conv2(x)
229
+ x = self.bn2(x)
230
+ x = self.relu(x)
231
+
232
+ x = self.maxpool3(x)
233
+ x = self.layer3(x)
234
+ x = self.conv3(x)
235
+ x = self.bn3(x)
236
+ x = self.relu(x)
237
+
238
+ x = self.layer4(x)
239
+ x = self.conv4_1(x)
240
+ x = self.bn4_1(x)
241
+ x = self.relu(x)
242
+ x = self.conv4_2(x)
243
+ x = self.bn4_2(x)
244
+ x = self.relu(x)
245
+
246
+ return x
trainer/modules/prediction.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
5
+
6
+
7
+ class Attention(nn.Module):
8
+
9
+ def __init__(self, input_size, hidden_size, num_classes):
10
+ super(Attention, self).__init__()
11
+ self.attention_cell = AttentionCell(input_size, hidden_size, num_classes)
12
+ self.hidden_size = hidden_size
13
+ self.num_classes = num_classes
14
+ self.generator = nn.Linear(hidden_size, num_classes)
15
+
16
+ def _char_to_onehot(self, input_char, onehot_dim=38):
17
+ input_char = input_char.unsqueeze(1)
18
+ batch_size = input_char.size(0)
19
+ one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device)
20
+ one_hot = one_hot.scatter_(1, input_char, 1)
21
+ return one_hot
22
+
23
+ def forward(self, batch_H, text, is_train=True, batch_max_length=25):
24
+ """
25
+ input:
26
+ batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes]
27
+ text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
28
+ output: probability distribution at each step [batch_size x num_steps x num_classes]
29
+ """
30
+ batch_size = batch_H.size(0)
31
+ num_steps = batch_max_length + 1 # +1 for [s] at end of sentence.
32
+
33
+ output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device)
34
+ hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),
35
+ torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device))
36
+
37
+ if is_train:
38
+ for i in range(num_steps):
39
+ # one-hot vectors for a i-th char. in a batch
40
+ char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes)
41
+ # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1})
42
+ hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
43
+ output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell)
44
+ probs = self.generator(output_hiddens)
45
+
46
+ else:
47
+ targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token
48
+ probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device)
49
+
50
+ for i in range(num_steps):
51
+ char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes)
52
+ hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
53
+ probs_step = self.generator(hidden[0])
54
+ probs[:, i, :] = probs_step
55
+ _, next_input = probs_step.max(1)
56
+ targets = next_input
57
+
58
+ return probs # batch_size x num_steps x num_classes
59
+
60
+
61
+ class AttentionCell(nn.Module):
62
+
63
+ def __init__(self, input_size, hidden_size, num_embeddings):
64
+ super(AttentionCell, self).__init__()
65
+ self.i2h = nn.Linear(input_size, hidden_size, bias=False)
66
+ self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias
67
+ self.score = nn.Linear(hidden_size, 1, bias=False)
68
+ self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
69
+ self.hidden_size = hidden_size
70
+
71
+ def forward(self, prev_hidden, batch_H, char_onehots):
72
+ # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
73
+ batch_H_proj = self.i2h(batch_H)
74
+ prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
75
+ e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
76
+
77
+ alpha = F.softmax(e, dim=1)
78
+ context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
79
+ concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding)
80
+ cur_hidden = self.rnn(concat_context, prev_hidden)
81
+ return cur_hidden, alpha
trainer/modules/sequence_modeling.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class BidirectionalLSTM(nn.Module):
5
+
6
+ def __init__(self, input_size, hidden_size, output_size):
7
+ super(BidirectionalLSTM, self).__init__()
8
+ self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
9
+ self.linear = nn.Linear(hidden_size * 2, output_size)
10
+
11
+ def forward(self, input):
12
+ """
13
+ input : visual feature [batch_size x T x input_size]
14
+ output : contextual feature [batch_size x T x output_size]
15
+ """
16
+ try:
17
+ self.rnn.flatten_parameters()
18
+ except:
19
+ pass
20
+ recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
21
+ output = self.linear(recurrent) # batch_size x T x output_size
22
+ return output
trainer/modules/transformation.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6
+
7
+
8
+ class TPS_SpatialTransformerNetwork(nn.Module):
9
+ """ Rectification Network of RARE, namely TPS based STN """
10
+
11
+ def __init__(self, F, I_size, I_r_size, I_channel_num=1):
12
+ """ Based on RARE TPS
13
+ input:
14
+ batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
15
+ I_size : (height, width) of the input image I
16
+ I_r_size : (height, width) of the rectified image I_r
17
+ I_channel_num : the number of channels of the input image I
18
+ output:
19
+ batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
20
+ """
21
+ super(TPS_SpatialTransformerNetwork, self).__init__()
22
+ self.F = F
23
+ self.I_size = I_size
24
+ self.I_r_size = I_r_size # = (I_r_height, I_r_width)
25
+ self.I_channel_num = I_channel_num
26
+ self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
27
+ self.GridGenerator = GridGenerator(self.F, self.I_r_size)
28
+
29
+ def forward(self, batch_I):
30
+ batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
31
+ build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
32
+ build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
33
+ batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
34
+
35
+ return batch_I_r
36
+
37
+
38
+ class LocalizationNetwork(nn.Module):
39
+ """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """
40
+
41
+ def __init__(self, F, I_channel_num):
42
+ super(LocalizationNetwork, self).__init__()
43
+ self.F = F
44
+ self.I_channel_num = I_channel_num
45
+ self.conv = nn.Sequential(
46
+ nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1,
47
+ bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
48
+ nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
49
+ nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
50
+ nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
51
+ nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
52
+ nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
53
+ nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True),
54
+ nn.AdaptiveAvgPool2d(1) # batch_size x 512
55
+ )
56
+
57
+ self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
58
+ self.localization_fc2 = nn.Linear(256, self.F * 2)
59
+
60
+ # Init fc2 in LocalizationNetwork
61
+ self.localization_fc2.weight.data.fill_(0)
62
+ """ see RARE paper Fig. 6 (a) """
63
+ ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
64
+ ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
65
+ ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
66
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
67
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
68
+ initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
69
+ self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1)
70
+
71
+ def forward(self, batch_I):
72
+ """
73
+ input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
74
+ output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
75
+ """
76
+ batch_size = batch_I.size(0)
77
+ features = self.conv(batch_I).view(batch_size, -1)
78
+ batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2)
79
+ return batch_C_prime
80
+
81
+
82
+ class GridGenerator(nn.Module):
83
+ """ Grid Generator of RARE, which produces P_prime by multiplying T with P """
84
+
85
+ def __init__(self, F, I_r_size):
86
+ """ Generate P_hat and inv_delta_C for later """
87
+ super(GridGenerator, self).__init__()
88
+ self.eps = 1e-6
89
+ self.I_r_height, self.I_r_width = I_r_size
90
+ self.F = F
91
+ self.C = self._build_C(self.F) # F x 2
92
+ self.P = self._build_P(self.I_r_width, self.I_r_height)
93
+ ## for multi-gpu, you need register buffer
94
+ self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3
95
+ self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3
96
+ ## for fine-tuning with different image width, you may use below instead of self.register_buffer
97
+ #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3
98
+ #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3
99
+
100
+ def _build_C(self, F):
101
+ """ Return coordinates of fiducial points in I_r; C """
102
+ ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
103
+ ctrl_pts_y_top = -1 * np.ones(int(F / 2))
104
+ ctrl_pts_y_bottom = np.ones(int(F / 2))
105
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
106
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
107
+ C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
108
+ return C # F x 2
109
+
110
+ def _build_inv_delta_C(self, F, C):
111
+ """ Return inv_delta_C which is needed to calculate T """
112
+ hat_C = np.zeros((F, F), dtype=float) # F x F
113
+ for i in range(0, F):
114
+ for j in range(i, F):
115
+ r = np.linalg.norm(C[i] - C[j])
116
+ hat_C[i, j] = r
117
+ hat_C[j, i] = r
118
+ np.fill_diagonal(hat_C, 1)
119
+ hat_C = (hat_C ** 2) * np.log(hat_C)
120
+ # print(C.shape, hat_C.shape)
121
+ delta_C = np.concatenate( # F+3 x F+3
122
+ [
123
+ np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
124
+ np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
125
+ np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3
126
+ ],
127
+ axis=0
128
+ )
129
+ inv_delta_C = np.linalg.inv(delta_C)
130
+ return inv_delta_C # F+3 x F+3
131
+
132
+ def _build_P(self, I_r_width, I_r_height):
133
+ I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width
134
+ I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height
135
+ P = np.stack( # self.I_r_width x self.I_r_height x 2
136
+ np.meshgrid(I_r_grid_x, I_r_grid_y),
137
+ axis=2
138
+ )
139
+ return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
140
+
141
+ def _build_P_hat(self, F, C, P):
142
+ n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
143
+ P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2
144
+ C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
145
+ P_diff = P_tile - C_tile # n x F x 2
146
+ rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
147
+ rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F
148
+ P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
149
+ return P_hat # n x F+3
150
+
151
+ def build_P_prime(self, batch_C_prime):
152
+ """ Generate Grid from batch_C_prime [batch_size x F x 2] """
153
+ batch_size = batch_C_prime.size(0)
154
+ batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
155
+ batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
156
+ batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
157
+ batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2
158
+ batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2
159
+ batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
160
+ return batch_P_prime # batch_size x n x 2
trainer/saved_models/folder.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ trained model will be saved here
trainer/test.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import string
4
+ import argparse
5
+
6
+ import torch
7
+ import torch.backends.cudnn as cudnn
8
+ import torch.utils.data
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ from nltk.metrics.distance import edit_distance
12
+
13
+ from utils import CTCLabelConverter, AttnLabelConverter, Averager
14
+ from dataset import hierarchical_dataset, AlignCollate
15
+ from model import Model
16
+
17
+ def validation(model, criterion, evaluation_loader, converter, opt, device):
18
+ """ validation or evaluation """
19
+ n_correct = 0
20
+ norm_ED = 0
21
+ length_of_data = 0
22
+ infer_time = 0
23
+ valid_loss_avg = Averager()
24
+
25
+ for i, (image_tensors, labels) in enumerate(evaluation_loader):
26
+ batch_size = image_tensors.size(0)
27
+ length_of_data = length_of_data + batch_size
28
+ image = image_tensors.to(device)
29
+ # For max length prediction
30
+ length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
31
+ text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
32
+
33
+ text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
34
+
35
+ start_time = time.time()
36
+ if 'CTC' in opt.Prediction:
37
+ preds = model(image, text_for_pred)
38
+ forward_time = time.time() - start_time
39
+
40
+ # Calculate evaluation loss for CTC decoder.
41
+ preds_size = torch.IntTensor([preds.size(1)] * batch_size)
42
+ # permute 'preds' to use CTCloss format
43
+ cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
44
+
45
+ if opt.decode == 'greedy':
46
+ # Select max probabilty (greedy decoding) then decode index to character
47
+ _, preds_index = preds.max(2)
48
+ preds_index = preds_index.view(-1)
49
+ preds_str = converter.decode_greedy(preds_index.data, preds_size.data)
50
+ elif opt.decode == 'beamsearch':
51
+ preds_str = converter.decode_beamsearch(preds, beamWidth=2)
52
+
53
+ else:
54
+ preds = model(image, text_for_pred, is_train=False)
55
+ forward_time = time.time() - start_time
56
+
57
+ preds = preds[:, :text_for_loss.shape[1] - 1, :]
58
+ target = text_for_loss[:, 1:] # without [GO] Symbol
59
+ cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))
60
+
61
+ # select max probabilty (greedy decoding) then decode index to character
62
+ _, preds_index = preds.max(2)
63
+ preds_str = converter.decode(preds_index, length_for_pred)
64
+ labels = converter.decode(text_for_loss[:, 1:], length_for_loss)
65
+
66
+ infer_time += forward_time
67
+ valid_loss_avg.add(cost)
68
+
69
+ # calculate accuracy & confidence score
70
+ preds_prob = F.softmax(preds, dim=2)
71
+ preds_max_prob, _ = preds_prob.max(dim=2)
72
+ confidence_score_list = []
73
+
74
+ for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
75
+ if 'Attn' in opt.Prediction:
76
+ gt = gt[:gt.find('[s]')]
77
+ pred_EOS = pred.find('[s]')
78
+ pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
79
+ pred_max_prob = pred_max_prob[:pred_EOS]
80
+
81
+ if pred == gt:
82
+ n_correct += 1
83
+
84
+ '''
85
+ (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
86
+ "For each word we calculate the normalized edit distance to the length of the ground truth transcription."
87
+ if len(gt) == 0:
88
+ norm_ED += 1
89
+ else:
90
+ norm_ED += edit_distance(pred, gt) / len(gt)
91
+ '''
92
+
93
+ # ICDAR2019 Normalized Edit Distance
94
+ if len(gt) == 0 or len(pred) ==0:
95
+ norm_ED += 0
96
+ elif len(gt) > len(pred):
97
+ norm_ED += 1 - edit_distance(pred, gt) / len(gt)
98
+ else:
99
+ norm_ED += 1 - edit_distance(pred, gt) / len(pred)
100
+
101
+ # calculate confidence score (= multiply of pred_max_prob)
102
+ try:
103
+ confidence_score = pred_max_prob.cumprod(dim=0)[-1]
104
+ except:
105
+ confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s])
106
+ confidence_score_list.append(confidence_score)
107
+ # print(pred, gt, pred==gt, confidence_score)
108
+
109
+ accuracy = n_correct / float(length_of_data) * 100
110
+ norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
111
+
112
+ return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
trainer/train.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import random
5
+ import torch
6
+ import torch.backends.cudnn as cudnn
7
+ import torch.nn as nn
8
+ import torch.nn.init as init
9
+ import torch.optim as optim
10
+ import torch.utils.data
11
+ from torch.cuda.amp import autocast, GradScaler
12
+ import numpy as np
13
+
14
+ from utils import CTCLabelConverter, AttnLabelConverter, Averager
15
+ from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
16
+ from model import Model
17
+ from test import validation
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+
20
+ def count_parameters(model):
21
+ print("Modules, Parameters")
22
+ total_params = 0
23
+ for name, parameter in model.named_parameters():
24
+ if not parameter.requires_grad: continue
25
+ param = parameter.numel()
26
+ #table.add_row([name, param])
27
+ total_params+=param
28
+ print(name, param)
29
+ print(f"Total Trainable Params: {total_params}")
30
+ return total_params
31
+
32
+ def train(opt, show_number = 2, amp=False):
33
+ """ dataset preparation """
34
+ if not opt.data_filtering_off:
35
+ print('Filtering the images containing characters which are not in opt.character')
36
+ print('Filtering the images whose label is longer than opt.batch_max_length')
37
+
38
+ opt.select_data = opt.select_data.split('-')
39
+ opt.batch_ratio = opt.batch_ratio.split('-')
40
+ train_dataset = Batch_Balanced_Dataset(opt)
41
+
42
+ log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a', encoding="utf8")
43
+ AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust=opt.contrast_adjust)
44
+ valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
45
+ valid_loader = torch.utils.data.DataLoader(
46
+ valid_dataset, batch_size=min(32, opt.batch_size),
47
+ shuffle=True, # 'True' to check training progress with validation function.
48
+ num_workers=int(opt.workers), prefetch_factor=512,
49
+ collate_fn=AlignCollate_valid, pin_memory=True)
50
+ log.write(valid_dataset_log)
51
+ print('-' * 80)
52
+ log.write('-' * 80 + '\n')
53
+ log.close()
54
+
55
+ """ model configuration """
56
+ if 'CTC' in opt.Prediction:
57
+ converter = CTCLabelConverter(opt.character)
58
+ else:
59
+ converter = AttnLabelConverter(opt.character)
60
+ opt.num_class = len(converter.character)
61
+
62
+ if opt.rgb:
63
+ opt.input_channel = 3
64
+ model = Model(opt)
65
+ print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
66
+ opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
67
+ opt.SequenceModeling, opt.Prediction)
68
+
69
+ if opt.saved_model != '':
70
+ pretrained_dict = torch.load(opt.saved_model)
71
+ if opt.new_prediction:
72
+ model.Prediction = nn.Linear(model.SequenceModeling_output, len(pretrained_dict['module.Prediction.weight']))
73
+
74
+ model = torch.nn.DataParallel(model).to(device)
75
+ print(f'loading pretrained model from {opt.saved_model}')
76
+ if opt.FT:
77
+ model.load_state_dict(pretrained_dict, strict=False)
78
+ else:
79
+ model.load_state_dict(pretrained_dict)
80
+ if opt.new_prediction:
81
+ model.module.Prediction = nn.Linear(model.module.SequenceModeling_output, opt.num_class)
82
+ for name, param in model.module.Prediction.named_parameters():
83
+ if 'bias' in name:
84
+ init.constant_(param, 0.0)
85
+ elif 'weight' in name:
86
+ init.kaiming_normal_(param)
87
+ model = model.to(device)
88
+ else:
89
+ # weight initialization
90
+ for name, param in model.named_parameters():
91
+ if 'localization_fc2' in name:
92
+ print(f'Skip {name} as it is already initialized')
93
+ continue
94
+ try:
95
+ if 'bias' in name:
96
+ init.constant_(param, 0.0)
97
+ elif 'weight' in name:
98
+ init.kaiming_normal_(param)
99
+ except Exception as e: # for batchnorm.
100
+ if 'weight' in name:
101
+ param.data.fill_(1)
102
+ continue
103
+ model = torch.nn.DataParallel(model).to(device)
104
+
105
+ model.train()
106
+ print("Model:")
107
+ print(model)
108
+ count_parameters(model)
109
+
110
+ """ setup loss """
111
+ if 'CTC' in opt.Prediction:
112
+ criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
113
+ else:
114
+ criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
115
+ # loss averager
116
+ loss_avg = Averager()
117
+
118
+ # freeze some layers
119
+ try:
120
+ if opt.freeze_FeatureFxtraction:
121
+ for param in model.module.FeatureExtraction.parameters():
122
+ param.requires_grad = False
123
+ if opt.freeze_SequenceModeling:
124
+ for param in model.module.SequenceModeling.parameters():
125
+ param.requires_grad = False
126
+ except:
127
+ pass
128
+
129
+ # filter that only require gradient decent
130
+ filtered_parameters = []
131
+ params_num = []
132
+ for p in filter(lambda p: p.requires_grad, model.parameters()):
133
+ filtered_parameters.append(p)
134
+ params_num.append(np.prod(p.size()))
135
+ print('Trainable params num : ', sum(params_num))
136
+ # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]
137
+
138
+ # setup optimizer
139
+ if opt.optim=='adam':
140
+ #optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
141
+ optimizer = optim.Adam(filtered_parameters)
142
+ else:
143
+ optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
144
+ print("Optimizer:")
145
+ print(optimizer)
146
+
147
+ """ final options """
148
+ # print(opt)
149
+ with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf8") as opt_file:
150
+ opt_log = '------------ Options -------------\n'
151
+ args = vars(opt)
152
+ for k, v in args.items():
153
+ opt_log += f'{str(k)}: {str(v)}\n'
154
+ opt_log += '---------------------------------------\n'
155
+ print(opt_log)
156
+ opt_file.write(opt_log)
157
+
158
+ """ start training """
159
+ start_iter = 0
160
+ if opt.saved_model != '':
161
+ try:
162
+ start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
163
+ print(f'continue to train, start_iter: {start_iter}')
164
+ except:
165
+ pass
166
+
167
+ start_time = time.time()
168
+ best_accuracy = -1
169
+ best_norm_ED = -1
170
+ i = start_iter
171
+
172
+ scaler = GradScaler()
173
+ t1= time.time()
174
+
175
+ while(True):
176
+ # train part
177
+ optimizer.zero_grad(set_to_none=True)
178
+
179
+ if amp:
180
+ with autocast():
181
+ image_tensors, labels = train_dataset.get_batch()
182
+ image = image_tensors.to(device)
183
+ text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
184
+ batch_size = image.size(0)
185
+
186
+ if 'CTC' in opt.Prediction:
187
+ preds = model(image, text).log_softmax(2)
188
+ preds_size = torch.IntTensor([preds.size(1)] * batch_size)
189
+ preds = preds.permute(1, 0, 2)
190
+ torch.backends.cudnn.enabled = False
191
+ cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
192
+ torch.backends.cudnn.enabled = True
193
+ else:
194
+ preds = model(image, text[:, :-1]) # align with Attention.forward
195
+ target = text[:, 1:] # without [GO] Symbol
196
+ cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
197
+ scaler.scale(cost).backward()
198
+ scaler.unscale_(optimizer)
199
+ torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
200
+ scaler.step(optimizer)
201
+ scaler.update()
202
+ else:
203
+ image_tensors, labels = train_dataset.get_batch()
204
+ image = image_tensors.to(device)
205
+ text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
206
+ batch_size = image.size(0)
207
+ if 'CTC' in opt.Prediction:
208
+ preds = model(image, text).log_softmax(2)
209
+ preds_size = torch.IntTensor([preds.size(1)] * batch_size)
210
+ preds = preds.permute(1, 0, 2)
211
+ torch.backends.cudnn.enabled = False
212
+ cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
213
+ torch.backends.cudnn.enabled = True
214
+ else:
215
+ preds = model(image, text[:, :-1]) # align with Attention.forward
216
+ target = text[:, 1:] # without [GO] Symbol
217
+ cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
218
+ cost.backward()
219
+ torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
220
+ optimizer.step()
221
+ loss_avg.add(cost)
222
+
223
+ # validation part
224
+ if (i % opt.valInterval == 0) and (i!=0):
225
+ print('training time: ', time.time()-t1)
226
+ t1=time.time()
227
+ elapsed_time = time.time() - start_time
228
+ # for log
229
+ with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf8") as log:
230
+ model.eval()
231
+ with torch.no_grad():
232
+ valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels,\
233
+ infer_time, length_of_data = validation(model, criterion, valid_loader, converter, opt, device)
234
+ model.train()
235
+
236
+ # training loss and validation loss
237
+ loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
238
+ loss_avg.reset()
239
+
240
+ current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.4f}'
241
+
242
+ # keep best accuracy model (on valid dataset)
243
+ if current_accuracy > best_accuracy:
244
+ best_accuracy = current_accuracy
245
+ torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
246
+ if current_norm_ED > best_norm_ED:
247
+ best_norm_ED = current_norm_ED
248
+ torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
249
+ best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.4f}'
250
+
251
+ loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
252
+ print(loss_model_log)
253
+ log.write(loss_model_log + '\n')
254
+
255
+ # show some predicted results
256
+ dashed_line = '-' * 80
257
+ head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
258
+ predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
259
+
260
+ #show_number = min(show_number, len(labels))
261
+
262
+ start = random.randint(0,len(labels) - show_number )
263
+ for gt, pred, confidence in zip(labels[start:start+show_number], preds[start:start+show_number], confidence_score[start:start+show_number]):
264
+ if 'Attn' in opt.Prediction:
265
+ gt = gt[:gt.find('[s]')]
266
+ pred = pred[:pred.find('[s]')]
267
+
268
+ predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
269
+ predicted_result_log += f'{dashed_line}'
270
+ print(predicted_result_log)
271
+ log.write(predicted_result_log + '\n')
272
+ print('validation time: ', time.time()-t1)
273
+ t1=time.time()
274
+ # save model per 1e+4 iter.
275
+ if (i + 1) % 1e+4 == 0:
276
+ torch.save(
277
+ model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')
278
+
279
+ if i == opt.num_iter:
280
+ print('end the training')
281
+ sys.exit()
282
+ i += 1