Commit ·
b4959be
1
Parent(s): cc174d2
Initial commit of EasyOCR model
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- LICENSE +201 -0
- MANIFEST.in +8 -0
- README.md +178 -0
- custom_model.md +24 -0
- model.py +24 -0
- requirements.txt +12 -0
- scripts/.gitignore +2 -0
- scripts/generate-ja.rb +55 -0
- setup.cfg +2 -0
- setup.py +35 -0
- trainer/README.md +3 -0
- trainer/all_data/folder.txt +1 -0
- trainer/config_files/en_filtered_config.yaml +45 -0
- trainer/craft/.gitignore +4 -0
- trainer/craft/README.md +105 -0
- trainer/craft/config/__init__.py +0 -0
- trainer/craft/config/custom_data_train.yaml +100 -0
- trainer/craft/config/load_config.py +37 -0
- trainer/craft/config/syn_train.yaml +68 -0
- trainer/craft/data/boxEnlarge.py +65 -0
- trainer/craft/data/dataset.py +542 -0
- trainer/craft/data/gaussian.py +192 -0
- trainer/craft/data/imgaug.py +175 -0
- trainer/craft/data/imgproc.py +91 -0
- trainer/craft/data/pseudo_label/make_charbox.py +263 -0
- trainer/craft/data/pseudo_label/watershed.py +45 -0
- trainer/craft/data_root_dir/folder.txt +1 -0
- trainer/craft/eval.py +381 -0
- trainer/craft/exp/folder.txt +1 -0
- trainer/craft/loss/mseloss.py +172 -0
- trainer/craft/metrics/eval_det_iou.py +244 -0
- trainer/craft/model/craft.py +112 -0
- trainer/craft/model/vgg16_bn.py +77 -0
- trainer/craft/requirements.txt +10 -0
- trainer/craft/scripts/run_cde.sh +7 -0
- trainer/craft/train.py +479 -0
- trainer/craft/trainSynth.py +409 -0
- trainer/craft/train_distributed.py +523 -0
- trainer/craft/utils/craft_utils.py +345 -0
- trainer/craft/utils/inference_boxes.py +361 -0
- trainer/craft/utils/util.py +142 -0
- trainer/dataset.py +283 -0
- trainer/model.py +74 -0
- trainer/modules/feature_extraction.py +246 -0
- trainer/modules/prediction.py +81 -0
- trainer/modules/sequence_modeling.py +22 -0
- trainer/modules/transformation.py +160 -0
- trainer/saved_models/folder.txt +1 -0
- trainer/test.py +112 -0
- trainer/train.py +282 -0
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
MANIFEST.in
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include LICENSE.txt
|
| 2 |
+
include README.md
|
| 3 |
+
|
| 4 |
+
include easyocr/model/*
|
| 5 |
+
include easyocr/character/*
|
| 6 |
+
include easyocr/dict/*
|
| 7 |
+
include easyocr/scripts/compile_dbnet_dcn.py
|
| 8 |
+
recursive-include easyocr/DBNet *
|
README.md
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EasyOCR
|
| 2 |
+
|
| 3 |
+
[](https://badge.fury.io/py/easyocr)
|
| 4 |
+
[](https://github.com/JaidedAI/EasyOCR/blob/master/LICENSE)
|
| 5 |
+
[](https://colab.to/easyocr)
|
| 6 |
+
[](https://twitter.com/intent/tweet?text=Check%20out%20this%20awesome%20library:%20EasyOCR%20https://github.com/JaidedAI/EasyOCR)
|
| 7 |
+
[](https://twitter.com/JaidedAI)
|
| 8 |
+
|
| 9 |
+
Ready-to-use OCR with 80+ [supported languages](https://www.jaided.ai/easyocr) and all popular writing scripts including: Latin, Chinese, Arabic, Devanagari, Cyrillic, etc.
|
| 10 |
+
|
| 11 |
+
[Try Demo on our website](https://www.jaided.ai/easyocr)
|
| 12 |
+
|
| 13 |
+
Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [](https://huggingface.co/spaces/tomofi/EasyOCR)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
## What's new
|
| 17 |
+
- 24 September 2024 - Version 1.7.2
|
| 18 |
+
- Fix several compatibilities
|
| 19 |
+
|
| 20 |
+
- [Read all release notes](https://github.com/JaidedAI/EasyOCR/blob/master/releasenotes.md)
|
| 21 |
+
|
| 22 |
+
## What's coming next
|
| 23 |
+
- Handwritten text support
|
| 24 |
+
|
| 25 |
+
## Examples
|
| 26 |
+
|
| 27 |
+

|
| 28 |
+
|
| 29 |
+

|
| 30 |
+
|
| 31 |
+

|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
## Installation
|
| 35 |
+
|
| 36 |
+
Install using `pip`
|
| 37 |
+
|
| 38 |
+
For the latest stable release:
|
| 39 |
+
|
| 40 |
+
``` bash
|
| 41 |
+
pip install easyocr
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
For the latest development release:
|
| 45 |
+
|
| 46 |
+
``` bash
|
| 47 |
+
pip install git+https://github.com/JaidedAI/EasyOCR.git
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
Note 1: For Windows, please install torch and torchvision first by following the official instructions here https://pytorch.org. On the pytorch website, be sure to select the right CUDA version you have. If you intend to run on CPU mode only, select `CUDA = None`.
|
| 51 |
+
|
| 52 |
+
Note 2: We also provide a Dockerfile [here](https://github.com/JaidedAI/EasyOCR/blob/master/Dockerfile).
|
| 53 |
+
|
| 54 |
+
## Usage
|
| 55 |
+
|
| 56 |
+
``` python
|
| 57 |
+
import easyocr
|
| 58 |
+
reader = easyocr.Reader(['ch_sim','en']) # this needs to run only once to load the model into memory
|
| 59 |
+
result = reader.readtext('chinese.jpg')
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
The output will be in a list format, each item represents a bounding box, the text detected and confident level, respectively.
|
| 63 |
+
|
| 64 |
+
``` bash
|
| 65 |
+
[([[189, 75], [469, 75], [469, 165], [189, 165]], '愚园路', 0.3754989504814148),
|
| 66 |
+
([[86, 80], [134, 80], [134, 128], [86, 128]], '西', 0.40452659130096436),
|
| 67 |
+
([[517, 81], [565, 81], [565, 123], [517, 123]], '东', 0.9989598989486694),
|
| 68 |
+
([[78, 126], [136, 126], [136, 156], [78, 156]], '315', 0.8125889301300049),
|
| 69 |
+
([[514, 126], [574, 126], [574, 156], [514, 156]], '309', 0.4971577227115631),
|
| 70 |
+
([[226, 170], [414, 170], [414, 220], [226, 220]], 'Yuyuan Rd.', 0.8261902332305908),
|
| 71 |
+
([[79, 173], [125, 173], [125, 213], [79, 213]], 'W', 0.9848111271858215),
|
| 72 |
+
([[529, 173], [569, 173], [569, 213], [529, 213]], 'E', 0.8405593633651733)]
|
| 73 |
+
```
|
| 74 |
+
Note 1: `['ch_sim','en']` is the list of languages you want to read. You can pass
|
| 75 |
+
several languages at once but not all languages can be used together.
|
| 76 |
+
English is compatible with every language and languages that share common characters are usually compatible with each other.
|
| 77 |
+
|
| 78 |
+
Note 2: Instead of the filepath `chinese.jpg`, you can also pass an OpenCV image object (numpy array) or an image file as bytes. A URL to a raw image is also acceptable.
|
| 79 |
+
|
| 80 |
+
Note 3: The line `reader = easyocr.Reader(['ch_sim','en'])` is for loading a model into memory. It takes some time but it needs to be run only once.
|
| 81 |
+
|
| 82 |
+
You can also set `detail=0` for simpler output.
|
| 83 |
+
|
| 84 |
+
``` python
|
| 85 |
+
reader.readtext('chinese.jpg', detail = 0)
|
| 86 |
+
```
|
| 87 |
+
Result:
|
| 88 |
+
``` bash
|
| 89 |
+
['愚园路', '西', '东', '315', '309', 'Yuyuan Rd.', 'W', 'E']
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Model weights for the chosen language will be automatically downloaded or you can
|
| 93 |
+
download them manually from the [model hub](https://www.jaided.ai/easyocr/modelhub) and put them in the '~/.EasyOCR/model' folder
|
| 94 |
+
|
| 95 |
+
In case you do not have a GPU, or your GPU has low memory, you can run the model in CPU-only mode by adding `gpu=False`.
|
| 96 |
+
|
| 97 |
+
``` python
|
| 98 |
+
reader = easyocr.Reader(['ch_sim','en'], gpu=False)
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
For more information, read the [tutorial](https://www.jaided.ai/easyocr/tutorial) and [API Documentation](https://www.jaided.ai/easyocr/documentation).
|
| 102 |
+
|
| 103 |
+
#### Run on command line
|
| 104 |
+
|
| 105 |
+
```shell
|
| 106 |
+
$ easyocr -l ch_sim en -f chinese.jpg --detail=1 --gpu=True
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
## Train/use your own model
|
| 110 |
+
|
| 111 |
+
For recognition model, [Read here](https://github.com/JaidedAI/EasyOCR/blob/master/custom_model.md).
|
| 112 |
+
|
| 113 |
+
For detection model (CRAFT), [Read here](https://github.com/JaidedAI/EasyOCR/blob/master/trainer/craft/README.md).
|
| 114 |
+
|
| 115 |
+
## Implementation Roadmap
|
| 116 |
+
|
| 117 |
+
- Handwritten support
|
| 118 |
+
- Restructure code to support swappable detection and recognition algorithms
|
| 119 |
+
The api should be as easy as
|
| 120 |
+
``` python
|
| 121 |
+
reader = easyocr.Reader(['en'], detection='DB', recognition = 'Transformer')
|
| 122 |
+
```
|
| 123 |
+
The idea is to be able to plug in any state-of-the-art model into EasyOCR. There are a lot of geniuses trying to make better detection/recognition models, but we are not trying to be geniuses here. We just want to make their works quickly accessible to the public ... for free. (well, we believe most geniuses want their work to create a positive impact as fast/big as possible) The pipeline should be something like the below diagram. Grey slots are placeholders for changeable light blue modules.
|
| 124 |
+
|
| 125 |
+

|
| 126 |
+
|
| 127 |
+
## Acknowledgement and References
|
| 128 |
+
|
| 129 |
+
This project is based on research and code from several papers and open-source repositories.
|
| 130 |
+
|
| 131 |
+
All deep learning execution is based on [Pytorch](https://pytorch.org). :heart:
|
| 132 |
+
|
| 133 |
+
Detection execution uses the CRAFT algorithm from this [official repository](https://github.com/clovaai/CRAFT-pytorch) and their [paper](https://arxiv.org/abs/1904.01941) (Thanks @YoungminBaek from [@clovaai](https://github.com/clovaai)). We also use their pretrained model. Training script is provided by [@gmuffiness](https://github.com/gmuffiness).
|
| 134 |
+
|
| 135 |
+
The recognition model is a CRNN ([paper](https://arxiv.org/abs/1507.05717)). It is composed of 3 main components: feature extraction (we are currently using [Resnet](https://arxiv.org/abs/1512.03385)) and VGG, sequence labeling ([LSTM](https://www.bioinf.jku.at/publications/older/2604.pdf)) and decoding ([CTC](https://www.cs.toronto.edu/~graves/icml_2006.pdf)). The training pipeline for recognition execution is a modified version of the [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark) framework. (Thanks [@ku21fan](https://github.com/ku21fan) from [@clovaai](https://github.com/clovaai)) This repository is a gem that deserves more recognition.
|
| 136 |
+
|
| 137 |
+
Beam search code is based on this [repository](https://github.com/githubharald/CTCDecoder) and his [blog](https://towardsdatascience.com/beam-search-decoding-in-ctc-trained-neural-networks-5a889a3d85a7). (Thanks [@githubharald](https://github.com/githubharald))
|
| 138 |
+
|
| 139 |
+
Data synthesis is based on [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator). (Thanks [@Belval](https://github.com/Belval))
|
| 140 |
+
|
| 141 |
+
And a good read about CTC from distill.pub [here](https://distill.pub/2017/ctc/).
|
| 142 |
+
|
| 143 |
+
## Want To Contribute?
|
| 144 |
+
|
| 145 |
+
Let's advance humanity together by making AI available to everyone!
|
| 146 |
+
|
| 147 |
+
3 ways to contribute:
|
| 148 |
+
|
| 149 |
+
**Coder:** Please send a PR for small bugs/improvements. For bigger ones, discuss with us by opening an issue first. There is a list of possible bug/improvement issues tagged with ['PR WELCOME'](https://github.com/JaidedAI/EasyOCR/issues?q=is%3Aissue+is%3Aopen+label%3A%22PR+WELCOME%22).
|
| 150 |
+
|
| 151 |
+
**User:** Tell us how EasyOCR benefits you/your organization to encourage further development. Also post failure cases in [Issue Section](https://github.com/JaidedAI/EasyOCR/issues) to help improve future models.
|
| 152 |
+
|
| 153 |
+
**Tech leader/Guru:** If you found this library useful, please spread the word! (See [Yann Lecun's post](https://www.facebook.com/yann.lecun/posts/10157018122787143) about EasyOCR)
|
| 154 |
+
|
| 155 |
+
## Guideline for new language request
|
| 156 |
+
|
| 157 |
+
To request a new language, we need you to send a PR with the 2 following files:
|
| 158 |
+
|
| 159 |
+
1. In folder [easyocr/character](https://github.com/JaidedAI/EasyOCR/tree/master/easyocr/character),
|
| 160 |
+
we need 'yourlanguagecode_char.txt' that contains list of all characters. Please see format examples from other files in that folder.
|
| 161 |
+
2. In folder [easyocr/dict](https://github.com/JaidedAI/EasyOCR/tree/master/easyocr/dict),
|
| 162 |
+
we need 'yourlanguagecode.txt' that contains list of words in your language.
|
| 163 |
+
On average, we have ~30000 words per language with more than 50000 words for more popular ones.
|
| 164 |
+
More is better in this file.
|
| 165 |
+
|
| 166 |
+
If your language has unique elements (such as 1. Arabic: characters change form when attached to each other + write from right to left 2. Thai: Some characters need to be above the line and some below), please educate us to the best of your ability and/or give useful links. It is important to take care of the detail to achieve a system that really works.
|
| 167 |
+
|
| 168 |
+
Lastly, please understand that our priority will have to go to popular languages or sets of languages that share large portions of their characters with each other (also tell us if this is the case for your language). It takes us at least a week to develop a new model, so you may have to wait a while for the new model to be released.
|
| 169 |
+
|
| 170 |
+
See [List of languages in development](https://github.com/JaidedAI/EasyOCR/issues/91)
|
| 171 |
+
|
| 172 |
+
## Github Issues
|
| 173 |
+
|
| 174 |
+
Due to limited resources, an issue older than 6 months will be automatically closed. Please open an issue again if it is critical.
|
| 175 |
+
|
| 176 |
+
## Business Inquiries
|
| 177 |
+
|
| 178 |
+
For Enterprise Support, [Jaided AI](https://www.jaided.ai/) offers full service for custom OCR/AI systems from implementation, training/finetuning and deployment. Click [here](https://www.jaided.ai/contactus?ref=github) to contact us.
|
custom_model.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Custom recognition models
|
| 2 |
+
|
| 3 |
+
## How to train your custom model
|
| 4 |
+
|
| 5 |
+
You can use your own data or generate your own dataset. To generate your own data, we recommend using
|
| 6 |
+
[TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator). We provide an example of a dataset [here](https://jaided.ai/easyocr/modelhub/).
|
| 7 |
+
After you have a dataset, you can train your own model by following this repository
|
| 8 |
+
[deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark).
|
| 9 |
+
The network needs to be fully convolutional in order to predict flexible text length. Our current network is 'None-VGG-BiLSTM-CTC'.
|
| 10 |
+
Once you have your trained model (a `.pth` file), you need 2 additional files describing recognition network architecture and model configuration.
|
| 11 |
+
An example is provided in `custom_example.zip` file [here](https://jaided.ai/easyocr/modelhub/).
|
| 12 |
+
|
| 13 |
+
Please do not create an issue about data generation and model training in this repository. If you have any question regarding data generation and model training, please ask in the respective repositories.
|
| 14 |
+
|
| 15 |
+
Note: We also provide our version of a training script [here](https://github.com/JaidedAI/EasyOCR/tree/master/trainer). It is a modified version from [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark).
|
| 16 |
+
|
| 17 |
+
## How to use your custom model
|
| 18 |
+
|
| 19 |
+
To use your own recognition model, you need the three files as explained above. These three files have to share the same name (i.e. `yourmodel.pth`, `yourmodel.yaml`, `yourmodel.py`) that you will then use to call your model with EasyOCR API.
|
| 20 |
+
|
| 21 |
+
We provide [custom_example.zip](https://jaided.ai/easyocr/modelhub/)
|
| 22 |
+
as an example. Please download, extract and place `custom_example.py`, `custom_example.yaml` in the `user_network_directory` (default = `~/.EasyOCR/user_network`) and place `custom_example.pth` in model directory (default = `~/.EasyOCR/model`)
|
| 23 |
+
Once you place all 3 files in their respective places, you can use `custom_example` by
|
| 24 |
+
specifying `recog_network` like this `reader = easyocr.Reader(['en'], recog_network='custom_example')`.
|
model.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import easyocr
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
class EasyOCRModel:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
self.reader = easyocr.Reader(['en']) # Initialize with English; add languages if needed.
|
| 7 |
+
|
| 8 |
+
def predict(self, image_path: str) -> List[str]:
|
| 9 |
+
"""
|
| 10 |
+
Perform OCR on the given image.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
image_path (str): Path to the input image.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
List[str]: Extracted text from the image.
|
| 17 |
+
"""
|
| 18 |
+
return self.reader.readtext(image_path, detail=0)
|
| 19 |
+
|
| 20 |
+
# Test the model locally
|
| 21 |
+
if __name__ == "__main__":
|
| 22 |
+
model = EasyOCRModel()
|
| 23 |
+
result = model.predict("sample_image.jpg")
|
| 24 |
+
print(result)
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision>=0.5
|
| 3 |
+
opencv-python-headless
|
| 4 |
+
scipy
|
| 5 |
+
numpy
|
| 6 |
+
Pillow
|
| 7 |
+
scikit-image
|
| 8 |
+
python-bidi
|
| 9 |
+
PyYAML
|
| 10 |
+
Shapely
|
| 11 |
+
pyclipper
|
| 12 |
+
ninja
|
scripts/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
JMnedict*
|
| 2 |
+
JMdict*
|
scripts/generate-ja.rb
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# frozen_string_literal: true
|
| 2 |
+
|
| 3 |
+
require 'json'
|
| 4 |
+
require 'nokogiri'
|
| 5 |
+
require 'parallel'
|
| 6 |
+
require 'ruby-progressbar'
|
| 7 |
+
|
| 8 |
+
JMDICT_XML = 'JMdict_e'
|
| 9 |
+
JMNEDICT_XML = 'JMnedict.xml'
|
| 10 |
+
PUNC = '【】《》〈〉⦅⦆{}[]〔〕()『』「」、;:・?〜=。!⁉︎‥…〜※*〽♪♫♬♩〇〒〶〠〄ⓍⓁⓎ→'.chars
|
| 11 |
+
|
| 12 |
+
def download_dict(xml)
|
| 13 |
+
return if File.exist?(File.expand_path(xml, __dir__))
|
| 14 |
+
|
| 15 |
+
archive = "#{xml}.gz"
|
| 16 |
+
url = "http://ftp.monash.edu/pub/nihongo/#{archive}"
|
| 17 |
+
`cd #{File.dirname(__FILE__)} && wget #{url} && gunzip #{archive}`
|
| 18 |
+
end
|
| 19 |
+
|
| 20 |
+
def read_word(word)
|
| 21 |
+
word.css('k_ele keb').map(&:text) + word.css('r_ele reb').map(&:text)
|
| 22 |
+
end
|
| 23 |
+
|
| 24 |
+
def read_dict(filename, root)
|
| 25 |
+
xml = Nokogiri::XML(File.open(File.expand_path(filename, __dir__)))
|
| 26 |
+
words = xml.css("#{root} > entry")
|
| 27 |
+
Parallel.flat_map(words, in_threads: 16, progress: root) do |word|
|
| 28 |
+
read_word(word)
|
| 29 |
+
end
|
| 30 |
+
end
|
| 31 |
+
|
| 32 |
+
def write_files(words)
|
| 33 |
+
src_dir = File.expand_path('../easyocr', __dir__)
|
| 34 |
+
ja_dict = File.join(src_dir, 'dict', 'ja.txt')
|
| 35 |
+
ja_char = File.join(src_dir, 'character', 'ja_char2.txt')
|
| 36 |
+
ja_char_old = File.join(src_dir, 'character', 'ja_char.txt')
|
| 37 |
+
ja_punc = File.join(src_dir, 'character', 'ja_punc.txt')
|
| 38 |
+
|
| 39 |
+
words -= PUNC
|
| 40 |
+
chars = words.join.chars.uniq
|
| 41 |
+
chars_old = IO.read(ja_char_old).split("\n")
|
| 42 |
+
|
| 43 |
+
puts "new characters: #{(chars - chars_old).size}"
|
| 44 |
+
puts "missing characters: #{(chars_old - chars).size}"
|
| 45 |
+
puts chars_old - chars
|
| 46 |
+
|
| 47 |
+
IO.write(ja_dict, words.join("\n"))
|
| 48 |
+
IO.write(ja_char, chars.join("\n"))
|
| 49 |
+
IO.write(ja_punc, PUNC.join("\n"))
|
| 50 |
+
end
|
| 51 |
+
|
| 52 |
+
download_dict(JMDICT_XML)
|
| 53 |
+
download_dict(JMNEDICT_XML)
|
| 54 |
+
words = read_dict(JMDICT_XML, 'JMdict') + read_dict(JMNEDICT_XML, 'JMnedict')
|
| 55 |
+
write_files(words)
|
setup.cfg
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[metadata]
|
| 2 |
+
description_file = README.md
|
setup.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
End-to-End Multi-Lingual Optical Character Recognition (OCR) Solution
|
| 3 |
+
"""
|
| 4 |
+
from io import open
|
| 5 |
+
from setuptools import setup
|
| 6 |
+
|
| 7 |
+
with open('requirements.txt', encoding="utf-8-sig") as f:
|
| 8 |
+
requirements = f.readlines()
|
| 9 |
+
|
| 10 |
+
def readme():
|
| 11 |
+
with open('README.md', encoding="utf-8-sig") as f:
|
| 12 |
+
README = f.read()
|
| 13 |
+
return README
|
| 14 |
+
|
| 15 |
+
setup(
|
| 16 |
+
name='easyocr',
|
| 17 |
+
packages=['easyocr'],
|
| 18 |
+
include_package_data=True,
|
| 19 |
+
version='1.7.2',
|
| 20 |
+
install_requires=requirements,
|
| 21 |
+
entry_points={"console_scripts": ["easyocr= easyocr.cli:main"]},
|
| 22 |
+
license='Apache License 2.0',
|
| 23 |
+
description='End-to-End Multi-Lingual Optical Character Recognition (OCR) Solution',
|
| 24 |
+
long_description=readme(),
|
| 25 |
+
long_description_content_type="text/markdown",
|
| 26 |
+
author='Rakpong Kittinaradorn',
|
| 27 |
+
author_email='r.kittinaradorn@gmail.com',
|
| 28 |
+
url='https://github.com/jaidedai/easyocr',
|
| 29 |
+
download_url='https://github.com/jaidedai/easyocr.git',
|
| 30 |
+
keywords=['ocr optical character recognition deep learning neural network'],
|
| 31 |
+
classifiers=[
|
| 32 |
+
'Development Status :: 5 - Production/Stable'
|
| 33 |
+
],
|
| 34 |
+
|
| 35 |
+
)
|
trainer/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EasyOCR trainer
|
| 2 |
+
|
| 3 |
+
use `trainer.ipynb` with yaml config in `config_files` folder
|
trainer/all_data/folder.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
place dataset folder here
|
trainer/config_files/en_filtered_config.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
number: '0123456789'
|
| 2 |
+
symbol: "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ €"
|
| 3 |
+
lang_char: 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 4 |
+
experiment_name: 'en_filtered'
|
| 5 |
+
train_data: 'all_data'
|
| 6 |
+
valid_data: 'all_data/en_val'
|
| 7 |
+
manualSeed: 1111
|
| 8 |
+
workers: 6
|
| 9 |
+
batch_size: 32 #32
|
| 10 |
+
num_iter: 300000
|
| 11 |
+
valInterval: 20000
|
| 12 |
+
saved_model: '' #'saved_models/en_filtered/iter_300000.pth'
|
| 13 |
+
FT: False
|
| 14 |
+
optim: False # default is Adadelta
|
| 15 |
+
lr: 1.
|
| 16 |
+
beta1: 0.9
|
| 17 |
+
rho: 0.95
|
| 18 |
+
eps: 0.00000001
|
| 19 |
+
grad_clip: 5
|
| 20 |
+
#Data processing
|
| 21 |
+
select_data: 'en_train_filtered' # this is dataset folder in train_data
|
| 22 |
+
batch_ratio: '1'
|
| 23 |
+
total_data_usage_ratio: 1.0
|
| 24 |
+
batch_max_length: 34
|
| 25 |
+
imgH: 64
|
| 26 |
+
imgW: 600
|
| 27 |
+
rgb: False
|
| 28 |
+
contrast_adjust: False
|
| 29 |
+
sensitive: True
|
| 30 |
+
PAD: True
|
| 31 |
+
contrast_adjust: 0.0
|
| 32 |
+
data_filtering_off: False
|
| 33 |
+
# Model Architecture
|
| 34 |
+
Transformation: 'None'
|
| 35 |
+
FeatureExtraction: 'VGG'
|
| 36 |
+
SequenceModeling: 'BiLSTM'
|
| 37 |
+
Prediction: 'CTC'
|
| 38 |
+
num_fiducial: 20
|
| 39 |
+
input_channel: 1
|
| 40 |
+
output_channel: 256
|
| 41 |
+
hidden_size: 256
|
| 42 |
+
decode: 'greedy'
|
| 43 |
+
new_prediction: False
|
| 44 |
+
freeze_FeatureFxtraction: False
|
| 45 |
+
freeze_SequenceModeling: False
|
trainer/craft/.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
model/__pycache__/
|
| 3 |
+
wandb/*
|
| 4 |
+
vis_result/*
|
trainer/craft/README.md
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CRAFT-train
|
| 2 |
+
On the official CRAFT github, there are many people who want to train CRAFT models.
|
| 3 |
+
|
| 4 |
+
However, the training code is not published in the official CRAFT repository.
|
| 5 |
+
|
| 6 |
+
There are other reproduced codes, but there is a gap between their performance and performance reported in the original paper. (https://arxiv.org/pdf/1904.01941.pdf)
|
| 7 |
+
|
| 8 |
+
The trained model with this code recorded a level of performance similar to that of the original paper.
|
| 9 |
+
|
| 10 |
+
```bash
|
| 11 |
+
├── config
|
| 12 |
+
│ ├── syn_train.yaml
|
| 13 |
+
│ └── custom_data_train.yaml
|
| 14 |
+
├── data
|
| 15 |
+
│ ├── pseudo_label
|
| 16 |
+
│ │ ├── make_charbox.py
|
| 17 |
+
│ │ └── watershed.py
|
| 18 |
+
│ ├── boxEnlarge.py
|
| 19 |
+
│ ├── dataset.py
|
| 20 |
+
│ ├── gaussian.py
|
| 21 |
+
│ ├── imgaug.py
|
| 22 |
+
│ └── imgproc.py
|
| 23 |
+
├── loss
|
| 24 |
+
│ └── mseloss.py
|
| 25 |
+
├── metrics
|
| 26 |
+
│ └── eval_det_iou.py
|
| 27 |
+
├── model
|
| 28 |
+
│ ├── craft.py
|
| 29 |
+
│ └── vgg16_bn.py
|
| 30 |
+
├── utils
|
| 31 |
+
│ ├── craft_utils.py
|
| 32 |
+
│ ├── inference_boxes.py
|
| 33 |
+
│ └── utils.py
|
| 34 |
+
├── trainSynth.py
|
| 35 |
+
├── train.py
|
| 36 |
+
├── train_distributed.py
|
| 37 |
+
├── eval.py
|
| 38 |
+
├── data_root_dir (place dataset folder here)
|
| 39 |
+
└── exp (model and experiment result files will saved here)
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
### Installation
|
| 43 |
+
|
| 44 |
+
Install using `pip`
|
| 45 |
+
|
| 46 |
+
``` bash
|
| 47 |
+
pip install -r requirements.txt
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
### Training
|
| 52 |
+
1. Put your training, test data in the following format
|
| 53 |
+
```
|
| 54 |
+
└── data_root_dir (you can change root dir in yaml file)
|
| 55 |
+
├── ch4_training_images
|
| 56 |
+
│ ├── img_1.jpg
|
| 57 |
+
│ └── img_2.jpg
|
| 58 |
+
├── ch4_training_localization_transcription_gt
|
| 59 |
+
│ ├── gt_img_1.txt
|
| 60 |
+
│ └── gt_img_2.txt
|
| 61 |
+
├── ch4_test_images
|
| 62 |
+
│ ├── img_1.jpg
|
| 63 |
+
│ └── img_2.jpg
|
| 64 |
+
└── ch4_training_localization_transcription_gt
|
| 65 |
+
├── gt_img_1.txt
|
| 66 |
+
└── gt_img_2.txt
|
| 67 |
+
```
|
| 68 |
+
* localization_transcription_gt files format :
|
| 69 |
+
```
|
| 70 |
+
377,117,463,117,465,130,378,130,Genaxis Theatre
|
| 71 |
+
493,115,519,115,519,131,493,131,[06]
|
| 72 |
+
374,155,409,155,409,170,374,170,###
|
| 73 |
+
```
|
| 74 |
+
2. Write configuration in yaml format (example config files are provided in `config` folder.)
|
| 75 |
+
* To speed up training time with multi-gpu, set num_worker > 0
|
| 76 |
+
3. Put the yaml file in the config folder
|
| 77 |
+
4. Run training script like below (If you have multi-gpu, run train_distributed.py)
|
| 78 |
+
5. Then, experiment results will be saved to ```./exp/[yaml]``` by default.
|
| 79 |
+
|
| 80 |
+
* Step 1 : To train CRAFT with SynthText dataset from scratch
|
| 81 |
+
* Note : This step is not necessary if you use <a href="https://drive.google.com/file/d/1enVIsgNvBf3YiRsVkxodspOn55PIK-LJ/view?usp=sharing">this pretrain</a> as a checkpoint when start training step 2. You can download and put it in `exp/CRAFT_clr_amp_29500.pth` and change `ckpt_path` in the config file according to your local setup.
|
| 82 |
+
```
|
| 83 |
+
CUDA_VISIBLE_DEVICES=0 python3 trainSynth.py --yaml=syn_train
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
* Step 2 : To train CRAFT with [SynthText + IC15] or custom dataset
|
| 87 |
+
```
|
| 88 |
+
CUDA_VISIBLE_DEVICES=0 python3 train.py --yaml=custom_data_train ## if you run on single GPU
|
| 89 |
+
CUDA_VISIBLE_DEVICES=0,1 python3 train_distributed.py --yaml=custom_data_train ## if you run on multi GPU
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### Arguments
|
| 93 |
+
* ```--yaml``` : configuration file name
|
| 94 |
+
|
| 95 |
+
### Evaluation
|
| 96 |
+
* In the official repository issues, the author mentioned that the first row setting F1-score is around 0.75.
|
| 97 |
+
* In the official paper, it is stated that the result F1-score of the second row setting is 0.87.
|
| 98 |
+
* If you adjust post-process parameter 'text_threshold' from 0.85 to 0.75, then F1-score reaches to 0.856.
|
| 99 |
+
* It took 14h to train weak-supervision 25k iteration with 8 RTX 3090 Ti.
|
| 100 |
+
* Half of GPU assigned for training, and half of GPU assigned for supervision setting.
|
| 101 |
+
|
| 102 |
+
| Training Dataset | Evaluation Dataset | Precision | Recall | F1-score | pretrained model |
|
| 103 |
+
| ------------- |-----|:-----:|:-----:|:-----:|-----:|
|
| 104 |
+
| SynthText | ICDAR2013 | 0.801 | 0.748 | 0.773| <a href="https://drive.google.com/file/d/1enVIsgNvBf3YiRsVkxodspOn55PIK-LJ/view?usp=sharing">download link</a>|
|
| 105 |
+
| SynthText + ICDAR2015 | ICDAR2015 | 0.909 | 0.794 | 0.848| <a href="https://drive.google.com/file/d/1qUeZIDSFCOuGS9yo8o0fi-zYHLEW6lBP/view">download link</a>|
|
trainer/craft/config/__init__.py
ADDED
|
File without changes
|
trainer/craft/config/custom_data_train.yaml
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_opt: False
|
| 2 |
+
|
| 3 |
+
results_dir: "./exp/"
|
| 4 |
+
vis_test_dir: "./vis_result/"
|
| 5 |
+
|
| 6 |
+
data_root_dir: "./data_root_dir/"
|
| 7 |
+
score_gt_dir: None # "/data/ICDAR2015_official_supervision"
|
| 8 |
+
mode: "weak_supervision"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
train:
|
| 12 |
+
backbone : vgg
|
| 13 |
+
use_synthtext: False # If you want to combine SynthText in train time as CRAFT did, you can turn on this option
|
| 14 |
+
synth_data_dir: "/data/SynthText/"
|
| 15 |
+
synth_ratio: 5
|
| 16 |
+
real_dataset: custom
|
| 17 |
+
ckpt_path: "./pretrained_model/CRAFT_clr_amp_29500.pth"
|
| 18 |
+
eval_interval: 1000
|
| 19 |
+
batch_size: 5
|
| 20 |
+
st_iter: 0
|
| 21 |
+
end_iter: 25000
|
| 22 |
+
lr: 0.0001
|
| 23 |
+
lr_decay: 7500
|
| 24 |
+
gamma: 0.2
|
| 25 |
+
weight_decay: 0.00001
|
| 26 |
+
num_workers: 0 # On single gpu, train.py execution only works when num worker = 0 / On multi-gpu, you can set num_worker > 0 to speed up
|
| 27 |
+
amp: True
|
| 28 |
+
loss: 2
|
| 29 |
+
neg_rto: 0.3
|
| 30 |
+
n_min_neg: 5000
|
| 31 |
+
data:
|
| 32 |
+
vis_opt: False
|
| 33 |
+
pseudo_vis_opt: False
|
| 34 |
+
output_size: 768
|
| 35 |
+
do_not_care_label: ['###', '']
|
| 36 |
+
mean: [0.485, 0.456, 0.406]
|
| 37 |
+
variance: [0.229, 0.224, 0.225]
|
| 38 |
+
enlarge_region : [0.5, 0.5] # x axis, y axis
|
| 39 |
+
enlarge_affinity: [0.5, 0.5]
|
| 40 |
+
gauss_init_size: 200
|
| 41 |
+
gauss_sigma: 40
|
| 42 |
+
watershed:
|
| 43 |
+
version: "skimage"
|
| 44 |
+
sure_fg_th: 0.75
|
| 45 |
+
sure_bg_th: 0.05
|
| 46 |
+
syn_sample: -1
|
| 47 |
+
custom_sample: -1
|
| 48 |
+
syn_aug:
|
| 49 |
+
random_scale:
|
| 50 |
+
range: [1.0, 1.5, 2.0]
|
| 51 |
+
option: False
|
| 52 |
+
random_rotate:
|
| 53 |
+
max_angle: 20
|
| 54 |
+
option: False
|
| 55 |
+
random_crop:
|
| 56 |
+
version: "random_resize_crop_synth"
|
| 57 |
+
option: True
|
| 58 |
+
random_horizontal_flip:
|
| 59 |
+
option: False
|
| 60 |
+
random_colorjitter:
|
| 61 |
+
brightness: 0.2
|
| 62 |
+
contrast: 0.2
|
| 63 |
+
saturation: 0.2
|
| 64 |
+
hue: 0.2
|
| 65 |
+
option: True
|
| 66 |
+
custom_aug:
|
| 67 |
+
random_scale:
|
| 68 |
+
range: [ 1.0, 1.5, 2.0 ]
|
| 69 |
+
option: False
|
| 70 |
+
random_rotate:
|
| 71 |
+
max_angle: 20
|
| 72 |
+
option: True
|
| 73 |
+
random_crop:
|
| 74 |
+
version: "random_resize_crop"
|
| 75 |
+
scale: [0.03, 0.4]
|
| 76 |
+
ratio: [0.75, 1.33]
|
| 77 |
+
rnd_threshold: 1.0
|
| 78 |
+
option: True
|
| 79 |
+
random_horizontal_flip:
|
| 80 |
+
option: True
|
| 81 |
+
random_colorjitter:
|
| 82 |
+
brightness: 0.2
|
| 83 |
+
contrast: 0.2
|
| 84 |
+
saturation: 0.2
|
| 85 |
+
hue: 0.2
|
| 86 |
+
option: True
|
| 87 |
+
|
| 88 |
+
test:
|
| 89 |
+
trained_model : null
|
| 90 |
+
custom_data:
|
| 91 |
+
test_set_size: 500
|
| 92 |
+
test_data_dir: "./data_root_dir/"
|
| 93 |
+
text_threshold: 0.75
|
| 94 |
+
low_text: 0.5
|
| 95 |
+
link_threshold: 0.2
|
| 96 |
+
canvas_size: 2240
|
| 97 |
+
mag_ratio: 1.75
|
| 98 |
+
poly: False
|
| 99 |
+
cuda: True
|
| 100 |
+
vis_opt: False
|
trainer/craft/config/load_config.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
from functools import reduce
|
| 4 |
+
|
| 5 |
+
CONFIG_PATH = os.path.dirname(__file__)
|
| 6 |
+
|
| 7 |
+
def load_yaml(config_name):
|
| 8 |
+
|
| 9 |
+
with open(os.path.join(CONFIG_PATH, config_name)+ '.yaml') as file:
|
| 10 |
+
config = yaml.safe_load(file)
|
| 11 |
+
|
| 12 |
+
return config
|
| 13 |
+
|
| 14 |
+
class DotDict(dict):
|
| 15 |
+
def __getattr__(self, k):
|
| 16 |
+
try:
|
| 17 |
+
v = self[k]
|
| 18 |
+
except:
|
| 19 |
+
return super().__getattr__(k)
|
| 20 |
+
if isinstance(v, dict):
|
| 21 |
+
return DotDict(v)
|
| 22 |
+
return v
|
| 23 |
+
|
| 24 |
+
def __getitem__(self, k):
|
| 25 |
+
if isinstance(k, str) and '.' in k:
|
| 26 |
+
k = k.split('.')
|
| 27 |
+
if isinstance(k, (list, tuple)):
|
| 28 |
+
return reduce(lambda d, kk: d[kk], k, self)
|
| 29 |
+
return super().__getitem__(k)
|
| 30 |
+
|
| 31 |
+
def get(self, k, default=None):
|
| 32 |
+
if isinstance(k, str) and '.' in k:
|
| 33 |
+
try:
|
| 34 |
+
return self[k]
|
| 35 |
+
except KeyError:
|
| 36 |
+
return default
|
| 37 |
+
return super().get(k, default=default)
|
trainer/craft/config/syn_train.yaml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_opt: False
|
| 2 |
+
|
| 3 |
+
results_dir: "./exp/"
|
| 4 |
+
vis_test_dir: "./vis_result/"
|
| 5 |
+
data_dir:
|
| 6 |
+
synthtext: "/data/SynthText/"
|
| 7 |
+
synthtext_gt: NULL
|
| 8 |
+
|
| 9 |
+
train:
|
| 10 |
+
backbone : vgg
|
| 11 |
+
dataset: ["synthtext"]
|
| 12 |
+
ckpt_path: null
|
| 13 |
+
eval_interval: 1000
|
| 14 |
+
batch_size: 5
|
| 15 |
+
st_iter: 0
|
| 16 |
+
end_iter: 50000
|
| 17 |
+
lr: 0.0001
|
| 18 |
+
lr_decay: 15000
|
| 19 |
+
gamma: 0.2
|
| 20 |
+
weight_decay: 0.00001
|
| 21 |
+
num_workers: 4
|
| 22 |
+
amp: True
|
| 23 |
+
loss: 3
|
| 24 |
+
neg_rto: 1
|
| 25 |
+
n_min_neg: 1000
|
| 26 |
+
data:
|
| 27 |
+
vis_opt: False
|
| 28 |
+
output_size: 768
|
| 29 |
+
mean: [0.485, 0.456, 0.406]
|
| 30 |
+
variance: [0.229, 0.224, 0.225]
|
| 31 |
+
enlarge_region : [0.5, 0.5] # x axis, y axis
|
| 32 |
+
enlarge_affinity: [0.5, 0.5]
|
| 33 |
+
gauss_init_size: 200
|
| 34 |
+
gauss_sigma: 40
|
| 35 |
+
syn_sample : -1
|
| 36 |
+
syn_aug:
|
| 37 |
+
random_scale:
|
| 38 |
+
range: [1.0, 1.5, 2.0]
|
| 39 |
+
option: False
|
| 40 |
+
random_rotate:
|
| 41 |
+
max_angle: 20
|
| 42 |
+
option: False
|
| 43 |
+
random_crop:
|
| 44 |
+
version: "random_resize_crop_synth"
|
| 45 |
+
rnd_threshold : 1.0
|
| 46 |
+
option: True
|
| 47 |
+
random_horizontal_flip:
|
| 48 |
+
option: False
|
| 49 |
+
random_colorjitter:
|
| 50 |
+
brightness: 0.2
|
| 51 |
+
contrast: 0.2
|
| 52 |
+
saturation: 0.2
|
| 53 |
+
hue: 0.2
|
| 54 |
+
option: True
|
| 55 |
+
|
| 56 |
+
test:
|
| 57 |
+
trained_model: null
|
| 58 |
+
icdar2013:
|
| 59 |
+
test_set_size: 233
|
| 60 |
+
cuda: True
|
| 61 |
+
vis_opt: True
|
| 62 |
+
test_data_dir : "/data/ICDAR2013/"
|
| 63 |
+
text_threshold: 0.85
|
| 64 |
+
low_text: 0.5
|
| 65 |
+
link_threshold: 0.2
|
| 66 |
+
canvas_size: 960
|
| 67 |
+
mag_ratio: 1.5
|
| 68 |
+
poly: False
|
trainer/craft/data/boxEnlarge.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def pointAngle(Apoint, Bpoint):
|
| 6 |
+
angle = (Bpoint[1] - Apoint[1]) / ((Bpoint[0] - Apoint[0]) + 10e-8)
|
| 7 |
+
return angle
|
| 8 |
+
|
| 9 |
+
def pointDistance(Apoint, Bpoint):
|
| 10 |
+
return math.sqrt((Bpoint[1] - Apoint[1])**2 + (Bpoint[0] - Apoint[0])**2)
|
| 11 |
+
|
| 12 |
+
def lineBiasAndK(Apoint, Bpoint):
|
| 13 |
+
|
| 14 |
+
K = pointAngle(Apoint, Bpoint)
|
| 15 |
+
B = Apoint[1] - K*Apoint[0]
|
| 16 |
+
return K, B
|
| 17 |
+
|
| 18 |
+
def getX(K, B, Ypoint):
|
| 19 |
+
return int((Ypoint-B)/K)
|
| 20 |
+
|
| 21 |
+
def sidePoint(Apoint, Bpoint, h, w, placehold, enlarge_size):
|
| 22 |
+
|
| 23 |
+
K, B = lineBiasAndK(Apoint, Bpoint)
|
| 24 |
+
angle = abs(math.atan(pointAngle(Apoint, Bpoint)))
|
| 25 |
+
distance = pointDistance(Apoint, Bpoint)
|
| 26 |
+
|
| 27 |
+
x_enlarge_size, y_enlarge_size = enlarge_size
|
| 28 |
+
|
| 29 |
+
XaxisIncreaseDistance = abs(math.cos(angle) * x_enlarge_size * distance)
|
| 30 |
+
YaxisIncreaseDistance = abs(math.sin(angle) * y_enlarge_size * distance)
|
| 31 |
+
|
| 32 |
+
if placehold == 'leftTop':
|
| 33 |
+
x1 = max(0, Apoint[0] - XaxisIncreaseDistance)
|
| 34 |
+
y1 = max(0, Apoint[1] - YaxisIncreaseDistance)
|
| 35 |
+
elif placehold == 'rightTop':
|
| 36 |
+
x1 = min(w, Bpoint[0] + XaxisIncreaseDistance)
|
| 37 |
+
y1 = max(0, Bpoint[1] - YaxisIncreaseDistance)
|
| 38 |
+
elif placehold == 'rightBottom':
|
| 39 |
+
x1 = min(w, Bpoint[0] + XaxisIncreaseDistance)
|
| 40 |
+
y1 = min(h, Bpoint[1] + YaxisIncreaseDistance)
|
| 41 |
+
elif placehold == 'leftBottom':
|
| 42 |
+
x1 = max(0, Apoint[0] - XaxisIncreaseDistance)
|
| 43 |
+
y1 = min(h, Apoint[1] + YaxisIncreaseDistance)
|
| 44 |
+
return int(x1), int(y1)
|
| 45 |
+
|
| 46 |
+
def enlargebox(box, h, w, enlarge_size, horizontal_text_bool):
|
| 47 |
+
|
| 48 |
+
if not horizontal_text_bool:
|
| 49 |
+
enlarge_size = (enlarge_size[1], enlarge_size[0])
|
| 50 |
+
|
| 51 |
+
box = np.roll(box, -np.argmin(box.sum(axis=1)), axis=0)
|
| 52 |
+
|
| 53 |
+
Apoint, Bpoint, Cpoint, Dpoint = box
|
| 54 |
+
K1, B1 = lineBiasAndK(box[0], box[2])
|
| 55 |
+
K2, B2 = lineBiasAndK(box[3], box[1])
|
| 56 |
+
X = (B2 - B1)/(K1 - K2)
|
| 57 |
+
Y = K1 * X + B1
|
| 58 |
+
center = [X, Y]
|
| 59 |
+
|
| 60 |
+
x1, y1 = sidePoint(Apoint, center, h, w, 'leftTop', enlarge_size)
|
| 61 |
+
x2, y2 = sidePoint(center, Bpoint, h, w, 'rightTop', enlarge_size)
|
| 62 |
+
x3, y3 = sidePoint(center, Cpoint, h, w, 'rightBottom', enlarge_size)
|
| 63 |
+
x4, y4 = sidePoint(Dpoint, center, h, w, 'leftBottom', enlarge_size)
|
| 64 |
+
newcharbox = np.array([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
|
| 65 |
+
return newcharbox
|
trainer/craft/data/dataset.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import itertools
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import scipy.io as scio
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import cv2
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
|
| 13 |
+
from data import imgproc
|
| 14 |
+
from data.gaussian import GaussianBuilder
|
| 15 |
+
from data.imgaug import (
|
| 16 |
+
rescale,
|
| 17 |
+
random_resize_crop_synth,
|
| 18 |
+
random_resize_crop,
|
| 19 |
+
random_horizontal_flip,
|
| 20 |
+
random_rotate,
|
| 21 |
+
random_scale,
|
| 22 |
+
random_crop,
|
| 23 |
+
)
|
| 24 |
+
from data.pseudo_label.make_charbox import PseudoCharBoxBuilder
|
| 25 |
+
from utils.util import saveInput, saveImage
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class CraftBaseDataset(Dataset):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
output_size,
|
| 32 |
+
data_dir,
|
| 33 |
+
saved_gt_dir,
|
| 34 |
+
mean,
|
| 35 |
+
variance,
|
| 36 |
+
gauss_init_size,
|
| 37 |
+
gauss_sigma,
|
| 38 |
+
enlarge_region,
|
| 39 |
+
enlarge_affinity,
|
| 40 |
+
aug,
|
| 41 |
+
vis_test_dir,
|
| 42 |
+
vis_opt,
|
| 43 |
+
sample,
|
| 44 |
+
):
|
| 45 |
+
self.output_size = output_size
|
| 46 |
+
self.data_dir = data_dir
|
| 47 |
+
self.saved_gt_dir = saved_gt_dir
|
| 48 |
+
self.mean, self.variance = mean, variance
|
| 49 |
+
self.gaussian_builder = GaussianBuilder(
|
| 50 |
+
gauss_init_size, gauss_sigma, enlarge_region, enlarge_affinity
|
| 51 |
+
)
|
| 52 |
+
self.aug = aug
|
| 53 |
+
self.vis_test_dir = vis_test_dir
|
| 54 |
+
self.vis_opt = vis_opt
|
| 55 |
+
self.sample = sample
|
| 56 |
+
if self.sample != -1:
|
| 57 |
+
random.seed(0)
|
| 58 |
+
self.idx = random.sample(range(0, len(self.img_names)), self.sample)
|
| 59 |
+
|
| 60 |
+
self.pre_crop_area = []
|
| 61 |
+
|
| 62 |
+
def augment_image(
|
| 63 |
+
self, image, region_score, affinity_score, confidence_mask, word_level_char_bbox
|
| 64 |
+
):
|
| 65 |
+
augment_targets = [image, region_score, affinity_score, confidence_mask]
|
| 66 |
+
|
| 67 |
+
if self.aug.random_scale.option:
|
| 68 |
+
augment_targets, word_level_char_bbox = random_scale(
|
| 69 |
+
augment_targets, word_level_char_bbox, self.aug.random_scale.range
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if self.aug.random_rotate.option:
|
| 73 |
+
augment_targets = random_rotate(
|
| 74 |
+
augment_targets, self.aug.random_rotate.max_angle
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if self.aug.random_crop.option:
|
| 78 |
+
if self.aug.random_crop.version == "random_crop_with_bbox":
|
| 79 |
+
augment_targets = random_crop_with_bbox(
|
| 80 |
+
augment_targets, word_level_char_bbox, self.output_size
|
| 81 |
+
)
|
| 82 |
+
elif self.aug.random_crop.version == "random_resize_crop_synth":
|
| 83 |
+
augment_targets = random_resize_crop_synth(
|
| 84 |
+
augment_targets, self.output_size
|
| 85 |
+
)
|
| 86 |
+
elif self.aug.random_crop.version == "random_resize_crop":
|
| 87 |
+
|
| 88 |
+
if len(self.pre_crop_area) > 0:
|
| 89 |
+
pre_crop_area = self.pre_crop_area
|
| 90 |
+
else:
|
| 91 |
+
pre_crop_area = None
|
| 92 |
+
|
| 93 |
+
augment_targets = random_resize_crop(
|
| 94 |
+
augment_targets,
|
| 95 |
+
self.aug.random_crop.scale,
|
| 96 |
+
self.aug.random_crop.ratio,
|
| 97 |
+
self.output_size,
|
| 98 |
+
self.aug.random_crop.rnd_threshold,
|
| 99 |
+
pre_crop_area,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
elif self.aug.random_crop.version == "random_crop":
|
| 103 |
+
augment_targets = random_crop(augment_targets, self.output_size,)
|
| 104 |
+
|
| 105 |
+
else:
|
| 106 |
+
assert "Undefined RandomCrop version"
|
| 107 |
+
|
| 108 |
+
if self.aug.random_horizontal_flip.option:
|
| 109 |
+
augment_targets = random_horizontal_flip(augment_targets)
|
| 110 |
+
|
| 111 |
+
if self.aug.random_colorjitter.option:
|
| 112 |
+
image, region_score, affinity_score, confidence_mask = augment_targets
|
| 113 |
+
image = Image.fromarray(image)
|
| 114 |
+
image = transforms.ColorJitter(
|
| 115 |
+
brightness=self.aug.random_colorjitter.brightness,
|
| 116 |
+
contrast=self.aug.random_colorjitter.contrast,
|
| 117 |
+
saturation=self.aug.random_colorjitter.saturation,
|
| 118 |
+
hue=self.aug.random_colorjitter.hue,
|
| 119 |
+
)(image)
|
| 120 |
+
else:
|
| 121 |
+
image, region_score, affinity_score, confidence_mask = augment_targets
|
| 122 |
+
|
| 123 |
+
return np.array(image), region_score, affinity_score, confidence_mask
|
| 124 |
+
|
| 125 |
+
def resize_to_half(self, ground_truth, interpolation):
|
| 126 |
+
return cv2.resize(
|
| 127 |
+
ground_truth,
|
| 128 |
+
(self.output_size // 2, self.output_size // 2),
|
| 129 |
+
interpolation=interpolation,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def __len__(self):
|
| 133 |
+
if self.sample != -1:
|
| 134 |
+
return len(self.idx)
|
| 135 |
+
else:
|
| 136 |
+
return len(self.img_names)
|
| 137 |
+
|
| 138 |
+
def __getitem__(self, index):
|
| 139 |
+
if self.sample != -1:
|
| 140 |
+
index = self.idx[index]
|
| 141 |
+
if self.saved_gt_dir is None:
|
| 142 |
+
(
|
| 143 |
+
image,
|
| 144 |
+
region_score,
|
| 145 |
+
affinity_score,
|
| 146 |
+
confidence_mask,
|
| 147 |
+
word_level_char_bbox,
|
| 148 |
+
all_affinity_bbox,
|
| 149 |
+
words,
|
| 150 |
+
) = self.make_gt_score(index)
|
| 151 |
+
else:
|
| 152 |
+
(
|
| 153 |
+
image,
|
| 154 |
+
region_score,
|
| 155 |
+
affinity_score,
|
| 156 |
+
confidence_mask,
|
| 157 |
+
word_level_char_bbox,
|
| 158 |
+
words,
|
| 159 |
+
) = self.load_saved_gt_score(index)
|
| 160 |
+
all_affinity_bbox = []
|
| 161 |
+
|
| 162 |
+
if self.vis_opt:
|
| 163 |
+
saveImage(
|
| 164 |
+
self.img_names[index],
|
| 165 |
+
self.vis_test_dir,
|
| 166 |
+
image.copy(),
|
| 167 |
+
word_level_char_bbox.copy(),
|
| 168 |
+
all_affinity_bbox.copy(),
|
| 169 |
+
region_score.copy(),
|
| 170 |
+
affinity_score.copy(),
|
| 171 |
+
confidence_mask.copy(),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
image, region_score, affinity_score, confidence_mask = self.augment_image(
|
| 175 |
+
image, region_score, affinity_score, confidence_mask, word_level_char_bbox
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if self.vis_opt:
|
| 179 |
+
saveInput(
|
| 180 |
+
self.img_names[index],
|
| 181 |
+
self.vis_test_dir,
|
| 182 |
+
image,
|
| 183 |
+
region_score,
|
| 184 |
+
affinity_score,
|
| 185 |
+
confidence_mask,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
region_score = self.resize_to_half(region_score, interpolation=cv2.INTER_CUBIC)
|
| 189 |
+
affinity_score = self.resize_to_half(
|
| 190 |
+
affinity_score, interpolation=cv2.INTER_CUBIC
|
| 191 |
+
)
|
| 192 |
+
confidence_mask = self.resize_to_half(
|
| 193 |
+
confidence_mask, interpolation=cv2.INTER_NEAREST
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
image = imgproc.normalizeMeanVariance(
|
| 197 |
+
np.array(image), mean=self.mean, variance=self.variance
|
| 198 |
+
)
|
| 199 |
+
image = image.transpose(2, 0, 1)
|
| 200 |
+
|
| 201 |
+
return image, region_score, affinity_score, confidence_mask
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class SynthTextDataSet(CraftBaseDataset):
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
output_size,
|
| 208 |
+
data_dir,
|
| 209 |
+
saved_gt_dir,
|
| 210 |
+
mean,
|
| 211 |
+
variance,
|
| 212 |
+
gauss_init_size,
|
| 213 |
+
gauss_sigma,
|
| 214 |
+
enlarge_region,
|
| 215 |
+
enlarge_affinity,
|
| 216 |
+
aug,
|
| 217 |
+
vis_test_dir,
|
| 218 |
+
vis_opt,
|
| 219 |
+
sample,
|
| 220 |
+
):
|
| 221 |
+
super().__init__(
|
| 222 |
+
output_size,
|
| 223 |
+
data_dir,
|
| 224 |
+
saved_gt_dir,
|
| 225 |
+
mean,
|
| 226 |
+
variance,
|
| 227 |
+
gauss_init_size,
|
| 228 |
+
gauss_sigma,
|
| 229 |
+
enlarge_region,
|
| 230 |
+
enlarge_affinity,
|
| 231 |
+
aug,
|
| 232 |
+
vis_test_dir,
|
| 233 |
+
vis_opt,
|
| 234 |
+
sample,
|
| 235 |
+
)
|
| 236 |
+
self.img_names, self.char_bbox, self.img_words = self.load_data()
|
| 237 |
+
self.vis_index = list(range(1000))
|
| 238 |
+
|
| 239 |
+
def load_data(self, bbox="char"):
|
| 240 |
+
|
| 241 |
+
gt = scio.loadmat(os.path.join(self.data_dir, "gt.mat"))
|
| 242 |
+
img_names = gt["imnames"][0]
|
| 243 |
+
img_words = gt["txt"][0]
|
| 244 |
+
|
| 245 |
+
if bbox == "char":
|
| 246 |
+
img_bbox = gt["charBB"][0]
|
| 247 |
+
else:
|
| 248 |
+
img_bbox = gt["wordBB"][0] # word bbox needed for test
|
| 249 |
+
|
| 250 |
+
return img_names, img_bbox, img_words
|
| 251 |
+
|
| 252 |
+
def dilate_img_to_output_size(self, image, char_bbox):
|
| 253 |
+
h, w, _ = image.shape
|
| 254 |
+
if min(h, w) <= self.output_size:
|
| 255 |
+
scale = float(self.output_size) / min(h, w)
|
| 256 |
+
else:
|
| 257 |
+
scale = 1.0
|
| 258 |
+
image = cv2.resize(
|
| 259 |
+
image, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC
|
| 260 |
+
)
|
| 261 |
+
char_bbox *= scale
|
| 262 |
+
return image, char_bbox
|
| 263 |
+
|
| 264 |
+
def make_gt_score(self, index):
|
| 265 |
+
img_path = os.path.join(self.data_dir, self.img_names[index][0])
|
| 266 |
+
image = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
| 267 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 268 |
+
all_char_bbox = self.char_bbox[index].transpose(
|
| 269 |
+
(2, 1, 0)
|
| 270 |
+
) # shape : (Number of characters in image, 4, 2)
|
| 271 |
+
|
| 272 |
+
img_h, img_w, _ = image.shape
|
| 273 |
+
|
| 274 |
+
confidence_mask = np.ones((img_h, img_w), dtype=np.float32)
|
| 275 |
+
|
| 276 |
+
words = [
|
| 277 |
+
re.split(" \n|\n |\n| ", word.strip()) for word in self.img_words[index]
|
| 278 |
+
]
|
| 279 |
+
words = list(itertools.chain(*words))
|
| 280 |
+
words = [word for word in words if len(word) > 0]
|
| 281 |
+
|
| 282 |
+
word_level_char_bbox = []
|
| 283 |
+
char_idx = 0
|
| 284 |
+
|
| 285 |
+
for i in range(len(words)):
|
| 286 |
+
length_of_word = len(words[i])
|
| 287 |
+
word_bbox = all_char_bbox[char_idx : char_idx + length_of_word]
|
| 288 |
+
assert len(word_bbox) == length_of_word
|
| 289 |
+
char_idx += length_of_word
|
| 290 |
+
word_bbox = np.array(word_bbox)
|
| 291 |
+
word_level_char_bbox.append(word_bbox)
|
| 292 |
+
|
| 293 |
+
region_score = self.gaussian_builder.generate_region(
|
| 294 |
+
img_h,
|
| 295 |
+
img_w,
|
| 296 |
+
word_level_char_bbox,
|
| 297 |
+
horizontal_text_bools=[True for _ in range(len(words))],
|
| 298 |
+
)
|
| 299 |
+
affinity_score, all_affinity_bbox = self.gaussian_builder.generate_affinity(
|
| 300 |
+
img_h,
|
| 301 |
+
img_w,
|
| 302 |
+
word_level_char_bbox,
|
| 303 |
+
horizontal_text_bools=[True for _ in range(len(words))],
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
return (
|
| 307 |
+
image,
|
| 308 |
+
region_score,
|
| 309 |
+
affinity_score,
|
| 310 |
+
confidence_mask,
|
| 311 |
+
word_level_char_bbox,
|
| 312 |
+
all_affinity_bbox,
|
| 313 |
+
words,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class CustomDataset(CraftBaseDataset):
|
| 318 |
+
def __init__(
|
| 319 |
+
self,
|
| 320 |
+
output_size,
|
| 321 |
+
data_dir,
|
| 322 |
+
saved_gt_dir,
|
| 323 |
+
mean,
|
| 324 |
+
variance,
|
| 325 |
+
gauss_init_size,
|
| 326 |
+
gauss_sigma,
|
| 327 |
+
enlarge_region,
|
| 328 |
+
enlarge_affinity,
|
| 329 |
+
aug,
|
| 330 |
+
vis_test_dir,
|
| 331 |
+
vis_opt,
|
| 332 |
+
sample,
|
| 333 |
+
watershed_param,
|
| 334 |
+
pseudo_vis_opt,
|
| 335 |
+
do_not_care_label,
|
| 336 |
+
):
|
| 337 |
+
super().__init__(
|
| 338 |
+
output_size,
|
| 339 |
+
data_dir,
|
| 340 |
+
saved_gt_dir,
|
| 341 |
+
mean,
|
| 342 |
+
variance,
|
| 343 |
+
gauss_init_size,
|
| 344 |
+
gauss_sigma,
|
| 345 |
+
enlarge_region,
|
| 346 |
+
enlarge_affinity,
|
| 347 |
+
aug,
|
| 348 |
+
vis_test_dir,
|
| 349 |
+
vis_opt,
|
| 350 |
+
sample,
|
| 351 |
+
)
|
| 352 |
+
self.pseudo_vis_opt = pseudo_vis_opt
|
| 353 |
+
self.do_not_care_label = do_not_care_label
|
| 354 |
+
self.pseudo_charbox_builder = PseudoCharBoxBuilder(
|
| 355 |
+
watershed_param, vis_test_dir, pseudo_vis_opt, self.gaussian_builder
|
| 356 |
+
)
|
| 357 |
+
self.vis_index = list(range(1000))
|
| 358 |
+
self.img_dir = os.path.join(data_dir, "ch4_training_images")
|
| 359 |
+
self.img_gt_box_dir = os.path.join(
|
| 360 |
+
data_dir, "ch4_training_localization_transcription_gt"
|
| 361 |
+
)
|
| 362 |
+
self.img_names = os.listdir(self.img_dir)
|
| 363 |
+
|
| 364 |
+
def update_model(self, net):
|
| 365 |
+
self.net = net
|
| 366 |
+
|
| 367 |
+
def update_device(self, gpu):
|
| 368 |
+
self.gpu = gpu
|
| 369 |
+
|
| 370 |
+
def load_img_gt_box(self, img_gt_box_path):
|
| 371 |
+
lines = open(img_gt_box_path, encoding="utf-8").readlines()
|
| 372 |
+
word_bboxes = []
|
| 373 |
+
words = []
|
| 374 |
+
for line in lines:
|
| 375 |
+
box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
|
| 376 |
+
box_points = [int(box_info[i]) for i in range(8)]
|
| 377 |
+
box_points = np.array(box_points, np.float32).reshape(4, 2)
|
| 378 |
+
word = box_info[8:]
|
| 379 |
+
word = ",".join(word)
|
| 380 |
+
if word in self.do_not_care_label:
|
| 381 |
+
words.append(self.do_not_care_label[0])
|
| 382 |
+
word_bboxes.append(box_points)
|
| 383 |
+
continue
|
| 384 |
+
word_bboxes.append(box_points)
|
| 385 |
+
words.append(word)
|
| 386 |
+
return np.array(word_bboxes), words
|
| 387 |
+
|
| 388 |
+
def load_data(self, index):
|
| 389 |
+
img_name = self.img_names[index]
|
| 390 |
+
img_path = os.path.join(self.img_dir, img_name)
|
| 391 |
+
image = cv2.imread(img_path)
|
| 392 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 393 |
+
|
| 394 |
+
img_gt_box_path = os.path.join(
|
| 395 |
+
self.img_gt_box_dir, "gt_%s.txt" % os.path.splitext(img_name)[0]
|
| 396 |
+
)
|
| 397 |
+
word_bboxes, words = self.load_img_gt_box(
|
| 398 |
+
img_gt_box_path
|
| 399 |
+
) # shape : (Number of word bbox, 4, 2)
|
| 400 |
+
confidence_mask = np.ones((image.shape[0], image.shape[1]), np.float32)
|
| 401 |
+
|
| 402 |
+
word_level_char_bbox = []
|
| 403 |
+
do_care_words = []
|
| 404 |
+
horizontal_text_bools = []
|
| 405 |
+
|
| 406 |
+
if len(word_bboxes) == 0:
|
| 407 |
+
return (
|
| 408 |
+
image,
|
| 409 |
+
word_level_char_bbox,
|
| 410 |
+
do_care_words,
|
| 411 |
+
confidence_mask,
|
| 412 |
+
horizontal_text_bools,
|
| 413 |
+
)
|
| 414 |
+
_word_bboxes = word_bboxes.copy()
|
| 415 |
+
for i in range(len(word_bboxes)):
|
| 416 |
+
if words[i] in self.do_not_care_label:
|
| 417 |
+
cv2.fillPoly(confidence_mask, [np.int32(_word_bboxes[i])], 0)
|
| 418 |
+
continue
|
| 419 |
+
|
| 420 |
+
(
|
| 421 |
+
pseudo_char_bbox,
|
| 422 |
+
confidence,
|
| 423 |
+
horizontal_text_bool,
|
| 424 |
+
) = self.pseudo_charbox_builder.build_char_box(
|
| 425 |
+
self.net, self.gpu, image, word_bboxes[i], words[i], img_name=img_name
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
cv2.fillPoly(confidence_mask, [np.int32(_word_bboxes[i])], confidence)
|
| 429 |
+
do_care_words.append(words[i])
|
| 430 |
+
word_level_char_bbox.append(pseudo_char_bbox)
|
| 431 |
+
horizontal_text_bools.append(horizontal_text_bool)
|
| 432 |
+
|
| 433 |
+
return (
|
| 434 |
+
image,
|
| 435 |
+
word_level_char_bbox,
|
| 436 |
+
do_care_words,
|
| 437 |
+
confidence_mask,
|
| 438 |
+
horizontal_text_bools,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
def make_gt_score(self, index):
|
| 442 |
+
"""
|
| 443 |
+
Make region, affinity scores using pseudo character-level GT bounding box
|
| 444 |
+
word_level_char_bbox's shape : [word_num, [char_num_in_one_word, 4, 2]]
|
| 445 |
+
:rtype region_score: np.float32
|
| 446 |
+
:rtype affinity_score: np.float32
|
| 447 |
+
:rtype confidence_mask: np.float32
|
| 448 |
+
:rtype word_level_char_bbox: np.float32
|
| 449 |
+
:rtype words: list
|
| 450 |
+
"""
|
| 451 |
+
(
|
| 452 |
+
image,
|
| 453 |
+
word_level_char_bbox,
|
| 454 |
+
words,
|
| 455 |
+
confidence_mask,
|
| 456 |
+
horizontal_text_bools,
|
| 457 |
+
) = self.load_data(index)
|
| 458 |
+
img_h, img_w, _ = image.shape
|
| 459 |
+
|
| 460 |
+
if len(word_level_char_bbox) == 0:
|
| 461 |
+
region_score = np.zeros((img_h, img_w), dtype=np.float32)
|
| 462 |
+
affinity_score = np.zeros((img_h, img_w), dtype=np.float32)
|
| 463 |
+
all_affinity_bbox = []
|
| 464 |
+
else:
|
| 465 |
+
region_score = self.gaussian_builder.generate_region(
|
| 466 |
+
img_h, img_w, word_level_char_bbox, horizontal_text_bools
|
| 467 |
+
)
|
| 468 |
+
affinity_score, all_affinity_bbox = self.gaussian_builder.generate_affinity(
|
| 469 |
+
img_h, img_w, word_level_char_bbox, horizontal_text_bools
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
return (
|
| 473 |
+
image,
|
| 474 |
+
region_score,
|
| 475 |
+
affinity_score,
|
| 476 |
+
confidence_mask,
|
| 477 |
+
word_level_char_bbox,
|
| 478 |
+
all_affinity_bbox,
|
| 479 |
+
words,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
def load_saved_gt_score(self, index):
|
| 483 |
+
"""
|
| 484 |
+
Load pre-saved official CRAFT model's region, affinity scores to train
|
| 485 |
+
word_level_char_bbox's shape : [word_num, [char_num_in_one_word, 4, 2]]
|
| 486 |
+
:rtype region_score: np.float32
|
| 487 |
+
:rtype affinity_score: np.float32
|
| 488 |
+
:rtype confidence_mask: np.float32
|
| 489 |
+
:rtype word_level_char_bbox: np.float32
|
| 490 |
+
:rtype words: list
|
| 491 |
+
"""
|
| 492 |
+
img_name = self.img_names[index]
|
| 493 |
+
img_path = os.path.join(self.img_dir, img_name)
|
| 494 |
+
image = cv2.imread(img_path)
|
| 495 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 496 |
+
|
| 497 |
+
img_gt_box_path = os.path.join(
|
| 498 |
+
self.img_gt_box_dir, "gt_%s.txt" % os.path.splitext(img_name)[0]
|
| 499 |
+
)
|
| 500 |
+
word_bboxes, words = self.load_img_gt_box(img_gt_box_path)
|
| 501 |
+
image, word_bboxes = rescale(image, word_bboxes)
|
| 502 |
+
img_h, img_w, _ = image.shape
|
| 503 |
+
|
| 504 |
+
query_idx = int(self.img_names[index].split(".")[0].split("_")[1])
|
| 505 |
+
|
| 506 |
+
saved_region_scores_path = os.path.join(
|
| 507 |
+
self.saved_gt_dir, f"res_img_{query_idx}_region.jpg"
|
| 508 |
+
)
|
| 509 |
+
saved_affi_scores_path = os.path.join(
|
| 510 |
+
self.saved_gt_dir, f"res_img_{query_idx}_affi.jpg"
|
| 511 |
+
)
|
| 512 |
+
saved_cf_mask_path = os.path.join(
|
| 513 |
+
self.saved_gt_dir, f"res_img_{query_idx}_cf_mask_thresh_0.6.jpg"
|
| 514 |
+
)
|
| 515 |
+
region_score = cv2.imread(saved_region_scores_path, cv2.IMREAD_GRAYSCALE)
|
| 516 |
+
affinity_score = cv2.imread(saved_affi_scores_path, cv2.IMREAD_GRAYSCALE)
|
| 517 |
+
confidence_mask = cv2.imread(saved_cf_mask_path, cv2.IMREAD_GRAYSCALE)
|
| 518 |
+
|
| 519 |
+
region_score = cv2.resize(region_score, (img_w, img_h))
|
| 520 |
+
affinity_score = cv2.resize(affinity_score, (img_w, img_h))
|
| 521 |
+
confidence_mask = cv2.resize(
|
| 522 |
+
confidence_mask, (img_w, img_h), interpolation=cv2.INTER_NEAREST
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
region_score = region_score.astype(np.float32) / 255
|
| 526 |
+
affinity_score = affinity_score.astype(np.float32) / 255
|
| 527 |
+
confidence_mask = confidence_mask.astype(np.float32) / 255
|
| 528 |
+
|
| 529 |
+
# NOTE : Even though word_level_char_bbox is not necessary, align bbox format with make_gt_score()
|
| 530 |
+
word_level_char_bbox = []
|
| 531 |
+
|
| 532 |
+
for i in range(len(word_bboxes)):
|
| 533 |
+
word_level_char_bbox.append(np.expand_dims(word_bboxes[i], 0))
|
| 534 |
+
|
| 535 |
+
return (
|
| 536 |
+
image,
|
| 537 |
+
region_score,
|
| 538 |
+
affinity_score,
|
| 539 |
+
confidence_mask,
|
| 540 |
+
word_level_char_bbox,
|
| 541 |
+
words,
|
| 542 |
+
)
|
trainer/craft/data/gaussian.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
|
| 4 |
+
from data.boxEnlarge import enlargebox
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class GaussianBuilder(object):
|
| 8 |
+
def __init__(self, init_size, sigma, enlarge_region, enlarge_affinity):
|
| 9 |
+
self.init_size = init_size
|
| 10 |
+
self.sigma = sigma
|
| 11 |
+
self.enlarge_region = enlarge_region
|
| 12 |
+
self.enlarge_affinity = enlarge_affinity
|
| 13 |
+
self.gaussian_map, self.gaussian_map_color = self.generate_gaussian_map()
|
| 14 |
+
|
| 15 |
+
def generate_gaussian_map(self):
|
| 16 |
+
circle_mask = self.generate_circle_mask()
|
| 17 |
+
|
| 18 |
+
gaussian_map = np.zeros((self.init_size, self.init_size), np.float32)
|
| 19 |
+
|
| 20 |
+
for i in range(self.init_size):
|
| 21 |
+
for j in range(self.init_size):
|
| 22 |
+
gaussian_map[i, j] = (
|
| 23 |
+
1
|
| 24 |
+
/ 2
|
| 25 |
+
/ np.pi
|
| 26 |
+
/ (self.sigma ** 2)
|
| 27 |
+
* np.exp(
|
| 28 |
+
-1
|
| 29 |
+
/ 2
|
| 30 |
+
* (
|
| 31 |
+
(i - self.init_size / 2) ** 2 / (self.sigma ** 2)
|
| 32 |
+
+ (j - self.init_size / 2) ** 2 / (self.sigma ** 2)
|
| 33 |
+
)
|
| 34 |
+
)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
gaussian_map = gaussian_map * circle_mask
|
| 38 |
+
gaussian_map = (gaussian_map / np.max(gaussian_map)).astype(np.float32)
|
| 39 |
+
|
| 40 |
+
gaussian_map_color = (gaussian_map * 255).astype(np.uint8)
|
| 41 |
+
gaussian_map_color = cv2.applyColorMap(gaussian_map_color, cv2.COLORMAP_JET)
|
| 42 |
+
return gaussian_map, gaussian_map_color
|
| 43 |
+
|
| 44 |
+
def generate_circle_mask(self):
|
| 45 |
+
|
| 46 |
+
zero_arr = np.zeros((self.init_size, self.init_size), np.float32)
|
| 47 |
+
circle_mask = cv2.circle(
|
| 48 |
+
img=zero_arr,
|
| 49 |
+
center=(self.init_size // 2, self.init_size // 2),
|
| 50 |
+
radius=self.init_size // 2,
|
| 51 |
+
color=1,
|
| 52 |
+
thickness=-1,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
return circle_mask
|
| 56 |
+
|
| 57 |
+
def four_point_transform(self, bbox):
|
| 58 |
+
"""
|
| 59 |
+
Using the bbox, standard 2D gaussian map, returns Transformed 2d Gaussian map
|
| 60 |
+
"""
|
| 61 |
+
width, height = (
|
| 62 |
+
np.max(bbox[:, 0]).astype(np.int32),
|
| 63 |
+
np.max(bbox[:, 1]).astype(np.int32),
|
| 64 |
+
)
|
| 65 |
+
init_points = np.array(
|
| 66 |
+
[
|
| 67 |
+
[0, 0],
|
| 68 |
+
[self.init_size, 0],
|
| 69 |
+
[self.init_size, self.init_size],
|
| 70 |
+
[0, self.init_size],
|
| 71 |
+
],
|
| 72 |
+
dtype="float32",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
M = cv2.getPerspectiveTransform(init_points, bbox)
|
| 76 |
+
warped_gaussian_map = cv2.warpPerspective(self.gaussian_map, M, (width, height))
|
| 77 |
+
return warped_gaussian_map, width, height
|
| 78 |
+
|
| 79 |
+
def add_gaussian_map_to_score_map(
|
| 80 |
+
self, score_map, bbox, enlarge_size, horizontal_text_bool, map_type=None
|
| 81 |
+
):
|
| 82 |
+
"""
|
| 83 |
+
Mapping 2D Gaussian to the character box coordinates of the score_map.
|
| 84 |
+
|
| 85 |
+
:param score_map: Target map to put 2D gaussian on character box
|
| 86 |
+
:type score_map: np.float32
|
| 87 |
+
:param bbox: character boxes
|
| 88 |
+
:type bbox: np.float32
|
| 89 |
+
:param enlarge_size: Enlarge size of gaussian map to fit character shape
|
| 90 |
+
:type enlarge_size: list of enlarge size [x dim, y dim]
|
| 91 |
+
:param horizontal_text_bool: Flag that bbox is horizontal text or not
|
| 92 |
+
:type horizontal_text_bool: bool
|
| 93 |
+
:param map_type: Whether map's type is "region" | "affinity"
|
| 94 |
+
:type map_type: str
|
| 95 |
+
:return score_map: score map that all 2D gaussian put on character box
|
| 96 |
+
:rtype: np.float32
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
map_h, map_w = score_map.shape
|
| 100 |
+
bbox = enlargebox(bbox, map_h, map_w, enlarge_size, horizontal_text_bool)
|
| 101 |
+
|
| 102 |
+
# If any one point of character bbox is out of range, don't put in on map
|
| 103 |
+
if np.any(bbox < 0) or np.any(bbox[:, 0] > map_w) or np.any(bbox[:, 1] > map_h):
|
| 104 |
+
return score_map
|
| 105 |
+
|
| 106 |
+
bbox_left, bbox_top = np.array([np.min(bbox[:, 0]), np.min(bbox[:, 1])]).astype(
|
| 107 |
+
np.int32
|
| 108 |
+
)
|
| 109 |
+
bbox -= (bbox_left, bbox_top)
|
| 110 |
+
warped_gaussian_map, width, height = self.four_point_transform(
|
| 111 |
+
bbox.astype(np.float32)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
bbox_area_of_image = score_map[
|
| 116 |
+
bbox_top : bbox_top + height, bbox_left : bbox_left + width,
|
| 117 |
+
]
|
| 118 |
+
high_value_score = np.where(
|
| 119 |
+
warped_gaussian_map > bbox_area_of_image,
|
| 120 |
+
warped_gaussian_map,
|
| 121 |
+
bbox_area_of_image,
|
| 122 |
+
)
|
| 123 |
+
score_map[
|
| 124 |
+
bbox_top : bbox_top + height, bbox_left : bbox_left + width,
|
| 125 |
+
] = high_value_score
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
print("Error : {}".format(e))
|
| 129 |
+
print(
|
| 130 |
+
"On generating {} map, strange box came out. (width: {}, height: {})".format(
|
| 131 |
+
map_type, width, height
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return score_map
|
| 136 |
+
|
| 137 |
+
def calculate_affinity_box_points(self, bbox_1, bbox_2, vertical=False):
|
| 138 |
+
center_1, center_2 = np.mean(bbox_1, axis=0), np.mean(bbox_2, axis=0)
|
| 139 |
+
if vertical:
|
| 140 |
+
tl = (bbox_1[0] + bbox_1[-1] + center_1) / 3
|
| 141 |
+
tr = (bbox_1[1:3].sum(0) + center_1) / 3
|
| 142 |
+
br = (bbox_2[1:3].sum(0) + center_2) / 3
|
| 143 |
+
bl = (bbox_2[0] + bbox_2[-1] + center_2) / 3
|
| 144 |
+
else:
|
| 145 |
+
tl = (bbox_1[0:2].sum(0) + center_1) / 3
|
| 146 |
+
tr = (bbox_2[0:2].sum(0) + center_2) / 3
|
| 147 |
+
br = (bbox_2[2:4].sum(0) + center_2) / 3
|
| 148 |
+
bl = (bbox_1[2:4].sum(0) + center_1) / 3
|
| 149 |
+
affinity_box = np.array([tl, tr, br, bl]).astype(np.float32)
|
| 150 |
+
return affinity_box
|
| 151 |
+
|
| 152 |
+
def generate_region(
|
| 153 |
+
self, img_h, img_w, word_level_char_bbox, horizontal_text_bools
|
| 154 |
+
):
|
| 155 |
+
region_map = np.zeros([img_h, img_w], dtype=np.float32)
|
| 156 |
+
for i in range(
|
| 157 |
+
len(word_level_char_bbox)
|
| 158 |
+
): # shape : [word_num, [char_num_in_one_word, 4, 2]]
|
| 159 |
+
for j in range(len(word_level_char_bbox[i])):
|
| 160 |
+
region_map = self.add_gaussian_map_to_score_map(
|
| 161 |
+
region_map,
|
| 162 |
+
word_level_char_bbox[i][j].copy(),
|
| 163 |
+
self.enlarge_region,
|
| 164 |
+
horizontal_text_bools[i],
|
| 165 |
+
map_type="region",
|
| 166 |
+
)
|
| 167 |
+
return region_map
|
| 168 |
+
|
| 169 |
+
def generate_affinity(
|
| 170 |
+
self, img_h, img_w, word_level_char_bbox, horizontal_text_bools
|
| 171 |
+
):
|
| 172 |
+
|
| 173 |
+
affinity_map = np.zeros([img_h, img_w], dtype=np.float32)
|
| 174 |
+
all_affinity_bbox = []
|
| 175 |
+
for i in range(len(word_level_char_bbox)):
|
| 176 |
+
for j in range(len(word_level_char_bbox[i]) - 1):
|
| 177 |
+
affinity_bbox = self.calculate_affinity_box_points(
|
| 178 |
+
word_level_char_bbox[i][j], word_level_char_bbox[i][j + 1]
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
affinity_map = self.add_gaussian_map_to_score_map(
|
| 182 |
+
affinity_map,
|
| 183 |
+
affinity_bbox.copy(),
|
| 184 |
+
self.enlarge_affinity,
|
| 185 |
+
horizontal_text_bools[i],
|
| 186 |
+
map_type="affinity",
|
| 187 |
+
)
|
| 188 |
+
all_affinity_bbox.append(np.expand_dims(affinity_bbox, axis=0))
|
| 189 |
+
|
| 190 |
+
if len(all_affinity_bbox) > 0:
|
| 191 |
+
all_affinity_bbox = np.concatenate(all_affinity_bbox, axis=0)
|
| 192 |
+
return affinity_map, all_affinity_bbox
|
trainer/craft/data/imgaug.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torchvision.transforms.functional import resized_crop, crop
|
| 7 |
+
from torchvision.transforms import RandomResizedCrop, RandomCrop
|
| 8 |
+
from torchvision.transforms import InterpolationMode
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def rescale(img, bboxes, target_size=2240):
|
| 12 |
+
h, w = img.shape[0:2]
|
| 13 |
+
scale = target_size / max(h, w)
|
| 14 |
+
img = cv2.resize(img, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
| 15 |
+
bboxes = bboxes * scale
|
| 16 |
+
return img, bboxes
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def random_resize_crop_synth(augment_targets, size):
|
| 20 |
+
image, region_score, affinity_score, confidence_mask = augment_targets
|
| 21 |
+
|
| 22 |
+
image = Image.fromarray(image)
|
| 23 |
+
region_score = Image.fromarray(region_score)
|
| 24 |
+
affinity_score = Image.fromarray(affinity_score)
|
| 25 |
+
confidence_mask = Image.fromarray(confidence_mask)
|
| 26 |
+
|
| 27 |
+
short_side = min(image.size)
|
| 28 |
+
i, j, h, w = RandomCrop.get_params(image, output_size=(short_side, short_side))
|
| 29 |
+
|
| 30 |
+
image = resized_crop(
|
| 31 |
+
image, i, j, h, w, size=(size, size), interpolation=InterpolationMode.BICUBIC
|
| 32 |
+
)
|
| 33 |
+
region_score = resized_crop(
|
| 34 |
+
region_score, i, j, h, w, (size, size), interpolation=InterpolationMode.BICUBIC
|
| 35 |
+
)
|
| 36 |
+
affinity_score = resized_crop(
|
| 37 |
+
affinity_score,
|
| 38 |
+
i,
|
| 39 |
+
j,
|
| 40 |
+
h,
|
| 41 |
+
w,
|
| 42 |
+
(size, size),
|
| 43 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 44 |
+
)
|
| 45 |
+
confidence_mask = resized_crop(
|
| 46 |
+
confidence_mask,
|
| 47 |
+
i,
|
| 48 |
+
j,
|
| 49 |
+
h,
|
| 50 |
+
w,
|
| 51 |
+
(size, size),
|
| 52 |
+
interpolation=InterpolationMode.NEAREST,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
image = np.array(image)
|
| 56 |
+
region_score = np.array(region_score)
|
| 57 |
+
affinity_score = np.array(affinity_score)
|
| 58 |
+
confidence_mask = np.array(confidence_mask)
|
| 59 |
+
augment_targets = [image, region_score, affinity_score, confidence_mask]
|
| 60 |
+
|
| 61 |
+
return augment_targets
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def random_resize_crop(
|
| 65 |
+
augment_targets, scale, ratio, size, threshold, pre_crop_area=None
|
| 66 |
+
):
|
| 67 |
+
image, region_score, affinity_score, confidence_mask = augment_targets
|
| 68 |
+
|
| 69 |
+
image = Image.fromarray(image)
|
| 70 |
+
region_score = Image.fromarray(region_score)
|
| 71 |
+
affinity_score = Image.fromarray(affinity_score)
|
| 72 |
+
confidence_mask = Image.fromarray(confidence_mask)
|
| 73 |
+
|
| 74 |
+
if pre_crop_area != None:
|
| 75 |
+
i, j, h, w = pre_crop_area
|
| 76 |
+
|
| 77 |
+
else:
|
| 78 |
+
if random.random() < threshold:
|
| 79 |
+
i, j, h, w = RandomResizedCrop.get_params(image, scale=scale, ratio=ratio)
|
| 80 |
+
else:
|
| 81 |
+
i, j, h, w = RandomResizedCrop.get_params(
|
| 82 |
+
image, scale=(1.0, 1.0), ratio=(1.0, 1.0)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
image = resized_crop(
|
| 86 |
+
image, i, j, h, w, size=(size, size), interpolation=InterpolationMode.BICUBIC
|
| 87 |
+
)
|
| 88 |
+
region_score = resized_crop(
|
| 89 |
+
region_score, i, j, h, w, (size, size), interpolation=InterpolationMode.BICUBIC
|
| 90 |
+
)
|
| 91 |
+
affinity_score = resized_crop(
|
| 92 |
+
affinity_score,
|
| 93 |
+
i,
|
| 94 |
+
j,
|
| 95 |
+
h,
|
| 96 |
+
w,
|
| 97 |
+
(size, size),
|
| 98 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 99 |
+
)
|
| 100 |
+
confidence_mask = resized_crop(
|
| 101 |
+
confidence_mask,
|
| 102 |
+
i,
|
| 103 |
+
j,
|
| 104 |
+
h,
|
| 105 |
+
w,
|
| 106 |
+
(size, size),
|
| 107 |
+
interpolation=InterpolationMode.NEAREST,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
image = np.array(image)
|
| 111 |
+
region_score = np.array(region_score)
|
| 112 |
+
affinity_score = np.array(affinity_score)
|
| 113 |
+
confidence_mask = np.array(confidence_mask)
|
| 114 |
+
augment_targets = [image, region_score, affinity_score, confidence_mask]
|
| 115 |
+
|
| 116 |
+
return augment_targets
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def random_crop(augment_targets, size):
|
| 120 |
+
image, region_score, affinity_score, confidence_mask = augment_targets
|
| 121 |
+
|
| 122 |
+
image = Image.fromarray(image)
|
| 123 |
+
region_score = Image.fromarray(region_score)
|
| 124 |
+
affinity_score = Image.fromarray(affinity_score)
|
| 125 |
+
confidence_mask = Image.fromarray(confidence_mask)
|
| 126 |
+
|
| 127 |
+
i, j, h, w = RandomCrop.get_params(image, output_size=(size, size))
|
| 128 |
+
|
| 129 |
+
image = crop(image, i, j, h, w)
|
| 130 |
+
region_score = crop(region_score, i, j, h, w)
|
| 131 |
+
affinity_score = crop(affinity_score, i, j, h, w)
|
| 132 |
+
confidence_mask = crop(confidence_mask, i, j, h, w)
|
| 133 |
+
|
| 134 |
+
image = np.array(image)
|
| 135 |
+
region_score = np.array(region_score)
|
| 136 |
+
affinity_score = np.array(affinity_score)
|
| 137 |
+
confidence_mask = np.array(confidence_mask)
|
| 138 |
+
augment_targets = [image, region_score, affinity_score, confidence_mask]
|
| 139 |
+
|
| 140 |
+
return augment_targets
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def random_horizontal_flip(imgs):
|
| 144 |
+
if random.random() < 0.5:
|
| 145 |
+
for i in range(len(imgs)):
|
| 146 |
+
imgs[i] = np.flip(imgs[i], axis=1).copy()
|
| 147 |
+
return imgs
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def random_scale(images, word_level_char_bbox, scale_range):
|
| 151 |
+
scale = random.sample(scale_range, 1)[0]
|
| 152 |
+
|
| 153 |
+
for i in range(len(images)):
|
| 154 |
+
images[i] = cv2.resize(images[i], dsize=None, fx=scale, fy=scale)
|
| 155 |
+
|
| 156 |
+
for i in range(len(word_level_char_bbox)):
|
| 157 |
+
word_level_char_bbox[i] *= scale
|
| 158 |
+
|
| 159 |
+
return images
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def random_rotate(images, max_angle):
|
| 163 |
+
angle = random.random() * 2 * max_angle - max_angle
|
| 164 |
+
for i in range(len(images)):
|
| 165 |
+
img = images[i]
|
| 166 |
+
w, h = img.shape[:2]
|
| 167 |
+
rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
|
| 168 |
+
if i == len(images) - 1:
|
| 169 |
+
img_rotation = cv2.warpAffine(
|
| 170 |
+
img, M=rotation_matrix, dsize=(h, w), flags=cv2.INTER_NEAREST
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w))
|
| 174 |
+
images[i] = img_rotation
|
| 175 |
+
return images
|
trainer/craft/data/imgproc.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) 2019-present NAVER Corp.
|
| 3 |
+
MIT License
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# -*- coding: utf-8 -*-
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
from skimage import io
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def loadImage(img_file):
|
| 14 |
+
img = io.imread(img_file) # RGB order
|
| 15 |
+
if img.shape[0] == 2:
|
| 16 |
+
img = img[0]
|
| 17 |
+
if len(img.shape) == 2:
|
| 18 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
| 19 |
+
if img.shape[2] == 4:
|
| 20 |
+
img = img[:, :, :3]
|
| 21 |
+
img = np.array(img)
|
| 22 |
+
|
| 23 |
+
return img
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def normalizeMeanVariance(
|
| 27 |
+
in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)
|
| 28 |
+
):
|
| 29 |
+
# should be RGB order
|
| 30 |
+
img = in_img.copy().astype(np.float32)
|
| 31 |
+
|
| 32 |
+
img -= np.array(
|
| 33 |
+
[mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32
|
| 34 |
+
)
|
| 35 |
+
img /= np.array(
|
| 36 |
+
[variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0],
|
| 37 |
+
dtype=np.float32,
|
| 38 |
+
)
|
| 39 |
+
return img
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def denormalizeMeanVariance(
|
| 43 |
+
in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)
|
| 44 |
+
):
|
| 45 |
+
# should be RGB order
|
| 46 |
+
img = in_img.copy()
|
| 47 |
+
img *= variance
|
| 48 |
+
img += mean
|
| 49 |
+
img *= 255.0
|
| 50 |
+
img = np.clip(img, 0, 255).astype(np.uint8)
|
| 51 |
+
return img
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1):
|
| 55 |
+
height, width, channel = img.shape
|
| 56 |
+
|
| 57 |
+
# magnify image size
|
| 58 |
+
target_size = mag_ratio * max(height, width)
|
| 59 |
+
|
| 60 |
+
# set original image size
|
| 61 |
+
if target_size > square_size:
|
| 62 |
+
target_size = square_size
|
| 63 |
+
|
| 64 |
+
ratio = target_size / max(height, width)
|
| 65 |
+
|
| 66 |
+
target_h, target_w = int(height * ratio), int(width * ratio)
|
| 67 |
+
|
| 68 |
+
# NOTE
|
| 69 |
+
valid_size_heatmap = (int(target_h / 2), int(target_w / 2))
|
| 70 |
+
|
| 71 |
+
proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation)
|
| 72 |
+
|
| 73 |
+
# make canvas and paste image
|
| 74 |
+
target_h32, target_w32 = target_h, target_w
|
| 75 |
+
if target_h % 32 != 0:
|
| 76 |
+
target_h32 = target_h + (32 - target_h % 32)
|
| 77 |
+
if target_w % 32 != 0:
|
| 78 |
+
target_w32 = target_w + (32 - target_w % 32)
|
| 79 |
+
resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
|
| 80 |
+
resized[0:target_h, 0:target_w, :] = proc
|
| 81 |
+
|
| 82 |
+
# target_h, target_w = target_h32, target_w32
|
| 83 |
+
# size_heatmap = (int(target_w/2), int(target_h/2))
|
| 84 |
+
|
| 85 |
+
return resized, ratio, valid_size_heatmap
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def cvt2HeatmapImg(img):
|
| 89 |
+
img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
|
| 90 |
+
img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
|
| 91 |
+
return img
|
trainer/craft/data/pseudo_label/make_charbox.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from data import imgproc
|
| 10 |
+
from data.pseudo_label.watershed import exec_watershed_by_version
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PseudoCharBoxBuilder:
|
| 14 |
+
def __init__(self, watershed_param, vis_test_dir, pseudo_vis_opt, gaussian_builder):
|
| 15 |
+
self.watershed_param = watershed_param
|
| 16 |
+
self.vis_test_dir = vis_test_dir
|
| 17 |
+
self.pseudo_vis_opt = pseudo_vis_opt
|
| 18 |
+
self.gaussian_builder = gaussian_builder
|
| 19 |
+
self.cnt = 0
|
| 20 |
+
self.flag = False
|
| 21 |
+
|
| 22 |
+
def crop_image_by_bbox(self, image, box, word):
|
| 23 |
+
w = max(
|
| 24 |
+
int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[2] - box[3]))
|
| 25 |
+
)
|
| 26 |
+
h = max(
|
| 27 |
+
int(np.linalg.norm(box[0] - box[3])), int(np.linalg.norm(box[1] - box[2]))
|
| 28 |
+
)
|
| 29 |
+
try:
|
| 30 |
+
word_ratio = h / w
|
| 31 |
+
except:
|
| 32 |
+
import ipdb
|
| 33 |
+
|
| 34 |
+
ipdb.set_trace()
|
| 35 |
+
|
| 36 |
+
one_char_ratio = min(h, w) / (max(h, w) / len(word))
|
| 37 |
+
|
| 38 |
+
# NOTE: criterion to split vertical word in here is set to work properly on IC15 dataset
|
| 39 |
+
if word_ratio > 2 or (word_ratio > 1.6 and one_char_ratio > 2.4):
|
| 40 |
+
# warping method of vertical word (classified by upper condition)
|
| 41 |
+
horizontal_text_bool = False
|
| 42 |
+
long_side = h
|
| 43 |
+
short_side = w
|
| 44 |
+
M = cv2.getPerspectiveTransform(
|
| 45 |
+
np.float32(box),
|
| 46 |
+
np.float32(
|
| 47 |
+
np.array(
|
| 48 |
+
[
|
| 49 |
+
[long_side, 0],
|
| 50 |
+
[long_side, short_side],
|
| 51 |
+
[0, short_side],
|
| 52 |
+
[0, 0],
|
| 53 |
+
]
|
| 54 |
+
)
|
| 55 |
+
),
|
| 56 |
+
)
|
| 57 |
+
self.flag = True
|
| 58 |
+
else:
|
| 59 |
+
# warping method of horizontal word
|
| 60 |
+
horizontal_text_bool = True
|
| 61 |
+
long_side = w
|
| 62 |
+
short_side = h
|
| 63 |
+
M = cv2.getPerspectiveTransform(
|
| 64 |
+
np.float32(box),
|
| 65 |
+
np.float32(
|
| 66 |
+
np.array(
|
| 67 |
+
[
|
| 68 |
+
[0, 0],
|
| 69 |
+
[long_side, 0],
|
| 70 |
+
[long_side, short_side],
|
| 71 |
+
[0, short_side],
|
| 72 |
+
]
|
| 73 |
+
)
|
| 74 |
+
),
|
| 75 |
+
)
|
| 76 |
+
self.flag = False
|
| 77 |
+
|
| 78 |
+
warped = cv2.warpPerspective(image, M, (long_side, short_side))
|
| 79 |
+
return warped, M, horizontal_text_bool
|
| 80 |
+
|
| 81 |
+
def inference_word_box(self, net, gpu, word_image):
|
| 82 |
+
if net.training:
|
| 83 |
+
net.eval()
|
| 84 |
+
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
word_img_torch = torch.from_numpy(
|
| 87 |
+
imgproc.normalizeMeanVariance(
|
| 88 |
+
word_image,
|
| 89 |
+
mean=(0.485, 0.456, 0.406),
|
| 90 |
+
variance=(0.229, 0.224, 0.225),
|
| 91 |
+
)
|
| 92 |
+
)
|
| 93 |
+
word_img_torch = word_img_torch.permute(2, 0, 1).unsqueeze(0)
|
| 94 |
+
word_img_torch = word_img_torch.type(torch.FloatTensor).cuda(gpu)
|
| 95 |
+
with torch.cuda.amp.autocast():
|
| 96 |
+
word_img_scores, _ = net(word_img_torch)
|
| 97 |
+
return word_img_scores
|
| 98 |
+
|
| 99 |
+
def visualize_pseudo_label(
|
| 100 |
+
self, word_image, region_score, watershed_box, pseudo_char_bbox, img_name,
|
| 101 |
+
):
|
| 102 |
+
word_img_h, word_img_w, _ = word_image.shape
|
| 103 |
+
word_img_cp1 = word_image.copy()
|
| 104 |
+
word_img_cp2 = word_image.copy()
|
| 105 |
+
_watershed_box = np.int32(watershed_box)
|
| 106 |
+
_pseudo_char_bbox = np.int32(pseudo_char_bbox)
|
| 107 |
+
|
| 108 |
+
region_score_color = cv2.applyColorMap(np.uint8(region_score), cv2.COLORMAP_JET)
|
| 109 |
+
region_score_color = cv2.resize(region_score_color, (word_img_w, word_img_h))
|
| 110 |
+
|
| 111 |
+
for box in _watershed_box:
|
| 112 |
+
cv2.polylines(
|
| 113 |
+
np.uint8(word_img_cp1),
|
| 114 |
+
[np.reshape(box, (-1, 1, 2))],
|
| 115 |
+
True,
|
| 116 |
+
(255, 0, 0),
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
for box in _pseudo_char_bbox:
|
| 120 |
+
cv2.polylines(
|
| 121 |
+
np.uint8(word_img_cp2), [np.reshape(box, (-1, 1, 2))], True, (255, 0, 0)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# NOTE: Just for visualize, put gaussian map on char box
|
| 125 |
+
pseudo_gt_region_score = self.gaussian_builder.generate_region(
|
| 126 |
+
word_img_h, word_img_w, [_pseudo_char_bbox], [True]
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
pseudo_gt_region_score = cv2.applyColorMap(
|
| 130 |
+
(pseudo_gt_region_score * 255).astype("uint8"), cv2.COLORMAP_JET
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
overlay_img = cv2.addWeighted(
|
| 134 |
+
word_image[:, :, ::-1], 0.7, pseudo_gt_region_score, 0.3, 5
|
| 135 |
+
)
|
| 136 |
+
vis_result = np.hstack(
|
| 137 |
+
[
|
| 138 |
+
word_image[:, :, ::-1],
|
| 139 |
+
region_score_color,
|
| 140 |
+
word_img_cp1[:, :, ::-1],
|
| 141 |
+
word_img_cp2[:, :, ::-1],
|
| 142 |
+
pseudo_gt_region_score,
|
| 143 |
+
overlay_img,
|
| 144 |
+
]
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
if not os.path.exists(os.path.dirname(self.vis_test_dir)):
|
| 148 |
+
os.makedirs(os.path.dirname(self.vis_test_dir))
|
| 149 |
+
cv2.imwrite(
|
| 150 |
+
os.path.join(
|
| 151 |
+
self.vis_test_dir,
|
| 152 |
+
"{}_{}".format(
|
| 153 |
+
img_name, f"pseudo_char_bbox_{random.randint(0,100)}.jpg"
|
| 154 |
+
),
|
| 155 |
+
),
|
| 156 |
+
vis_result,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def clip_into_boundary(self, box, bound):
|
| 160 |
+
if len(box) == 0:
|
| 161 |
+
return box
|
| 162 |
+
else:
|
| 163 |
+
box[:, :, 0] = np.clip(box[:, :, 0], 0, bound[1])
|
| 164 |
+
box[:, :, 1] = np.clip(box[:, :, 1], 0, bound[0])
|
| 165 |
+
return box
|
| 166 |
+
|
| 167 |
+
def get_confidence(self, real_len, pseudo_len):
|
| 168 |
+
if pseudo_len == 0:
|
| 169 |
+
return 0.0
|
| 170 |
+
return (real_len - min(real_len, abs(real_len - pseudo_len))) / real_len
|
| 171 |
+
|
| 172 |
+
def split_word_equal_gap(self, word_img_w, word_img_h, word):
|
| 173 |
+
width = word_img_w
|
| 174 |
+
height = word_img_h
|
| 175 |
+
|
| 176 |
+
width_per_char = width / len(word)
|
| 177 |
+
bboxes = []
|
| 178 |
+
for j, char in enumerate(word):
|
| 179 |
+
if char == " ":
|
| 180 |
+
continue
|
| 181 |
+
left = j * width_per_char
|
| 182 |
+
right = (j + 1) * width_per_char
|
| 183 |
+
bbox = np.array([[left, 0], [right, 0], [right, height], [left, height]])
|
| 184 |
+
bboxes.append(bbox)
|
| 185 |
+
|
| 186 |
+
bboxes = np.array(bboxes, np.float32)
|
| 187 |
+
return bboxes
|
| 188 |
+
|
| 189 |
+
def cal_angle(self, v1):
|
| 190 |
+
theta = np.arccos(min(1, v1[0] / (np.linalg.norm(v1) + 10e-8)))
|
| 191 |
+
return 2 * math.pi - theta if v1[1] < 0 else theta
|
| 192 |
+
|
| 193 |
+
def clockwise_sort(self, points):
|
| 194 |
+
# returns 4x2 [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] ndarray
|
| 195 |
+
v1, v2, v3, v4 = points
|
| 196 |
+
center = (v1 + v2 + v3 + v4) / 4
|
| 197 |
+
theta = np.array(
|
| 198 |
+
[
|
| 199 |
+
self.cal_angle(v1 - center),
|
| 200 |
+
self.cal_angle(v2 - center),
|
| 201 |
+
self.cal_angle(v3 - center),
|
| 202 |
+
self.cal_angle(v4 - center),
|
| 203 |
+
]
|
| 204 |
+
)
|
| 205 |
+
index = np.argsort(theta)
|
| 206 |
+
return np.array([v1, v2, v3, v4])[index, :]
|
| 207 |
+
|
| 208 |
+
def build_char_box(self, net, gpu, image, word_bbox, word, img_name=""):
|
| 209 |
+
word_image, M, horizontal_text_bool = self.crop_image_by_bbox(
|
| 210 |
+
image, word_bbox, word
|
| 211 |
+
)
|
| 212 |
+
real_word_without_space = word.replace("\s", "")
|
| 213 |
+
real_char_len = len(real_word_without_space)
|
| 214 |
+
|
| 215 |
+
scale = 128.0 / word_image.shape[0]
|
| 216 |
+
|
| 217 |
+
word_image = cv2.resize(word_image, None, fx=scale, fy=scale)
|
| 218 |
+
word_img_h, word_img_w, _ = word_image.shape
|
| 219 |
+
|
| 220 |
+
scores = self.inference_word_box(net, gpu, word_image)
|
| 221 |
+
region_score = scores[0, :, :, 0].cpu().data.numpy()
|
| 222 |
+
region_score = np.uint8(np.clip(region_score, 0, 1) * 255)
|
| 223 |
+
|
| 224 |
+
region_score_rgb = cv2.resize(region_score, (word_img_w, word_img_h))
|
| 225 |
+
region_score_rgb = cv2.cvtColor(region_score_rgb, cv2.COLOR_GRAY2RGB)
|
| 226 |
+
|
| 227 |
+
pseudo_char_bbox = exec_watershed_by_version(
|
| 228 |
+
self.watershed_param, region_score, word_image, self.pseudo_vis_opt
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Used for visualize only
|
| 232 |
+
watershed_box = pseudo_char_bbox.copy()
|
| 233 |
+
|
| 234 |
+
pseudo_char_bbox = self.clip_into_boundary(
|
| 235 |
+
pseudo_char_bbox, region_score_rgb.shape
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
confidence = self.get_confidence(real_char_len, len(pseudo_char_bbox))
|
| 239 |
+
|
| 240 |
+
if confidence <= 0.5:
|
| 241 |
+
pseudo_char_bbox = self.split_word_equal_gap(word_img_w, word_img_h, word)
|
| 242 |
+
confidence = 0.5
|
| 243 |
+
|
| 244 |
+
if self.pseudo_vis_opt and self.flag:
|
| 245 |
+
self.visualize_pseudo_label(
|
| 246 |
+
word_image, region_score, watershed_box, pseudo_char_bbox, img_name,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if len(pseudo_char_bbox) != 0:
|
| 250 |
+
index = np.argsort(pseudo_char_bbox[:, 0, 0])
|
| 251 |
+
pseudo_char_bbox = pseudo_char_bbox[index]
|
| 252 |
+
|
| 253 |
+
pseudo_char_bbox /= scale
|
| 254 |
+
|
| 255 |
+
M_inv = np.linalg.pinv(M)
|
| 256 |
+
for i in range(len(pseudo_char_bbox)):
|
| 257 |
+
pseudo_char_bbox[i] = cv2.perspectiveTransform(
|
| 258 |
+
pseudo_char_bbox[i][None, :, :], M_inv
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
pseudo_char_bbox = self.clip_into_boundary(pseudo_char_bbox, image.shape)
|
| 262 |
+
|
| 263 |
+
return pseudo_char_bbox, confidence, horizontal_text_bool
|
trainer/craft/data/pseudo_label/watershed.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from skimage.segmentation import watershed
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def segment_region_score(watershed_param, region_score, word_image, pseudo_vis_opt):
|
| 7 |
+
region_score = np.float32(region_score) / 255
|
| 8 |
+
fore = np.uint8(region_score > 0.75)
|
| 9 |
+
back = np.uint8(region_score < 0.05)
|
| 10 |
+
unknown = 1 - (fore + back)
|
| 11 |
+
ret, markers = cv2.connectedComponents(fore)
|
| 12 |
+
markers += 1
|
| 13 |
+
markers[unknown == 1] = 0
|
| 14 |
+
|
| 15 |
+
labels = watershed(-region_score, markers)
|
| 16 |
+
boxes = []
|
| 17 |
+
for label in range(2, ret + 1):
|
| 18 |
+
y, x = np.where(labels == label)
|
| 19 |
+
x_max = x.max()
|
| 20 |
+
y_max = y.max()
|
| 21 |
+
x_min = x.min()
|
| 22 |
+
y_min = y.min()
|
| 23 |
+
box = [[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]]
|
| 24 |
+
box = np.array(box)
|
| 25 |
+
box *= 2
|
| 26 |
+
boxes.append(box)
|
| 27 |
+
return np.array(boxes, dtype=np.float32)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def exec_watershed_by_version(
|
| 31 |
+
watershed_param, region_score, word_image, pseudo_vis_opt
|
| 32 |
+
):
|
| 33 |
+
|
| 34 |
+
func_name_map_dict = {
|
| 35 |
+
"skimage": segment_region_score,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
return func_name_map_dict[watershed_param.version](
|
| 40 |
+
watershed_param, region_score, word_image, pseudo_vis_opt
|
| 41 |
+
)
|
| 42 |
+
except:
|
| 43 |
+
print(
|
| 44 |
+
f"Watershed version {watershed_param.version} does not exist in func_name_map_dict."
|
| 45 |
+
)
|
trainer/craft/data_root_dir/folder.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
place dataset folder here
|
trainer/craft/eval.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.backends.cudnn as cudnn
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import wandb
|
| 12 |
+
|
| 13 |
+
from config.load_config import load_yaml, DotDict
|
| 14 |
+
from model.craft import CRAFT
|
| 15 |
+
from metrics.eval_det_iou import DetectionIoUEvaluator
|
| 16 |
+
from utils.inference_boxes import (
|
| 17 |
+
test_net,
|
| 18 |
+
load_icdar2015_gt,
|
| 19 |
+
load_icdar2013_gt,
|
| 20 |
+
load_synthtext_gt,
|
| 21 |
+
)
|
| 22 |
+
from utils.util import copyStateDict
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def save_result_synth(img_file, img, pre_output, pre_box, gt_box=None, result_dir=""):
|
| 27 |
+
|
| 28 |
+
img = np.array(img)
|
| 29 |
+
img_copy = img.copy()
|
| 30 |
+
region = pre_output[0]
|
| 31 |
+
affinity = pre_output[1]
|
| 32 |
+
|
| 33 |
+
# make result file list
|
| 34 |
+
filename, file_ext = os.path.splitext(os.path.basename(img_file))
|
| 35 |
+
|
| 36 |
+
# draw bounding boxes for prediction, color green
|
| 37 |
+
for i, box in enumerate(pre_box):
|
| 38 |
+
poly = np.array(box).astype(np.int32).reshape((-1))
|
| 39 |
+
poly = poly.reshape(-1, 2)
|
| 40 |
+
try:
|
| 41 |
+
cv2.polylines(
|
| 42 |
+
img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
|
| 43 |
+
)
|
| 44 |
+
except:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
# draw bounding boxes for gt, color red
|
| 48 |
+
if gt_box is not None:
|
| 49 |
+
for j in range(len(gt_box)):
|
| 50 |
+
cv2.polylines(
|
| 51 |
+
img,
|
| 52 |
+
[np.array(gt_box[j]["points"]).astype(np.int32).reshape((-1, 1, 2))],
|
| 53 |
+
True,
|
| 54 |
+
color=(0, 0, 255),
|
| 55 |
+
thickness=2,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# draw overlay image
|
| 59 |
+
overlay_img = overlay(img_copy, region, affinity, pre_box)
|
| 60 |
+
|
| 61 |
+
# Save result image
|
| 62 |
+
res_img_path = result_dir + "/res_" + filename + ".jpg"
|
| 63 |
+
cv2.imwrite(res_img_path, img)
|
| 64 |
+
|
| 65 |
+
overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
|
| 66 |
+
cv2.imwrite(overlay_image_path, overlay_img)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def save_result_2015(img_file, img, pre_output, pre_box, gt_box, result_dir):
|
| 70 |
+
|
| 71 |
+
img = np.array(img)
|
| 72 |
+
img_copy = img.copy()
|
| 73 |
+
region = pre_output[0]
|
| 74 |
+
affinity = pre_output[1]
|
| 75 |
+
|
| 76 |
+
# make result file list
|
| 77 |
+
filename, file_ext = os.path.splitext(os.path.basename(img_file))
|
| 78 |
+
|
| 79 |
+
for i, box in enumerate(pre_box):
|
| 80 |
+
poly = np.array(box).astype(np.int32).reshape((-1))
|
| 81 |
+
poly = poly.reshape(-1, 2)
|
| 82 |
+
try:
|
| 83 |
+
cv2.polylines(
|
| 84 |
+
img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
|
| 85 |
+
)
|
| 86 |
+
except:
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
if gt_box is not None:
|
| 90 |
+
for j in range(len(gt_box)):
|
| 91 |
+
_gt_box = np.array(gt_box[j]["points"]).reshape(-1, 2).astype(np.int32)
|
| 92 |
+
if gt_box[j]["text"] == "###":
|
| 93 |
+
cv2.polylines(img, [_gt_box], True, color=(128, 128, 128), thickness=2)
|
| 94 |
+
else:
|
| 95 |
+
cv2.polylines(img, [_gt_box], True, color=(0, 0, 255), thickness=2)
|
| 96 |
+
|
| 97 |
+
# draw overlay image
|
| 98 |
+
overlay_img = overlay(img_copy, region, affinity, pre_box)
|
| 99 |
+
|
| 100 |
+
# Save result image
|
| 101 |
+
res_img_path = result_dir + "/res_" + filename + ".jpg"
|
| 102 |
+
cv2.imwrite(res_img_path, img)
|
| 103 |
+
|
| 104 |
+
overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
|
| 105 |
+
cv2.imwrite(overlay_image_path, overlay_img)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def save_result_2013(img_file, img, pre_output, pre_box, gt_box=None, result_dir=""):
|
| 109 |
+
|
| 110 |
+
img = np.array(img)
|
| 111 |
+
img_copy = img.copy()
|
| 112 |
+
region = pre_output[0]
|
| 113 |
+
affinity = pre_output[1]
|
| 114 |
+
|
| 115 |
+
# make result file list
|
| 116 |
+
filename, file_ext = os.path.splitext(os.path.basename(img_file))
|
| 117 |
+
|
| 118 |
+
# draw bounding boxes for prediction, color green
|
| 119 |
+
for i, box in enumerate(pre_box):
|
| 120 |
+
poly = np.array(box).astype(np.int32).reshape((-1))
|
| 121 |
+
poly = poly.reshape(-1, 2)
|
| 122 |
+
try:
|
| 123 |
+
cv2.polylines(
|
| 124 |
+
img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
|
| 125 |
+
)
|
| 126 |
+
except:
|
| 127 |
+
pass
|
| 128 |
+
|
| 129 |
+
# draw bounding boxes for gt, color red
|
| 130 |
+
if gt_box is not None:
|
| 131 |
+
for j in range(len(gt_box)):
|
| 132 |
+
cv2.polylines(
|
| 133 |
+
img,
|
| 134 |
+
[np.array(gt_box[j]["points"]).reshape((-1, 1, 2))],
|
| 135 |
+
True,
|
| 136 |
+
color=(0, 0, 255),
|
| 137 |
+
thickness=2,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# draw overlay image
|
| 141 |
+
overlay_img = overlay(img_copy, region, affinity, pre_box)
|
| 142 |
+
|
| 143 |
+
# Save result image
|
| 144 |
+
res_img_path = result_dir + "/res_" + filename + ".jpg"
|
| 145 |
+
cv2.imwrite(res_img_path, img)
|
| 146 |
+
|
| 147 |
+
overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
|
| 148 |
+
cv2.imwrite(overlay_image_path, overlay_img)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def overlay(image, region, affinity, single_img_bbox):
|
| 152 |
+
|
| 153 |
+
height, width, channel = image.shape
|
| 154 |
+
|
| 155 |
+
region_score = cv2.resize(region, (width, height))
|
| 156 |
+
affinity_score = cv2.resize(affinity, (width, height))
|
| 157 |
+
|
| 158 |
+
overlay_region = cv2.addWeighted(image.copy(), 0.4, region_score, 0.6, 5)
|
| 159 |
+
overlay_aff = cv2.addWeighted(image.copy(), 0.4, affinity_score, 0.6, 5)
|
| 160 |
+
|
| 161 |
+
boxed_img = image.copy()
|
| 162 |
+
for word_box in single_img_bbox:
|
| 163 |
+
cv2.polylines(
|
| 164 |
+
boxed_img,
|
| 165 |
+
[word_box.astype(np.int32).reshape((-1, 1, 2))],
|
| 166 |
+
True,
|
| 167 |
+
color=(0, 255, 0),
|
| 168 |
+
thickness=3,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
temp1 = np.hstack([image, boxed_img])
|
| 172 |
+
temp2 = np.hstack([overlay_region, overlay_aff])
|
| 173 |
+
temp3 = np.vstack([temp1, temp2])
|
| 174 |
+
|
| 175 |
+
return temp3
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def load_test_dataset_iou(test_folder_name, config):
|
| 179 |
+
|
| 180 |
+
if test_folder_name == "synthtext":
|
| 181 |
+
total_bboxes_gt, total_img_path = load_synthtext_gt(config.test_data_dir)
|
| 182 |
+
|
| 183 |
+
elif test_folder_name == "icdar2013":
|
| 184 |
+
total_bboxes_gt, total_img_path = load_icdar2013_gt(
|
| 185 |
+
dataFolder=config.test_data_dir
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
elif test_folder_name == "icdar2015":
|
| 189 |
+
total_bboxes_gt, total_img_path = load_icdar2015_gt(
|
| 190 |
+
dataFolder=config.test_data_dir
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
elif test_folder_name == "custom_data":
|
| 194 |
+
total_bboxes_gt, total_img_path = load_icdar2015_gt(
|
| 195 |
+
dataFolder=config.test_data_dir
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
else:
|
| 199 |
+
print("not found test dataset")
|
| 200 |
+
return None, None
|
| 201 |
+
|
| 202 |
+
return total_bboxes_gt, total_img_path
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def viz_test(img, pre_output, pre_box, gt_box, img_name, result_dir, test_folder_name):
|
| 206 |
+
|
| 207 |
+
if test_folder_name == "synthtext":
|
| 208 |
+
save_result_synth(
|
| 209 |
+
img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
|
| 210 |
+
)
|
| 211 |
+
elif test_folder_name == "icdar2013":
|
| 212 |
+
save_result_2013(
|
| 213 |
+
img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
|
| 214 |
+
)
|
| 215 |
+
elif test_folder_name == "icdar2015":
|
| 216 |
+
save_result_2015(
|
| 217 |
+
img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
|
| 218 |
+
)
|
| 219 |
+
elif test_folder_name == "custom_data":
|
| 220 |
+
save_result_2015(
|
| 221 |
+
img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
|
| 222 |
+
)
|
| 223 |
+
else:
|
| 224 |
+
print("not found test dataset")
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def main_eval(model_path, backbone, config, evaluator, result_dir, buffer, model, mode):
|
| 228 |
+
|
| 229 |
+
if not os.path.exists(result_dir):
|
| 230 |
+
os.makedirs(result_dir, exist_ok=True)
|
| 231 |
+
|
| 232 |
+
total_imgs_bboxes_gt, total_imgs_path = load_test_dataset_iou("custom_data", config)
|
| 233 |
+
|
| 234 |
+
if mode == "weak_supervision" and torch.cuda.device_count() != 1:
|
| 235 |
+
gpu_count = torch.cuda.device_count() // 2
|
| 236 |
+
else:
|
| 237 |
+
gpu_count = torch.cuda.device_count()
|
| 238 |
+
gpu_idx = torch.cuda.current_device()
|
| 239 |
+
torch.cuda.set_device(gpu_idx)
|
| 240 |
+
|
| 241 |
+
# Only evaluation time
|
| 242 |
+
if model is None:
|
| 243 |
+
piece_imgs_path = total_imgs_path
|
| 244 |
+
|
| 245 |
+
if backbone == "vgg":
|
| 246 |
+
model = CRAFT()
|
| 247 |
+
else:
|
| 248 |
+
raise Exception("Undefined architecture")
|
| 249 |
+
|
| 250 |
+
print("Loading weights from checkpoint (" + model_path + ")")
|
| 251 |
+
net_param = torch.load(model_path, map_location=f"cuda:{gpu_idx}")
|
| 252 |
+
model.load_state_dict(copyStateDict(net_param["craft"]))
|
| 253 |
+
|
| 254 |
+
if config.cuda:
|
| 255 |
+
model = model.cuda()
|
| 256 |
+
cudnn.benchmark = False
|
| 257 |
+
|
| 258 |
+
# Distributed evaluation in the middle of training time
|
| 259 |
+
else:
|
| 260 |
+
if buffer is not None:
|
| 261 |
+
# check all buffer value is None for distributed evaluation
|
| 262 |
+
assert all(
|
| 263 |
+
v is None for v in buffer
|
| 264 |
+
), "Buffer already filled with another value."
|
| 265 |
+
slice_idx = len(total_imgs_bboxes_gt) // gpu_count
|
| 266 |
+
|
| 267 |
+
# last gpu
|
| 268 |
+
if gpu_idx == gpu_count - 1:
|
| 269 |
+
piece_imgs_path = total_imgs_path[gpu_idx * slice_idx :]
|
| 270 |
+
# piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx:]
|
| 271 |
+
else:
|
| 272 |
+
piece_imgs_path = total_imgs_path[
|
| 273 |
+
gpu_idx * slice_idx : (gpu_idx + 1) * slice_idx
|
| 274 |
+
]
|
| 275 |
+
# piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx: (gpu_idx + 1) * slice_idx]
|
| 276 |
+
|
| 277 |
+
model.eval()
|
| 278 |
+
|
| 279 |
+
# -----------------------------------------------------------------------------------------------------------------#
|
| 280 |
+
total_imgs_bboxes_pre = []
|
| 281 |
+
for k, img_path in enumerate(tqdm(piece_imgs_path)):
|
| 282 |
+
image = cv2.imread(img_path)
|
| 283 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 284 |
+
single_img_bbox = []
|
| 285 |
+
bboxes, polys, score_text = test_net(
|
| 286 |
+
model,
|
| 287 |
+
image,
|
| 288 |
+
config.text_threshold,
|
| 289 |
+
config.link_threshold,
|
| 290 |
+
config.low_text,
|
| 291 |
+
config.cuda,
|
| 292 |
+
config.poly,
|
| 293 |
+
config.canvas_size,
|
| 294 |
+
config.mag_ratio,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
for box in bboxes:
|
| 298 |
+
box_info = {"points": box, "text": "###", "ignore": False}
|
| 299 |
+
single_img_bbox.append(box_info)
|
| 300 |
+
total_imgs_bboxes_pre.append(single_img_bbox)
|
| 301 |
+
# Distributed evaluation -------------------------------------------------------------------------------------#
|
| 302 |
+
if buffer is not None:
|
| 303 |
+
buffer[gpu_idx * slice_idx + k] = single_img_bbox
|
| 304 |
+
# print(sum([element is not None for element in buffer]))
|
| 305 |
+
# -------------------------------------------------------------------------------------------------------------#
|
| 306 |
+
|
| 307 |
+
if config.vis_opt:
|
| 308 |
+
viz_test(
|
| 309 |
+
image,
|
| 310 |
+
score_text,
|
| 311 |
+
pre_box=polys,
|
| 312 |
+
gt_box=total_imgs_bboxes_gt[k],
|
| 313 |
+
img_name=img_path,
|
| 314 |
+
result_dir=result_dir,
|
| 315 |
+
test_folder_name="custom_data",
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# When distributed evaluation mode, wait until buffer is full filled
|
| 319 |
+
if buffer is not None:
|
| 320 |
+
while None in buffer:
|
| 321 |
+
continue
|
| 322 |
+
assert all(v is not None for v in buffer), "Buffer not filled"
|
| 323 |
+
total_imgs_bboxes_pre = buffer
|
| 324 |
+
|
| 325 |
+
results = []
|
| 326 |
+
for i, (gt, pred) in enumerate(zip(total_imgs_bboxes_gt, total_imgs_bboxes_pre)):
|
| 327 |
+
perSampleMetrics_dict = evaluator.evaluate_image(gt, pred)
|
| 328 |
+
results.append(perSampleMetrics_dict)
|
| 329 |
+
|
| 330 |
+
metrics = evaluator.combine_results(results)
|
| 331 |
+
print(metrics)
|
| 332 |
+
return metrics
|
| 333 |
+
|
| 334 |
+
def cal_eval(config, data, res_dir_name, opt, mode):
|
| 335 |
+
evaluator = DetectionIoUEvaluator()
|
| 336 |
+
test_config = DotDict(config.test[data])
|
| 337 |
+
res_dir = os.path.join(os.path.join("exp", args.yaml), "{}".format(res_dir_name))
|
| 338 |
+
|
| 339 |
+
if opt == "iou_eval":
|
| 340 |
+
main_eval(
|
| 341 |
+
config.test.trained_model,
|
| 342 |
+
config.train.backbone,
|
| 343 |
+
test_config,
|
| 344 |
+
evaluator,
|
| 345 |
+
res_dir,
|
| 346 |
+
buffer=None,
|
| 347 |
+
model=None,
|
| 348 |
+
mode=mode,
|
| 349 |
+
)
|
| 350 |
+
else:
|
| 351 |
+
print("Undefined evaluation")
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
if __name__ == "__main__":
|
| 355 |
+
|
| 356 |
+
parser = argparse.ArgumentParser(description="CRAFT Text Detection Eval")
|
| 357 |
+
parser.add_argument(
|
| 358 |
+
"--yaml",
|
| 359 |
+
"--yaml_file_name",
|
| 360 |
+
default="custom_data_train",
|
| 361 |
+
type=str,
|
| 362 |
+
help="Load configuration",
|
| 363 |
+
)
|
| 364 |
+
args = parser.parse_args()
|
| 365 |
+
|
| 366 |
+
# load configure
|
| 367 |
+
config = load_yaml(args.yaml)
|
| 368 |
+
config = DotDict(config)
|
| 369 |
+
|
| 370 |
+
if config["wandb_opt"]:
|
| 371 |
+
wandb.init(project="evaluation", entity="gmuffiness", name=args.yaml)
|
| 372 |
+
wandb.config.update(config)
|
| 373 |
+
|
| 374 |
+
val_result_dir_name = args.yaml
|
| 375 |
+
cal_eval(
|
| 376 |
+
config,
|
| 377 |
+
"custom_data",
|
| 378 |
+
val_result_dir_name + "-ic15-iou",
|
| 379 |
+
opt="iou_eval",
|
| 380 |
+
mode=None,
|
| 381 |
+
)
|
trainer/craft/exp/folder.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
trained model will be saved here
|
trainer/craft/loss/mseloss.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Loss(nn.Module):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
super(Loss, self).__init__()
|
| 8 |
+
|
| 9 |
+
def forward(self, gt_region, gt_affinity, pred_region, pred_affinity, conf_map):
|
| 10 |
+
loss = torch.mean(
|
| 11 |
+
((gt_region - pred_region).pow(2) + (gt_affinity - pred_affinity).pow(2))
|
| 12 |
+
* conf_map
|
| 13 |
+
)
|
| 14 |
+
return loss
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Maploss_v2(nn.Module):
|
| 18 |
+
def __init__(self):
|
| 19 |
+
|
| 20 |
+
super(Maploss_v2, self).__init__()
|
| 21 |
+
|
| 22 |
+
def batch_image_loss(self, pred_score, label_score, neg_rto, n_min_neg):
|
| 23 |
+
|
| 24 |
+
# positive_loss
|
| 25 |
+
positive_pixel = (label_score > 0.1).float()
|
| 26 |
+
positive_pixel_number = torch.sum(positive_pixel)
|
| 27 |
+
|
| 28 |
+
positive_loss_region = pred_score * positive_pixel
|
| 29 |
+
|
| 30 |
+
# negative_loss
|
| 31 |
+
negative_pixel = (label_score <= 0.1).float()
|
| 32 |
+
negative_pixel_number = torch.sum(negative_pixel)
|
| 33 |
+
negative_loss_region = pred_score * negative_pixel
|
| 34 |
+
|
| 35 |
+
if positive_pixel_number != 0:
|
| 36 |
+
if negative_pixel_number < neg_rto * positive_pixel_number:
|
| 37 |
+
negative_loss = (
|
| 38 |
+
torch.sum(
|
| 39 |
+
torch.topk(
|
| 40 |
+
negative_loss_region.view(-1), n_min_neg, sorted=False
|
| 41 |
+
)[0]
|
| 42 |
+
)
|
| 43 |
+
/ n_min_neg
|
| 44 |
+
)
|
| 45 |
+
else:
|
| 46 |
+
negative_loss = torch.sum(
|
| 47 |
+
torch.topk(
|
| 48 |
+
negative_loss_region.view(-1),
|
| 49 |
+
int(neg_rto * positive_pixel_number),
|
| 50 |
+
sorted=False,
|
| 51 |
+
)[0]
|
| 52 |
+
) / (positive_pixel_number * neg_rto)
|
| 53 |
+
positive_loss = torch.sum(positive_loss_region) / positive_pixel_number
|
| 54 |
+
else:
|
| 55 |
+
# only negative pixel
|
| 56 |
+
negative_loss = (
|
| 57 |
+
torch.sum(
|
| 58 |
+
torch.topk(negative_loss_region.view(-1), n_min_neg, sorted=False)[
|
| 59 |
+
0
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
/ n_min_neg
|
| 63 |
+
)
|
| 64 |
+
positive_loss = 0.0
|
| 65 |
+
total_loss = positive_loss + negative_loss
|
| 66 |
+
return total_loss
|
| 67 |
+
|
| 68 |
+
def forward(
|
| 69 |
+
self,
|
| 70 |
+
region_scores_label,
|
| 71 |
+
affinity_socres_label,
|
| 72 |
+
region_scores_pre,
|
| 73 |
+
affinity_scores_pre,
|
| 74 |
+
mask,
|
| 75 |
+
neg_rto,
|
| 76 |
+
n_min_neg,
|
| 77 |
+
):
|
| 78 |
+
loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
|
| 79 |
+
assert (
|
| 80 |
+
region_scores_label.size() == region_scores_pre.size()
|
| 81 |
+
and affinity_socres_label.size() == affinity_scores_pre.size()
|
| 82 |
+
)
|
| 83 |
+
loss1 = loss_fn(region_scores_pre, region_scores_label)
|
| 84 |
+
loss2 = loss_fn(affinity_scores_pre, affinity_socres_label)
|
| 85 |
+
|
| 86 |
+
loss_region = torch.mul(loss1, mask)
|
| 87 |
+
loss_affinity = torch.mul(loss2, mask)
|
| 88 |
+
|
| 89 |
+
char_loss = self.batch_image_loss(
|
| 90 |
+
loss_region, region_scores_label, neg_rto, n_min_neg
|
| 91 |
+
)
|
| 92 |
+
affi_loss = self.batch_image_loss(
|
| 93 |
+
loss_affinity, affinity_socres_label, neg_rto, n_min_neg
|
| 94 |
+
)
|
| 95 |
+
return char_loss + affi_loss
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Maploss_v3(nn.Module):
|
| 99 |
+
def __init__(self):
|
| 100 |
+
|
| 101 |
+
super(Maploss_v3, self).__init__()
|
| 102 |
+
|
| 103 |
+
def single_image_loss(self, pre_loss, loss_label, neg_rto, n_min_neg):
|
| 104 |
+
|
| 105 |
+
batch_size = pre_loss.shape[0]
|
| 106 |
+
|
| 107 |
+
positive_loss, negative_loss = 0, 0
|
| 108 |
+
for single_loss, single_label in zip(pre_loss, loss_label):
|
| 109 |
+
|
| 110 |
+
# positive_loss
|
| 111 |
+
pos_pixel = (single_label >= 0.1).float()
|
| 112 |
+
n_pos_pixel = torch.sum(pos_pixel)
|
| 113 |
+
pos_loss_region = single_loss * pos_pixel
|
| 114 |
+
positive_loss += torch.sum(pos_loss_region) / max(n_pos_pixel, 1e-12)
|
| 115 |
+
|
| 116 |
+
# negative_loss
|
| 117 |
+
neg_pixel = (single_label < 0.1).float()
|
| 118 |
+
n_neg_pixel = torch.sum(neg_pixel)
|
| 119 |
+
neg_loss_region = single_loss * neg_pixel
|
| 120 |
+
|
| 121 |
+
if n_pos_pixel != 0:
|
| 122 |
+
if n_neg_pixel < neg_rto * n_pos_pixel:
|
| 123 |
+
negative_loss += torch.sum(neg_loss_region) / n_neg_pixel
|
| 124 |
+
else:
|
| 125 |
+
n_hard_neg = max(n_min_neg, neg_rto * n_pos_pixel)
|
| 126 |
+
# n_hard_neg = neg_rto*n_pos_pixel
|
| 127 |
+
negative_loss += (
|
| 128 |
+
torch.sum(
|
| 129 |
+
torch.topk(neg_loss_region.view(-1), int(n_hard_neg))[0]
|
| 130 |
+
)
|
| 131 |
+
/ n_hard_neg
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
# only negative pixel
|
| 135 |
+
negative_loss += (
|
| 136 |
+
torch.sum(torch.topk(neg_loss_region.view(-1), n_min_neg)[0])
|
| 137 |
+
/ n_min_neg
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
total_loss = (positive_loss + negative_loss) / batch_size
|
| 141 |
+
|
| 142 |
+
return total_loss
|
| 143 |
+
|
| 144 |
+
def forward(
|
| 145 |
+
self,
|
| 146 |
+
region_scores_label,
|
| 147 |
+
affinity_scores_label,
|
| 148 |
+
region_scores_pre,
|
| 149 |
+
affinity_scores_pre,
|
| 150 |
+
mask,
|
| 151 |
+
neg_rto,
|
| 152 |
+
n_min_neg,
|
| 153 |
+
):
|
| 154 |
+
loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
|
| 155 |
+
|
| 156 |
+
assert (
|
| 157 |
+
region_scores_label.size() == region_scores_pre.size()
|
| 158 |
+
and affinity_scores_label.size() == affinity_scores_pre.size()
|
| 159 |
+
)
|
| 160 |
+
loss1 = loss_fn(region_scores_pre, region_scores_label)
|
| 161 |
+
loss2 = loss_fn(affinity_scores_pre, affinity_scores_label)
|
| 162 |
+
|
| 163 |
+
loss_region = torch.mul(loss1, mask)
|
| 164 |
+
loss_affinity = torch.mul(loss2, mask)
|
| 165 |
+
char_loss = self.single_image_loss(
|
| 166 |
+
loss_region, region_scores_label, neg_rto, n_min_neg
|
| 167 |
+
)
|
| 168 |
+
affi_loss = self.single_image_loss(
|
| 169 |
+
loss_affinity, affinity_scores_label, neg_rto, n_min_neg
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return char_loss + affi_loss
|
trainer/craft/metrics/eval_det_iou.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
import numpy as np
|
| 5 |
+
from shapely.geometry import Polygon
|
| 6 |
+
"""
|
| 7 |
+
cite from:
|
| 8 |
+
PaddleOCR, github: https://github.com/PaddlePaddle/PaddleOCR
|
| 9 |
+
PaddleOCR reference from :
|
| 10 |
+
https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DetectionIoUEvaluator(object):
|
| 15 |
+
def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5):
|
| 16 |
+
self.iou_constraint = iou_constraint
|
| 17 |
+
self.area_precision_constraint = area_precision_constraint
|
| 18 |
+
|
| 19 |
+
def evaluate_image(self, gt, pred):
|
| 20 |
+
def get_union(pD, pG):
|
| 21 |
+
return Polygon(pD).union(Polygon(pG)).area
|
| 22 |
+
|
| 23 |
+
def get_intersection_over_union(pD, pG):
|
| 24 |
+
return get_intersection(pD, pG) / get_union(pD, pG)
|
| 25 |
+
|
| 26 |
+
def get_intersection(pD, pG):
|
| 27 |
+
return Polygon(pD).intersection(Polygon(pG)).area
|
| 28 |
+
|
| 29 |
+
def compute_ap(confList, matchList, numGtCare):
|
| 30 |
+
correct = 0
|
| 31 |
+
AP = 0
|
| 32 |
+
if len(confList) > 0:
|
| 33 |
+
confList = np.array(confList)
|
| 34 |
+
matchList = np.array(matchList)
|
| 35 |
+
sorted_ind = np.argsort(-confList)
|
| 36 |
+
confList = confList[sorted_ind]
|
| 37 |
+
matchList = matchList[sorted_ind]
|
| 38 |
+
for n in range(len(confList)):
|
| 39 |
+
match = matchList[n]
|
| 40 |
+
if match:
|
| 41 |
+
correct += 1
|
| 42 |
+
AP += float(correct) / (n + 1)
|
| 43 |
+
|
| 44 |
+
if numGtCare > 0:
|
| 45 |
+
AP /= numGtCare
|
| 46 |
+
|
| 47 |
+
return AP
|
| 48 |
+
|
| 49 |
+
perSampleMetrics = {}
|
| 50 |
+
|
| 51 |
+
matchedSum = 0
|
| 52 |
+
|
| 53 |
+
Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
|
| 54 |
+
|
| 55 |
+
numGlobalCareGt = 0
|
| 56 |
+
numGlobalCareDet = 0
|
| 57 |
+
|
| 58 |
+
arrGlobalConfidences = []
|
| 59 |
+
arrGlobalMatches = []
|
| 60 |
+
|
| 61 |
+
recall = 0
|
| 62 |
+
precision = 0
|
| 63 |
+
hmean = 0
|
| 64 |
+
|
| 65 |
+
detMatched = 0
|
| 66 |
+
|
| 67 |
+
iouMat = np.empty([1, 1])
|
| 68 |
+
|
| 69 |
+
gtPols = []
|
| 70 |
+
detPols = []
|
| 71 |
+
|
| 72 |
+
gtPolPoints = []
|
| 73 |
+
detPolPoints = []
|
| 74 |
+
|
| 75 |
+
# Array of Ground Truth Polygons' keys marked as don't Care
|
| 76 |
+
gtDontCarePolsNum = []
|
| 77 |
+
# Array of Detected Polygons' matched with a don't Care GT
|
| 78 |
+
detDontCarePolsNum = []
|
| 79 |
+
|
| 80 |
+
pairs = []
|
| 81 |
+
detMatchedNums = []
|
| 82 |
+
|
| 83 |
+
arrSampleConfidences = []
|
| 84 |
+
arrSampleMatch = []
|
| 85 |
+
|
| 86 |
+
evaluationLog = ""
|
| 87 |
+
|
| 88 |
+
# print(len(gt))
|
| 89 |
+
|
| 90 |
+
for n in range(len(gt)):
|
| 91 |
+
points = gt[n]['points']
|
| 92 |
+
# transcription = gt[n]['text']
|
| 93 |
+
dontCare = gt[n]['ignore']
|
| 94 |
+
# points = Polygon(points)
|
| 95 |
+
# points = points.buffer(0)
|
| 96 |
+
try:
|
| 97 |
+
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
| 98 |
+
continue
|
| 99 |
+
except:
|
| 100 |
+
import ipdb;
|
| 101 |
+
ipdb.set_trace()
|
| 102 |
+
|
| 103 |
+
#import ipdb;ipdb.set_trace()
|
| 104 |
+
gtPol = points
|
| 105 |
+
gtPols.append(gtPol)
|
| 106 |
+
gtPolPoints.append(points)
|
| 107 |
+
if dontCare:
|
| 108 |
+
gtDontCarePolsNum.append(len(gtPols) - 1)
|
| 109 |
+
|
| 110 |
+
evaluationLog += "GT polygons: " + str(len(gtPols)) + (
|
| 111 |
+
" (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
|
| 112 |
+
if len(gtDontCarePolsNum) > 0 else "\n")
|
| 113 |
+
|
| 114 |
+
for n in range(len(pred)):
|
| 115 |
+
points = pred[n]['points']
|
| 116 |
+
# points = Polygon(points)
|
| 117 |
+
# points = points.buffer(0)
|
| 118 |
+
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
detPol = points
|
| 122 |
+
detPols.append(detPol)
|
| 123 |
+
detPolPoints.append(points)
|
| 124 |
+
if len(gtDontCarePolsNum) > 0:
|
| 125 |
+
for dontCarePol in gtDontCarePolsNum:
|
| 126 |
+
dontCarePol = gtPols[dontCarePol]
|
| 127 |
+
intersected_area = get_intersection(dontCarePol, detPol)
|
| 128 |
+
pdDimensions = Polygon(detPol).area
|
| 129 |
+
precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
|
| 130 |
+
if (precision > self.area_precision_constraint):
|
| 131 |
+
detDontCarePolsNum.append(len(detPols) - 1)
|
| 132 |
+
break
|
| 133 |
+
|
| 134 |
+
evaluationLog += "DET polygons: " + str(len(detPols)) + (
|
| 135 |
+
" (" + str(len(detDontCarePolsNum)) + " don't care)\n"
|
| 136 |
+
if len(detDontCarePolsNum) > 0 else "\n")
|
| 137 |
+
|
| 138 |
+
if len(gtPols) > 0 and len(detPols) > 0:
|
| 139 |
+
# Calculate IoU and precision matrices
|
| 140 |
+
outputShape = [len(gtPols), len(detPols)]
|
| 141 |
+
iouMat = np.empty(outputShape)
|
| 142 |
+
gtRectMat = np.zeros(len(gtPols), np.int8)
|
| 143 |
+
detRectMat = np.zeros(len(detPols), np.int8)
|
| 144 |
+
for gtNum in range(len(gtPols)):
|
| 145 |
+
for detNum in range(len(detPols)):
|
| 146 |
+
pG = gtPols[gtNum]
|
| 147 |
+
pD = detPols[detNum]
|
| 148 |
+
iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
|
| 149 |
+
|
| 150 |
+
for gtNum in range(len(gtPols)):
|
| 151 |
+
for detNum in range(len(detPols)):
|
| 152 |
+
if gtRectMat[gtNum] == 0 and detRectMat[
|
| 153 |
+
detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
|
| 154 |
+
if iouMat[gtNum, detNum] > self.iou_constraint:
|
| 155 |
+
gtRectMat[gtNum] = 1
|
| 156 |
+
detRectMat[detNum] = 1
|
| 157 |
+
detMatched += 1
|
| 158 |
+
pairs.append({'gt': gtNum, 'det': detNum})
|
| 159 |
+
detMatchedNums.append(detNum)
|
| 160 |
+
evaluationLog += "Match GT #" + \
|
| 161 |
+
str(gtNum) + " with Det #" + str(detNum) + "\n"
|
| 162 |
+
|
| 163 |
+
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
|
| 164 |
+
numDetCare = (len(detPols) - len(detDontCarePolsNum))
|
| 165 |
+
if numGtCare == 0:
|
| 166 |
+
recall = float(1)
|
| 167 |
+
precision = float(0) if numDetCare > 0 else float(1)
|
| 168 |
+
else:
|
| 169 |
+
recall = float(detMatched) / numGtCare
|
| 170 |
+
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
|
| 171 |
+
|
| 172 |
+
hmean = 0 if (precision + recall) == 0 else 2.0 * \
|
| 173 |
+
precision * recall / (precision + recall)
|
| 174 |
+
|
| 175 |
+
matchedSum += detMatched
|
| 176 |
+
numGlobalCareGt += numGtCare
|
| 177 |
+
numGlobalCareDet += numDetCare
|
| 178 |
+
|
| 179 |
+
perSampleMetrics = {
|
| 180 |
+
'precision': precision,
|
| 181 |
+
'recall': recall,
|
| 182 |
+
'hmean': hmean,
|
| 183 |
+
'pairs': pairs,
|
| 184 |
+
'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
|
| 185 |
+
'gtPolPoints': gtPolPoints,
|
| 186 |
+
'detPolPoints': detPolPoints,
|
| 187 |
+
'gtCare': numGtCare,
|
| 188 |
+
'detCare': numDetCare,
|
| 189 |
+
'gtDontCare': gtDontCarePolsNum,
|
| 190 |
+
'detDontCare': detDontCarePolsNum,
|
| 191 |
+
'detMatched': detMatched,
|
| 192 |
+
'evaluationLog': evaluationLog
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
return perSampleMetrics
|
| 196 |
+
|
| 197 |
+
def combine_results(self, results):
|
| 198 |
+
numGlobalCareGt = 0
|
| 199 |
+
numGlobalCareDet = 0
|
| 200 |
+
matchedSum = 0
|
| 201 |
+
for result in results:
|
| 202 |
+
numGlobalCareGt += result['gtCare']
|
| 203 |
+
numGlobalCareDet += result['detCare']
|
| 204 |
+
matchedSum += result['detMatched']
|
| 205 |
+
|
| 206 |
+
methodRecall = 0 if numGlobalCareGt == 0 else float(
|
| 207 |
+
matchedSum) / numGlobalCareGt
|
| 208 |
+
methodPrecision = 0 if numGlobalCareDet == 0 else float(
|
| 209 |
+
matchedSum) / numGlobalCareDet
|
| 210 |
+
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
|
| 211 |
+
methodRecall * methodPrecision / (
|
| 212 |
+
methodRecall + methodPrecision)
|
| 213 |
+
# print(methodRecall, methodPrecision, methodHmean)
|
| 214 |
+
# sys.exit(-1)
|
| 215 |
+
methodMetrics = {
|
| 216 |
+
'precision': methodPrecision,
|
| 217 |
+
'recall': methodRecall,
|
| 218 |
+
'hmean': methodHmean
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
return methodMetrics
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
if __name__ == '__main__':
|
| 225 |
+
evaluator = DetectionIoUEvaluator()
|
| 226 |
+
gts = [[{
|
| 227 |
+
'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
|
| 228 |
+
'text': 1234,
|
| 229 |
+
'ignore': False,
|
| 230 |
+
}, {
|
| 231 |
+
'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
|
| 232 |
+
'text': 5678,
|
| 233 |
+
'ignore': False,
|
| 234 |
+
}]]
|
| 235 |
+
preds = [[{
|
| 236 |
+
'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
|
| 237 |
+
'text': 123,
|
| 238 |
+
'ignore': False,
|
| 239 |
+
}]]
|
| 240 |
+
results = []
|
| 241 |
+
for gt, pred in zip(gts, preds):
|
| 242 |
+
results.append(evaluator.evaluate_image(gt, pred))
|
| 243 |
+
metrics = evaluator.combine_results(results)
|
| 244 |
+
print(metrics)
|
trainer/craft/model/craft.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) 2019-present NAVER Corp.
|
| 3 |
+
MIT License
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# -*- coding: utf-8 -*-
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from model.vgg16_bn import vgg16_bn, init_weights
|
| 12 |
+
|
| 13 |
+
class double_conv(nn.Module):
|
| 14 |
+
def __init__(self, in_ch, mid_ch, out_ch):
|
| 15 |
+
super(double_conv, self).__init__()
|
| 16 |
+
self.conv = nn.Sequential(
|
| 17 |
+
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
|
| 18 |
+
nn.BatchNorm2d(mid_ch),
|
| 19 |
+
nn.ReLU(inplace=True),
|
| 20 |
+
nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
|
| 21 |
+
nn.BatchNorm2d(out_ch),
|
| 22 |
+
nn.ReLU(inplace=True)
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
x = self.conv(x)
|
| 27 |
+
return x
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CRAFT(nn.Module):
|
| 31 |
+
def __init__(self, pretrained=True, freeze=False, amp=False):
|
| 32 |
+
super(CRAFT, self).__init__()
|
| 33 |
+
|
| 34 |
+
self.amp = amp
|
| 35 |
+
|
| 36 |
+
""" Base network """
|
| 37 |
+
self.basenet = vgg16_bn(pretrained, freeze)
|
| 38 |
+
|
| 39 |
+
""" U network """
|
| 40 |
+
self.upconv1 = double_conv(1024, 512, 256)
|
| 41 |
+
self.upconv2 = double_conv(512, 256, 128)
|
| 42 |
+
self.upconv3 = double_conv(256, 128, 64)
|
| 43 |
+
self.upconv4 = double_conv(128, 64, 32)
|
| 44 |
+
|
| 45 |
+
num_class = 2
|
| 46 |
+
self.conv_cls = nn.Sequential(
|
| 47 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
| 48 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
| 49 |
+
nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
| 50 |
+
nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
|
| 51 |
+
nn.Conv2d(16, num_class, kernel_size=1),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
init_weights(self.upconv1.modules())
|
| 55 |
+
init_weights(self.upconv2.modules())
|
| 56 |
+
init_weights(self.upconv3.modules())
|
| 57 |
+
init_weights(self.upconv4.modules())
|
| 58 |
+
init_weights(self.conv_cls.modules())
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
""" Base network """
|
| 62 |
+
if self.amp:
|
| 63 |
+
with torch.cuda.amp.autocast():
|
| 64 |
+
sources = self.basenet(x)
|
| 65 |
+
|
| 66 |
+
""" U network """
|
| 67 |
+
y = torch.cat([sources[0], sources[1]], dim=1)
|
| 68 |
+
y = self.upconv1(y)
|
| 69 |
+
|
| 70 |
+
y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
|
| 71 |
+
y = torch.cat([y, sources[2]], dim=1)
|
| 72 |
+
y = self.upconv2(y)
|
| 73 |
+
|
| 74 |
+
y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
|
| 75 |
+
y = torch.cat([y, sources[3]], dim=1)
|
| 76 |
+
y = self.upconv3(y)
|
| 77 |
+
|
| 78 |
+
y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
|
| 79 |
+
y = torch.cat([y, sources[4]], dim=1)
|
| 80 |
+
feature = self.upconv4(y)
|
| 81 |
+
|
| 82 |
+
y = self.conv_cls(feature)
|
| 83 |
+
|
| 84 |
+
return y.permute(0,2,3,1), feature
|
| 85 |
+
else:
|
| 86 |
+
|
| 87 |
+
sources = self.basenet(x)
|
| 88 |
+
|
| 89 |
+
""" U network """
|
| 90 |
+
y = torch.cat([sources[0], sources[1]], dim=1)
|
| 91 |
+
y = self.upconv1(y)
|
| 92 |
+
|
| 93 |
+
y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
|
| 94 |
+
y = torch.cat([y, sources[2]], dim=1)
|
| 95 |
+
y = self.upconv2(y)
|
| 96 |
+
|
| 97 |
+
y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
|
| 98 |
+
y = torch.cat([y, sources[3]], dim=1)
|
| 99 |
+
y = self.upconv3(y)
|
| 100 |
+
|
| 101 |
+
y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
|
| 102 |
+
y = torch.cat([y, sources[4]], dim=1)
|
| 103 |
+
feature = self.upconv4(y)
|
| 104 |
+
|
| 105 |
+
y = self.conv_cls(feature)
|
| 106 |
+
|
| 107 |
+
return y.permute(0, 2, 3, 1), feature
|
| 108 |
+
|
| 109 |
+
if __name__ == '__main__':
|
| 110 |
+
model = CRAFT(pretrained=True).cuda()
|
| 111 |
+
output, _ = model(torch.randn(1, 3, 768, 768).cuda())
|
| 112 |
+
print(output.shape)
|
trainer/craft/model/vgg16_bn.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.init as init
|
| 4 |
+
import torchvision
|
| 5 |
+
from torchvision import models
|
| 6 |
+
from packaging import version
|
| 7 |
+
|
| 8 |
+
def init_weights(modules):
|
| 9 |
+
for m in modules:
|
| 10 |
+
if isinstance(m, nn.Conv2d):
|
| 11 |
+
init.xavier_uniform_(m.weight.data)
|
| 12 |
+
if m.bias is not None:
|
| 13 |
+
m.bias.data.zero_()
|
| 14 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 15 |
+
m.weight.data.fill_(1)
|
| 16 |
+
m.bias.data.zero_()
|
| 17 |
+
elif isinstance(m, nn.Linear):
|
| 18 |
+
m.weight.data.normal_(0, 0.01)
|
| 19 |
+
m.bias.data.zero_()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class vgg16_bn(torch.nn.Module):
|
| 23 |
+
def __init__(self, pretrained=True, freeze=True):
|
| 24 |
+
super(vgg16_bn, self).__init__()
|
| 25 |
+
if version.parse(torchvision.__version__) >= version.parse('0.13'):
|
| 26 |
+
vgg_pretrained_features = models.vgg16_bn(
|
| 27 |
+
weights=models.VGG16_BN_Weights.DEFAULT if pretrained else None
|
| 28 |
+
).features
|
| 29 |
+
else: # torchvision.__version__ < 0.13
|
| 30 |
+
models.vgg.model_urls['vgg16_bn'] = models.vgg.model_urls['vgg16_bn'].replace('https://', 'http://')
|
| 31 |
+
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
|
| 32 |
+
|
| 33 |
+
self.slice1 = torch.nn.Sequential()
|
| 34 |
+
self.slice2 = torch.nn.Sequential()
|
| 35 |
+
self.slice3 = torch.nn.Sequential()
|
| 36 |
+
self.slice4 = torch.nn.Sequential()
|
| 37 |
+
self.slice5 = torch.nn.Sequential()
|
| 38 |
+
for x in range(12): # conv2_2
|
| 39 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
| 40 |
+
for x in range(12, 19): # conv3_3
|
| 41 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
| 42 |
+
for x in range(19, 29): # conv4_3
|
| 43 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
| 44 |
+
for x in range(29, 39): # conv5_3
|
| 45 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
| 46 |
+
|
| 47 |
+
# fc6, fc7 without atrous conv
|
| 48 |
+
self.slice5 = torch.nn.Sequential(
|
| 49 |
+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
| 50 |
+
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
|
| 51 |
+
nn.Conv2d(1024, 1024, kernel_size=1)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
if not pretrained:
|
| 55 |
+
init_weights(self.slice1.modules())
|
| 56 |
+
init_weights(self.slice2.modules())
|
| 57 |
+
init_weights(self.slice3.modules())
|
| 58 |
+
init_weights(self.slice4.modules())
|
| 59 |
+
|
| 60 |
+
init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
|
| 61 |
+
|
| 62 |
+
if freeze:
|
| 63 |
+
for param in self.slice1.parameters(): # only first conv
|
| 64 |
+
param.requires_grad= False
|
| 65 |
+
|
| 66 |
+
def forward(self, X):
|
| 67 |
+
h = self.slice1(X)
|
| 68 |
+
h_relu2_2 = h
|
| 69 |
+
h = self.slice2(h)
|
| 70 |
+
h_relu3_2 = h
|
| 71 |
+
h = self.slice3(h)
|
| 72 |
+
h_relu4_3 = h
|
| 73 |
+
h = self.slice4(h)
|
| 74 |
+
h_relu5_3 = h
|
| 75 |
+
h = self.slice5(h)
|
| 76 |
+
h_fc7 = h
|
| 77 |
+
return h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2
|
trainer/craft/requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
conda==4.10.3
|
| 2 |
+
opencv-python==4.5.3.56
|
| 3 |
+
Pillow==9.3.0
|
| 4 |
+
Polygon3==3.0.9.1
|
| 5 |
+
PyYAML==5.4.1
|
| 6 |
+
scikit-image==0.17.2
|
| 7 |
+
Shapely==1.8.0
|
| 8 |
+
torch==1.13.1
|
| 9 |
+
torchvision==0.10.0
|
| 10 |
+
wandb==0.12.9
|
trainer/craft/scripts/run_cde.sh
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# sed -i -e 's/\r$//' scripts/run_cde.sh
|
| 2 |
+
EXP_NAME=custom_data_release_test_3
|
| 3 |
+
yaml_path="config/$EXP_NAME.yaml"
|
| 4 |
+
cp config/custom_data_train.yaml $yaml_path
|
| 5 |
+
#CUDA_VISIBLE_DEVICES=0,1 python3 train_distributed.py --yaml=$EXP_NAME --port=2468
|
| 6 |
+
CUDA_VISIBLE_DEVICES=0 python3 train.py --yaml=$EXP_NAME --port=2468
|
| 7 |
+
rm "config/$EXP_NAME.yaml"
|
trainer/craft/train.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import time
|
| 6 |
+
import multiprocessing as mp
|
| 7 |
+
import yaml
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
import wandb
|
| 14 |
+
|
| 15 |
+
from config.load_config import load_yaml, DotDict
|
| 16 |
+
from data.dataset import SynthTextDataSet, CustomDataset
|
| 17 |
+
from loss.mseloss import Maploss_v2, Maploss_v3
|
| 18 |
+
from model.craft import CRAFT
|
| 19 |
+
from eval import main_eval
|
| 20 |
+
from metrics.eval_det_iou import DetectionIoUEvaluator
|
| 21 |
+
from utils.util import copyStateDict, save_parser
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Trainer(object):
|
| 25 |
+
def __init__(self, config, gpu, mode):
|
| 26 |
+
|
| 27 |
+
self.config = config
|
| 28 |
+
self.gpu = gpu
|
| 29 |
+
self.mode = mode
|
| 30 |
+
self.net_param = self.get_load_param(gpu)
|
| 31 |
+
|
| 32 |
+
def get_synth_loader(self):
|
| 33 |
+
|
| 34 |
+
dataset = SynthTextDataSet(
|
| 35 |
+
output_size=self.config.train.data.output_size,
|
| 36 |
+
data_dir=self.config.train.synth_data_dir,
|
| 37 |
+
saved_gt_dir=None,
|
| 38 |
+
mean=self.config.train.data.mean,
|
| 39 |
+
variance=self.config.train.data.variance,
|
| 40 |
+
gauss_init_size=self.config.train.data.gauss_init_size,
|
| 41 |
+
gauss_sigma=self.config.train.data.gauss_sigma,
|
| 42 |
+
enlarge_region=self.config.train.data.enlarge_region,
|
| 43 |
+
enlarge_affinity=self.config.train.data.enlarge_affinity,
|
| 44 |
+
aug=self.config.train.data.syn_aug,
|
| 45 |
+
vis_test_dir=self.config.vis_test_dir,
|
| 46 |
+
vis_opt=self.config.train.data.vis_opt,
|
| 47 |
+
sample=self.config.train.data.syn_sample,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
syn_loader = torch.utils.data.DataLoader(
|
| 51 |
+
dataset,
|
| 52 |
+
batch_size=self.config.train.batch_size // self.config.train.synth_ratio,
|
| 53 |
+
shuffle=False,
|
| 54 |
+
num_workers=self.config.train.num_workers,
|
| 55 |
+
drop_last=True,
|
| 56 |
+
pin_memory=True,
|
| 57 |
+
)
|
| 58 |
+
return syn_loader
|
| 59 |
+
|
| 60 |
+
def get_custom_dataset(self):
|
| 61 |
+
|
| 62 |
+
custom_dataset = CustomDataset(
|
| 63 |
+
output_size=self.config.train.data.output_size,
|
| 64 |
+
data_dir=self.config.data_root_dir,
|
| 65 |
+
saved_gt_dir=None,
|
| 66 |
+
mean=self.config.train.data.mean,
|
| 67 |
+
variance=self.config.train.data.variance,
|
| 68 |
+
gauss_init_size=self.config.train.data.gauss_init_size,
|
| 69 |
+
gauss_sigma=self.config.train.data.gauss_sigma,
|
| 70 |
+
enlarge_region=self.config.train.data.enlarge_region,
|
| 71 |
+
enlarge_affinity=self.config.train.data.enlarge_affinity,
|
| 72 |
+
watershed_param=self.config.train.data.watershed,
|
| 73 |
+
aug=self.config.train.data.custom_aug,
|
| 74 |
+
vis_test_dir=self.config.vis_test_dir,
|
| 75 |
+
sample=self.config.train.data.custom_sample,
|
| 76 |
+
vis_opt=self.config.train.data.vis_opt,
|
| 77 |
+
pseudo_vis_opt=self.config.train.data.pseudo_vis_opt,
|
| 78 |
+
do_not_care_label=self.config.train.data.do_not_care_label,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return custom_dataset
|
| 82 |
+
|
| 83 |
+
def get_load_param(self, gpu):
|
| 84 |
+
|
| 85 |
+
if self.config.train.ckpt_path is not None:
|
| 86 |
+
map_location = "cuda:%d" % gpu
|
| 87 |
+
param = torch.load(self.config.train.ckpt_path, map_location=map_location)
|
| 88 |
+
else:
|
| 89 |
+
param = None
|
| 90 |
+
|
| 91 |
+
return param
|
| 92 |
+
|
| 93 |
+
def adjust_learning_rate(self, optimizer, gamma, step, lr):
|
| 94 |
+
lr = lr * (gamma ** step)
|
| 95 |
+
for param_group in optimizer.param_groups:
|
| 96 |
+
param_group["lr"] = lr
|
| 97 |
+
return param_group["lr"]
|
| 98 |
+
|
| 99 |
+
def get_loss(self):
|
| 100 |
+
if self.config.train.loss == 2:
|
| 101 |
+
criterion = Maploss_v2()
|
| 102 |
+
elif self.config.train.loss == 3:
|
| 103 |
+
criterion = Maploss_v3()
|
| 104 |
+
else:
|
| 105 |
+
raise Exception("Undefined loss")
|
| 106 |
+
return criterion
|
| 107 |
+
|
| 108 |
+
def iou_eval(self, dataset, train_step, buffer, model):
|
| 109 |
+
test_config = DotDict(self.config.test[dataset])
|
| 110 |
+
|
| 111 |
+
val_result_dir = os.path.join(
|
| 112 |
+
self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step))
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
evaluator = DetectionIoUEvaluator()
|
| 116 |
+
|
| 117 |
+
metrics = main_eval(
|
| 118 |
+
None,
|
| 119 |
+
self.config.train.backbone,
|
| 120 |
+
test_config,
|
| 121 |
+
evaluator,
|
| 122 |
+
val_result_dir,
|
| 123 |
+
buffer,
|
| 124 |
+
model,
|
| 125 |
+
self.mode,
|
| 126 |
+
)
|
| 127 |
+
if self.gpu == 0 and self.config.wandb_opt:
|
| 128 |
+
wandb.log(
|
| 129 |
+
{
|
| 130 |
+
"{} iou Recall".format(dataset): np.round(metrics["recall"], 3),
|
| 131 |
+
"{} iou Precision".format(dataset): np.round(
|
| 132 |
+
metrics["precision"], 3
|
| 133 |
+
),
|
| 134 |
+
"{} iou F1-score".format(dataset): np.round(metrics["hmean"], 3),
|
| 135 |
+
}
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def train(self, buffer_dict):
|
| 139 |
+
|
| 140 |
+
torch.cuda.set_device(self.gpu)
|
| 141 |
+
|
| 142 |
+
# MODEL -------------------------------------------------------------------------------------------------------#
|
| 143 |
+
# SUPERVISION model
|
| 144 |
+
if self.config.mode == "weak_supervision":
|
| 145 |
+
if self.config.train.backbone == "vgg":
|
| 146 |
+
supervision_model = CRAFT(pretrained=False, amp=self.config.train.amp)
|
| 147 |
+
else:
|
| 148 |
+
raise Exception("Undefined architecture")
|
| 149 |
+
|
| 150 |
+
supervision_device = self.gpu
|
| 151 |
+
if self.config.train.ckpt_path is not None:
|
| 152 |
+
supervision_param = self.get_load_param(supervision_device)
|
| 153 |
+
supervision_model.load_state_dict(
|
| 154 |
+
copyStateDict(supervision_param["craft"])
|
| 155 |
+
)
|
| 156 |
+
supervision_model = supervision_model.to(f"cuda:{supervision_device}")
|
| 157 |
+
print(f"Supervision model loading on : gpu {supervision_device}")
|
| 158 |
+
else:
|
| 159 |
+
supervision_model, supervision_device = None, None
|
| 160 |
+
|
| 161 |
+
# TRAIN model
|
| 162 |
+
if self.config.train.backbone == "vgg":
|
| 163 |
+
craft = CRAFT(pretrained=False, amp=self.config.train.amp)
|
| 164 |
+
else:
|
| 165 |
+
raise Exception("Undefined architecture")
|
| 166 |
+
|
| 167 |
+
if self.config.train.ckpt_path is not None:
|
| 168 |
+
craft.load_state_dict(copyStateDict(self.net_param["craft"]))
|
| 169 |
+
|
| 170 |
+
craft = craft.cuda()
|
| 171 |
+
craft = torch.nn.DataParallel(craft)
|
| 172 |
+
|
| 173 |
+
torch.backends.cudnn.benchmark = True
|
| 174 |
+
|
| 175 |
+
# DATASET -----------------------------------------------------------------------------------------------------#
|
| 176 |
+
|
| 177 |
+
if self.config.train.use_synthtext:
|
| 178 |
+
trn_syn_loader = self.get_synth_loader()
|
| 179 |
+
batch_syn = iter(trn_syn_loader)
|
| 180 |
+
|
| 181 |
+
if self.config.train.real_dataset == "custom":
|
| 182 |
+
trn_real_dataset = self.get_custom_dataset()
|
| 183 |
+
else:
|
| 184 |
+
raise Exception("Undefined dataset")
|
| 185 |
+
|
| 186 |
+
if self.config.mode == "weak_supervision":
|
| 187 |
+
trn_real_dataset.update_model(supervision_model)
|
| 188 |
+
trn_real_dataset.update_device(supervision_device)
|
| 189 |
+
|
| 190 |
+
trn_real_loader = torch.utils.data.DataLoader(
|
| 191 |
+
trn_real_dataset,
|
| 192 |
+
batch_size=self.config.train.batch_size,
|
| 193 |
+
shuffle=False,
|
| 194 |
+
num_workers=self.config.train.num_workers,
|
| 195 |
+
drop_last=False,
|
| 196 |
+
pin_memory=True,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# OPTIMIZER ---------------------------------------------------------------------------------------------------#
|
| 200 |
+
optimizer = optim.Adam(
|
| 201 |
+
craft.parameters(),
|
| 202 |
+
lr=self.config.train.lr,
|
| 203 |
+
weight_decay=self.config.train.weight_decay,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
|
| 207 |
+
optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
|
| 208 |
+
self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
|
| 209 |
+
self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]
|
| 210 |
+
|
| 211 |
+
# LOSS --------------------------------------------------------------------------------------------------------#
|
| 212 |
+
# mixed precision
|
| 213 |
+
if self.config.train.amp:
|
| 214 |
+
scaler = torch.cuda.amp.GradScaler()
|
| 215 |
+
|
| 216 |
+
if (
|
| 217 |
+
self.config.train.ckpt_path is not None
|
| 218 |
+
and self.config.train.st_iter != 0
|
| 219 |
+
):
|
| 220 |
+
scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
|
| 221 |
+
else:
|
| 222 |
+
scaler = None
|
| 223 |
+
|
| 224 |
+
criterion = self.get_loss()
|
| 225 |
+
|
| 226 |
+
# TRAIN -------------------------------------------------------------------------------------------------------#
|
| 227 |
+
train_step = self.config.train.st_iter
|
| 228 |
+
whole_training_step = self.config.train.end_iter
|
| 229 |
+
update_lr_rate_step = 0
|
| 230 |
+
training_lr = self.config.train.lr
|
| 231 |
+
loss_value = 0
|
| 232 |
+
batch_time = 0
|
| 233 |
+
start_time = time.time()
|
| 234 |
+
|
| 235 |
+
print(
|
| 236 |
+
"================================ Train start ================================"
|
| 237 |
+
)
|
| 238 |
+
while train_step < whole_training_step:
|
| 239 |
+
for (
|
| 240 |
+
index,
|
| 241 |
+
(
|
| 242 |
+
images,
|
| 243 |
+
region_scores,
|
| 244 |
+
affinity_scores,
|
| 245 |
+
confidence_masks,
|
| 246 |
+
),
|
| 247 |
+
) in enumerate(trn_real_loader):
|
| 248 |
+
craft.train()
|
| 249 |
+
if train_step > 0 and train_step % self.config.train.lr_decay == 0:
|
| 250 |
+
update_lr_rate_step += 1
|
| 251 |
+
training_lr = self.adjust_learning_rate(
|
| 252 |
+
optimizer,
|
| 253 |
+
self.config.train.gamma,
|
| 254 |
+
update_lr_rate_step,
|
| 255 |
+
self.config.train.lr,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
images = images.cuda(non_blocking=True)
|
| 259 |
+
region_scores = region_scores.cuda(non_blocking=True)
|
| 260 |
+
affinity_scores = affinity_scores.cuda(non_blocking=True)
|
| 261 |
+
confidence_masks = confidence_masks.cuda(non_blocking=True)
|
| 262 |
+
|
| 263 |
+
if self.config.train.use_synthtext:
|
| 264 |
+
# Synth image load
|
| 265 |
+
syn_image, syn_region_label, syn_affi_label, syn_confidence_mask = next(
|
| 266 |
+
batch_syn
|
| 267 |
+
)
|
| 268 |
+
syn_image = syn_image.cuda(non_blocking=True)
|
| 269 |
+
syn_region_label = syn_region_label.cuda(non_blocking=True)
|
| 270 |
+
syn_affi_label = syn_affi_label.cuda(non_blocking=True)
|
| 271 |
+
syn_confidence_mask = syn_confidence_mask.cuda(non_blocking=True)
|
| 272 |
+
|
| 273 |
+
# concat syn & custom image
|
| 274 |
+
images = torch.cat((syn_image, images), 0)
|
| 275 |
+
region_image_label = torch.cat(
|
| 276 |
+
(syn_region_label, region_scores), 0
|
| 277 |
+
)
|
| 278 |
+
affinity_image_label = torch.cat((syn_affi_label, affinity_scores), 0)
|
| 279 |
+
confidence_mask_label = torch.cat(
|
| 280 |
+
(syn_confidence_mask, confidence_masks), 0
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
region_image_label = region_scores
|
| 284 |
+
affinity_image_label = affinity_scores
|
| 285 |
+
confidence_mask_label = confidence_masks
|
| 286 |
+
|
| 287 |
+
if self.config.train.amp:
|
| 288 |
+
with torch.cuda.amp.autocast():
|
| 289 |
+
|
| 290 |
+
output, _ = craft(images)
|
| 291 |
+
out1 = output[:, :, :, 0]
|
| 292 |
+
out2 = output[:, :, :, 1]
|
| 293 |
+
|
| 294 |
+
loss = criterion(
|
| 295 |
+
region_image_label,
|
| 296 |
+
affinity_image_label,
|
| 297 |
+
out1,
|
| 298 |
+
out2,
|
| 299 |
+
confidence_mask_label,
|
| 300 |
+
self.config.train.neg_rto,
|
| 301 |
+
self.config.train.n_min_neg,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
optimizer.zero_grad()
|
| 305 |
+
scaler.scale(loss).backward()
|
| 306 |
+
scaler.step(optimizer)
|
| 307 |
+
scaler.update()
|
| 308 |
+
|
| 309 |
+
else:
|
| 310 |
+
output, _ = craft(images)
|
| 311 |
+
out1 = output[:, :, :, 0]
|
| 312 |
+
out2 = output[:, :, :, 1]
|
| 313 |
+
loss = criterion(
|
| 314 |
+
region_image_label,
|
| 315 |
+
affinity_image_label,
|
| 316 |
+
out1,
|
| 317 |
+
out2,
|
| 318 |
+
confidence_mask_label,
|
| 319 |
+
self.config.train.neg_rto,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
optimizer.zero_grad()
|
| 323 |
+
loss.backward()
|
| 324 |
+
optimizer.step()
|
| 325 |
+
|
| 326 |
+
end_time = time.time()
|
| 327 |
+
loss_value += loss.item()
|
| 328 |
+
batch_time += end_time - start_time
|
| 329 |
+
|
| 330 |
+
if train_step > 0 and train_step % 5 == 0:
|
| 331 |
+
mean_loss = loss_value / 5
|
| 332 |
+
loss_value = 0
|
| 333 |
+
avg_batch_time = batch_time / 5
|
| 334 |
+
batch_time = 0
|
| 335 |
+
|
| 336 |
+
print(
|
| 337 |
+
"{}, training_step: {}|{}, learning rate: {:.8f}, "
|
| 338 |
+
"training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
|
| 339 |
+
time.strftime(
|
| 340 |
+
"%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
|
| 341 |
+
),
|
| 342 |
+
train_step,
|
| 343 |
+
whole_training_step,
|
| 344 |
+
training_lr,
|
| 345 |
+
mean_loss,
|
| 346 |
+
avg_batch_time,
|
| 347 |
+
)
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
if self.config.wandb_opt:
|
| 351 |
+
wandb.log({"train_step": train_step, "mean_loss": mean_loss})
|
| 352 |
+
|
| 353 |
+
if (
|
| 354 |
+
train_step % self.config.train.eval_interval == 0
|
| 355 |
+
and train_step != 0
|
| 356 |
+
):
|
| 357 |
+
|
| 358 |
+
craft.eval()
|
| 359 |
+
|
| 360 |
+
print("Saving state, index:", train_step)
|
| 361 |
+
save_param_dic = {
|
| 362 |
+
"iter": train_step,
|
| 363 |
+
"craft": craft.state_dict(),
|
| 364 |
+
"optimizer": optimizer.state_dict(),
|
| 365 |
+
}
|
| 366 |
+
save_param_path = (
|
| 367 |
+
self.config.results_dir
|
| 368 |
+
+ "/CRAFT_clr_"
|
| 369 |
+
+ repr(train_step)
|
| 370 |
+
+ ".pth"
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
if self.config.train.amp:
|
| 374 |
+
save_param_dic["scaler"] = scaler.state_dict()
|
| 375 |
+
save_param_path = (
|
| 376 |
+
self.config.results_dir
|
| 377 |
+
+ "/CRAFT_clr_amp_"
|
| 378 |
+
+ repr(train_step)
|
| 379 |
+
+ ".pth"
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
torch.save(save_param_dic, save_param_path)
|
| 383 |
+
|
| 384 |
+
# validation
|
| 385 |
+
self.iou_eval(
|
| 386 |
+
"custom_data",
|
| 387 |
+
train_step,
|
| 388 |
+
buffer_dict["custom_data"],
|
| 389 |
+
craft,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
train_step += 1
|
| 393 |
+
if train_step >= whole_training_step:
|
| 394 |
+
break
|
| 395 |
+
|
| 396 |
+
if self.config.mode == "weak_supervision":
|
| 397 |
+
state_dict = craft.module.state_dict()
|
| 398 |
+
supervision_model.load_state_dict(state_dict)
|
| 399 |
+
trn_real_dataset.update_model(supervision_model)
|
| 400 |
+
|
| 401 |
+
# save last model
|
| 402 |
+
save_param_dic = {
|
| 403 |
+
"iter": train_step,
|
| 404 |
+
"craft": craft.state_dict(),
|
| 405 |
+
"optimizer": optimizer.state_dict(),
|
| 406 |
+
}
|
| 407 |
+
save_param_path = (
|
| 408 |
+
self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
if self.config.train.amp:
|
| 412 |
+
save_param_dic["scaler"] = scaler.state_dict()
|
| 413 |
+
save_param_path = (
|
| 414 |
+
self.config.results_dir
|
| 415 |
+
+ "/CRAFT_clr_amp_"
|
| 416 |
+
+ repr(train_step)
|
| 417 |
+
+ ".pth"
|
| 418 |
+
)
|
| 419 |
+
torch.save(save_param_dic, save_param_path)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def main():
|
| 423 |
+
parser = argparse.ArgumentParser(description="CRAFT custom data train")
|
| 424 |
+
parser.add_argument(
|
| 425 |
+
"--yaml",
|
| 426 |
+
"--yaml_file_name",
|
| 427 |
+
default="custom_data_train",
|
| 428 |
+
type=str,
|
| 429 |
+
help="Load configuration",
|
| 430 |
+
)
|
| 431 |
+
parser.add_argument(
|
| 432 |
+
"--port", "--use ddp port", default="2346", type=str, help="Port number"
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
args = parser.parse_args()
|
| 436 |
+
|
| 437 |
+
# load configure
|
| 438 |
+
exp_name = args.yaml
|
| 439 |
+
config = load_yaml(args.yaml)
|
| 440 |
+
|
| 441 |
+
print("-" * 20 + " Options " + "-" * 20)
|
| 442 |
+
print(yaml.dump(config))
|
| 443 |
+
print("-" * 40)
|
| 444 |
+
|
| 445 |
+
# Make result_dir
|
| 446 |
+
res_dir = os.path.join(config["results_dir"], args.yaml)
|
| 447 |
+
config["results_dir"] = res_dir
|
| 448 |
+
if not os.path.exists(res_dir):
|
| 449 |
+
os.makedirs(res_dir)
|
| 450 |
+
|
| 451 |
+
# Duplicate yaml file to result_dir
|
| 452 |
+
shutil.copy(
|
| 453 |
+
"config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
if config["mode"] == "weak_supervision":
|
| 457 |
+
mode = "weak_supervision"
|
| 458 |
+
else:
|
| 459 |
+
mode = None
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
# Apply config to wandb
|
| 463 |
+
if config["wandb_opt"]:
|
| 464 |
+
wandb.init(project="craft-stage2", entity="user_name", name=exp_name)
|
| 465 |
+
wandb.config.update(config)
|
| 466 |
+
|
| 467 |
+
config = DotDict(config)
|
| 468 |
+
|
| 469 |
+
# Start train
|
| 470 |
+
buffer_dict = {"custom_data":None}
|
| 471 |
+
trainer = Trainer(config, 0, mode)
|
| 472 |
+
trainer.train(buffer_dict)
|
| 473 |
+
|
| 474 |
+
if config["wandb_opt"]:
|
| 475 |
+
wandb.finish()
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
if __name__ == "__main__":
|
| 479 |
+
main()
|
trainer/craft/trainSynth.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import time
|
| 6 |
+
import yaml
|
| 7 |
+
import multiprocessing as mp
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
import wandb
|
| 14 |
+
|
| 15 |
+
from config.load_config import load_yaml, DotDict
|
| 16 |
+
from data.dataset import SynthTextDataSet
|
| 17 |
+
from loss.mseloss import Maploss_v2, Maploss_v3
|
| 18 |
+
from model.craft import CRAFT
|
| 19 |
+
from metrics.eval_det_iou import DetectionIoUEvaluator
|
| 20 |
+
from eval import main_eval
|
| 21 |
+
from utils.util import copyStateDict, save_parser
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Trainer(object):
|
| 25 |
+
def __init__(self, config, gpu):
|
| 26 |
+
|
| 27 |
+
self.config = config
|
| 28 |
+
self.gpu = gpu
|
| 29 |
+
self.mode = None
|
| 30 |
+
self.trn_loader, self.trn_sampler = self.get_trn_loader()
|
| 31 |
+
self.net_param = self.get_load_param(gpu)
|
| 32 |
+
|
| 33 |
+
def get_trn_loader(self):
|
| 34 |
+
|
| 35 |
+
dataset = SynthTextDataSet(
|
| 36 |
+
output_size=self.config.train.data.output_size,
|
| 37 |
+
data_dir=self.config.data_dir.synthtext,
|
| 38 |
+
saved_gt_dir=None,
|
| 39 |
+
mean=self.config.train.data.mean,
|
| 40 |
+
variance=self.config.train.data.variance,
|
| 41 |
+
gauss_init_size=self.config.train.data.gauss_init_size,
|
| 42 |
+
gauss_sigma=self.config.train.data.gauss_sigma,
|
| 43 |
+
enlarge_region=self.config.train.data.enlarge_region,
|
| 44 |
+
enlarge_affinity=self.config.train.data.enlarge_affinity,
|
| 45 |
+
aug=self.config.train.data.syn_aug,
|
| 46 |
+
vis_test_dir=self.config.vis_test_dir,
|
| 47 |
+
vis_opt=self.config.train.data.vis_opt,
|
| 48 |
+
sample=self.config.train.data.syn_sample,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
trn_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
| 52 |
+
|
| 53 |
+
trn_loader = torch.utils.data.DataLoader(
|
| 54 |
+
dataset,
|
| 55 |
+
batch_size=self.config.train.batch_size,
|
| 56 |
+
shuffle=False,
|
| 57 |
+
num_workers=self.config.train.num_workers,
|
| 58 |
+
sampler=trn_sampler,
|
| 59 |
+
drop_last=True,
|
| 60 |
+
pin_memory=True,
|
| 61 |
+
)
|
| 62 |
+
return trn_loader, trn_sampler
|
| 63 |
+
|
| 64 |
+
def get_load_param(self, gpu):
|
| 65 |
+
if self.config.train.ckpt_path is not None:
|
| 66 |
+
map_location = {"cuda:%d" % 0: "cuda:%d" % gpu}
|
| 67 |
+
param = torch.load(self.config.train.ckpt_path, map_location=map_location)
|
| 68 |
+
else:
|
| 69 |
+
param = None
|
| 70 |
+
return param
|
| 71 |
+
|
| 72 |
+
def adjust_learning_rate(self, optimizer, gamma, step, lr):
|
| 73 |
+
lr = lr * (gamma ** step)
|
| 74 |
+
for param_group in optimizer.param_groups:
|
| 75 |
+
param_group["lr"] = lr
|
| 76 |
+
return param_group["lr"]
|
| 77 |
+
|
| 78 |
+
def get_loss(self):
|
| 79 |
+
if self.config.train.loss == 2:
|
| 80 |
+
criterion = Maploss_v2()
|
| 81 |
+
elif self.config.train.loss == 3:
|
| 82 |
+
criterion = Maploss_v3()
|
| 83 |
+
else:
|
| 84 |
+
raise Exception("Undefined loss")
|
| 85 |
+
return criterion
|
| 86 |
+
|
| 87 |
+
def iou_eval(self, dataset, train_step, save_param_path, buffer, model):
|
| 88 |
+
test_config = DotDict(self.config.test[dataset])
|
| 89 |
+
|
| 90 |
+
val_result_dir = os.path.join(
|
| 91 |
+
self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step))
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
evaluator = DetectionIoUEvaluator()
|
| 95 |
+
|
| 96 |
+
metrics = main_eval(
|
| 97 |
+
save_param_path,
|
| 98 |
+
self.config.train.backbone,
|
| 99 |
+
test_config,
|
| 100 |
+
evaluator,
|
| 101 |
+
val_result_dir,
|
| 102 |
+
buffer,
|
| 103 |
+
model,
|
| 104 |
+
self.mode,
|
| 105 |
+
)
|
| 106 |
+
if self.gpu == 0 and self.config.wandb_opt:
|
| 107 |
+
wandb.log(
|
| 108 |
+
{
|
| 109 |
+
"{} IoU Recall".format(dataset): np.round(metrics["recall"], 3),
|
| 110 |
+
"{} IoU Precision".format(dataset): np.round(
|
| 111 |
+
metrics["precision"], 3
|
| 112 |
+
),
|
| 113 |
+
"{} IoU F1-score".format(dataset): np.round(metrics["hmean"], 3),
|
| 114 |
+
}
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def train(self, buffer_dict):
|
| 118 |
+
torch.cuda.set_device(self.gpu)
|
| 119 |
+
|
| 120 |
+
# DATASET -----------------------------------------------------------------------------------------------------#
|
| 121 |
+
trn_loader = self.trn_loader
|
| 122 |
+
|
| 123 |
+
# MODEL -------------------------------------------------------------------------------------------------------#
|
| 124 |
+
if self.config.train.backbone == "vgg":
|
| 125 |
+
craft = CRAFT(pretrained=True, amp=self.config.train.amp)
|
| 126 |
+
else:
|
| 127 |
+
raise Exception("Undefined architecture")
|
| 128 |
+
|
| 129 |
+
if self.config.train.ckpt_path is not None:
|
| 130 |
+
craft.load_state_dict(copyStateDict(self.net_param["craft"]))
|
| 131 |
+
craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft)
|
| 132 |
+
craft = craft.cuda()
|
| 133 |
+
craft = torch.nn.parallel.DistributedDataParallel(craft, device_ids=[self.gpu])
|
| 134 |
+
|
| 135 |
+
torch.backends.cudnn.benchmark = True
|
| 136 |
+
|
| 137 |
+
# OPTIMIZER----------------------------------------------------------------------------------------------------#
|
| 138 |
+
|
| 139 |
+
optimizer = optim.Adam(
|
| 140 |
+
craft.parameters(),
|
| 141 |
+
lr=self.config.train.lr,
|
| 142 |
+
weight_decay=self.config.train.weight_decay,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
|
| 146 |
+
optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
|
| 147 |
+
self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
|
| 148 |
+
self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]
|
| 149 |
+
|
| 150 |
+
# LOSS --------------------------------------------------------------------------------------------------------#
|
| 151 |
+
# mixed precision
|
| 152 |
+
if self.config.train.amp:
|
| 153 |
+
scaler = torch.cuda.amp.GradScaler()
|
| 154 |
+
|
| 155 |
+
# load model
|
| 156 |
+
if (
|
| 157 |
+
self.config.train.ckpt_path is not None
|
| 158 |
+
and self.config.train.st_iter != 0
|
| 159 |
+
):
|
| 160 |
+
scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
|
| 161 |
+
else:
|
| 162 |
+
scaler = None
|
| 163 |
+
|
| 164 |
+
criterion = self.get_loss()
|
| 165 |
+
|
| 166 |
+
# TRAIN -------------------------------------------------------------------------------------------------------#
|
| 167 |
+
train_step = self.config.train.st_iter
|
| 168 |
+
whole_training_step = self.config.train.end_iter
|
| 169 |
+
update_lr_rate_step = 0
|
| 170 |
+
training_lr = self.config.train.lr
|
| 171 |
+
loss_value = 0
|
| 172 |
+
batch_time = 0
|
| 173 |
+
epoch = 0
|
| 174 |
+
start_time = time.time()
|
| 175 |
+
|
| 176 |
+
while train_step < whole_training_step:
|
| 177 |
+
self.trn_sampler.set_epoch(train_step)
|
| 178 |
+
for (
|
| 179 |
+
index,
|
| 180 |
+
(image, region_image, affinity_image, confidence_mask,),
|
| 181 |
+
) in enumerate(trn_loader):
|
| 182 |
+
craft.train()
|
| 183 |
+
if train_step > 0 and train_step % self.config.train.lr_decay == 0:
|
| 184 |
+
update_lr_rate_step += 1
|
| 185 |
+
training_lr = self.adjust_learning_rate(
|
| 186 |
+
optimizer,
|
| 187 |
+
self.config.train.gamma,
|
| 188 |
+
update_lr_rate_step,
|
| 189 |
+
self.config.train.lr,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
images = image.cuda(non_blocking=True)
|
| 193 |
+
region_image_label = region_image.cuda(non_blocking=True)
|
| 194 |
+
affinity_image_label = affinity_image.cuda(non_blocking=True)
|
| 195 |
+
confidence_mask_label = confidence_mask.cuda(non_blocking=True)
|
| 196 |
+
|
| 197 |
+
if self.config.train.amp:
|
| 198 |
+
with torch.cuda.amp.autocast():
|
| 199 |
+
|
| 200 |
+
output, _ = craft(images)
|
| 201 |
+
out1 = output[:, :, :, 0]
|
| 202 |
+
out2 = output[:, :, :, 1]
|
| 203 |
+
|
| 204 |
+
loss = criterion(
|
| 205 |
+
region_image_label,
|
| 206 |
+
affinity_image_label,
|
| 207 |
+
out1,
|
| 208 |
+
out2,
|
| 209 |
+
confidence_mask_label,
|
| 210 |
+
self.config.train.neg_rto,
|
| 211 |
+
self.config.train.n_min_neg,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
optimizer.zero_grad()
|
| 215 |
+
scaler.scale(loss).backward()
|
| 216 |
+
scaler.step(optimizer)
|
| 217 |
+
scaler.update()
|
| 218 |
+
|
| 219 |
+
else:
|
| 220 |
+
output, _ = craft(images)
|
| 221 |
+
out1 = output[:, :, :, 0]
|
| 222 |
+
out2 = output[:, :, :, 1]
|
| 223 |
+
loss = criterion(
|
| 224 |
+
region_image_label,
|
| 225 |
+
affinity_image_label,
|
| 226 |
+
out1,
|
| 227 |
+
out2,
|
| 228 |
+
confidence_mask_label,
|
| 229 |
+
self.config.train.neg_rto,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
optimizer.zero_grad()
|
| 233 |
+
loss.backward()
|
| 234 |
+
optimizer.step()
|
| 235 |
+
|
| 236 |
+
end_time = time.time()
|
| 237 |
+
loss_value += loss.item()
|
| 238 |
+
batch_time += end_time - start_time
|
| 239 |
+
|
| 240 |
+
if train_step > 0 and train_step % 5 == 0 and self.gpu == 0:
|
| 241 |
+
mean_loss = loss_value / 5
|
| 242 |
+
loss_value = 0
|
| 243 |
+
avg_batch_time = batch_time / 5
|
| 244 |
+
batch_time = 0
|
| 245 |
+
|
| 246 |
+
print(
|
| 247 |
+
"{}, training_step: {}|{}, learning rate: {:.8f}, "
|
| 248 |
+
"training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
|
| 249 |
+
time.strftime(
|
| 250 |
+
"%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
|
| 251 |
+
),
|
| 252 |
+
train_step,
|
| 253 |
+
whole_training_step,
|
| 254 |
+
training_lr,
|
| 255 |
+
mean_loss,
|
| 256 |
+
avg_batch_time,
|
| 257 |
+
)
|
| 258 |
+
)
|
| 259 |
+
if self.gpu == 0 and self.config.wandb_opt:
|
| 260 |
+
wandb.log({"train_step": train_step, "mean_loss": mean_loss})
|
| 261 |
+
|
| 262 |
+
if (
|
| 263 |
+
train_step % self.config.train.eval_interval == 0
|
| 264 |
+
and train_step != 0
|
| 265 |
+
):
|
| 266 |
+
|
| 267 |
+
# initialize all buffer value with zero
|
| 268 |
+
if self.gpu == 0:
|
| 269 |
+
for buffer in buffer_dict.values():
|
| 270 |
+
for i in range(len(buffer)):
|
| 271 |
+
buffer[i] = None
|
| 272 |
+
|
| 273 |
+
print("Saving state, index:", train_step)
|
| 274 |
+
save_param_dic = {
|
| 275 |
+
"iter": train_step,
|
| 276 |
+
"craft": craft.state_dict(),
|
| 277 |
+
"optimizer": optimizer.state_dict(),
|
| 278 |
+
}
|
| 279 |
+
save_param_path = (
|
| 280 |
+
self.config.results_dir
|
| 281 |
+
+ "/CRAFT_clr_"
|
| 282 |
+
+ repr(train_step)
|
| 283 |
+
+ ".pth"
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
if self.config.train.amp:
|
| 287 |
+
save_param_dic["scaler"] = scaler.state_dict()
|
| 288 |
+
save_param_path = (
|
| 289 |
+
self.config.results_dir
|
| 290 |
+
+ "/CRAFT_clr_amp_"
|
| 291 |
+
+ repr(train_step)
|
| 292 |
+
+ ".pth"
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
if self.gpu == 0:
|
| 296 |
+
torch.save(save_param_dic, save_param_path)
|
| 297 |
+
|
| 298 |
+
# validation
|
| 299 |
+
self.iou_eval(
|
| 300 |
+
"icdar2013",
|
| 301 |
+
train_step,
|
| 302 |
+
save_param_path,
|
| 303 |
+
buffer_dict["icdar2013"],
|
| 304 |
+
craft,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
train_step += 1
|
| 308 |
+
if train_step >= whole_training_step:
|
| 309 |
+
break
|
| 310 |
+
epoch += 1
|
| 311 |
+
|
| 312 |
+
# save last model
|
| 313 |
+
if self.gpu == 0:
|
| 314 |
+
save_param_dic = {
|
| 315 |
+
"iter": train_step,
|
| 316 |
+
"craft": craft.state_dict(),
|
| 317 |
+
"optimizer": optimizer.state_dict(),
|
| 318 |
+
}
|
| 319 |
+
save_param_path = (
|
| 320 |
+
self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
if self.config.train.amp:
|
| 324 |
+
save_param_dic["scaler"] = scaler.state_dict()
|
| 325 |
+
save_param_path = (
|
| 326 |
+
self.config.results_dir
|
| 327 |
+
+ "/CRAFT_clr_amp_"
|
| 328 |
+
+ repr(train_step)
|
| 329 |
+
+ ".pth"
|
| 330 |
+
)
|
| 331 |
+
torch.save(save_param_dic, save_param_path)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def main():
|
| 335 |
+
parser = argparse.ArgumentParser(description="CRAFT SynthText Train")
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
"--yaml",
|
| 338 |
+
"--yaml_file_name",
|
| 339 |
+
default="syn_train",
|
| 340 |
+
type=str,
|
| 341 |
+
help="Load configuration",
|
| 342 |
+
)
|
| 343 |
+
parser.add_argument(
|
| 344 |
+
"--port", "--use ddp port", default="2646", type=str, help="Load configuration"
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
args = parser.parse_args()
|
| 348 |
+
|
| 349 |
+
# load configure
|
| 350 |
+
exp_name = args.yaml
|
| 351 |
+
config = load_yaml(args.yaml)
|
| 352 |
+
|
| 353 |
+
print("-" * 20 + " Options " + "-" * 20)
|
| 354 |
+
print(yaml.dump(config))
|
| 355 |
+
print("-" * 40)
|
| 356 |
+
|
| 357 |
+
# Make result_dir
|
| 358 |
+
res_dir = os.path.join(config["results_dir"], args.yaml)
|
| 359 |
+
config["results_dir"] = res_dir
|
| 360 |
+
if not os.path.exists(res_dir):
|
| 361 |
+
os.makedirs(res_dir)
|
| 362 |
+
|
| 363 |
+
# Duplicate yaml file to result_dir
|
| 364 |
+
shutil.copy(
|
| 365 |
+
"config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
ngpus_per_node = torch.cuda.device_count()
|
| 369 |
+
print(f"Total device num : {ngpus_per_node}")
|
| 370 |
+
|
| 371 |
+
manager = mp.Manager()
|
| 372 |
+
buffer1 = manager.list([None] * config["test"]["icdar2013"]["test_set_size"])
|
| 373 |
+
buffer_dict = {"icdar2013": buffer1}
|
| 374 |
+
torch.multiprocessing.spawn(
|
| 375 |
+
main_worker,
|
| 376 |
+
nprocs=ngpus_per_node,
|
| 377 |
+
args=(args.port, ngpus_per_node, config, buffer_dict, exp_name,),
|
| 378 |
+
)
|
| 379 |
+
print('flag5')
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def main_worker(gpu, port, ngpus_per_node, config, buffer_dict, exp_name):
|
| 383 |
+
|
| 384 |
+
torch.distributed.init_process_group(
|
| 385 |
+
backend="nccl",
|
| 386 |
+
init_method="tcp://127.0.0.1:" + port,
|
| 387 |
+
world_size=ngpus_per_node,
|
| 388 |
+
rank=gpu,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# Apply config to wandb
|
| 392 |
+
if gpu == 0 and config["wandb_opt"]:
|
| 393 |
+
wandb.init(project="craft-stage1", entity="gmuffiness", name=exp_name)
|
| 394 |
+
wandb.config.update(config)
|
| 395 |
+
|
| 396 |
+
batch_size = int(config["train"]["batch_size"] / ngpus_per_node)
|
| 397 |
+
config["train"]["batch_size"] = batch_size
|
| 398 |
+
config = DotDict(config)
|
| 399 |
+
|
| 400 |
+
# Start train
|
| 401 |
+
trainer = Trainer(config, gpu)
|
| 402 |
+
trainer.train(buffer_dict)
|
| 403 |
+
|
| 404 |
+
if gpu == 0 and config["wandb_opt"]:
|
| 405 |
+
wandb.finish()
|
| 406 |
+
torch.distributed.destroy_process_group()
|
| 407 |
+
|
| 408 |
+
if __name__ == "__main__":
|
| 409 |
+
main()
|
trainer/craft/train_distributed.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import time
|
| 6 |
+
import multiprocessing as mp
|
| 7 |
+
import yaml
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
import wandb
|
| 14 |
+
|
| 15 |
+
from config.load_config import load_yaml, DotDict
|
| 16 |
+
from data.dataset import SynthTextDataSet, CustomDataset
|
| 17 |
+
from loss.mseloss import Maploss_v2, Maploss_v3
|
| 18 |
+
from model.craft import CRAFT
|
| 19 |
+
from eval import main_eval
|
| 20 |
+
from metrics.eval_det_iou import DetectionIoUEvaluator
|
| 21 |
+
from utils.util import copyStateDict, save_parser
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Trainer(object):
|
| 25 |
+
def __init__(self, config, gpu, mode):
|
| 26 |
+
|
| 27 |
+
self.config = config
|
| 28 |
+
self.gpu = gpu
|
| 29 |
+
self.mode = mode
|
| 30 |
+
self.net_param = self.get_load_param(gpu)
|
| 31 |
+
|
| 32 |
+
def get_synth_loader(self):
|
| 33 |
+
|
| 34 |
+
dataset = SynthTextDataSet(
|
| 35 |
+
output_size=self.config.train.data.output_size,
|
| 36 |
+
data_dir=self.config.train.synth_data_dir,
|
| 37 |
+
saved_gt_dir=None,
|
| 38 |
+
mean=self.config.train.data.mean,
|
| 39 |
+
variance=self.config.train.data.variance,
|
| 40 |
+
gauss_init_size=self.config.train.data.gauss_init_size,
|
| 41 |
+
gauss_sigma=self.config.train.data.gauss_sigma,
|
| 42 |
+
enlarge_region=self.config.train.data.enlarge_region,
|
| 43 |
+
enlarge_affinity=self.config.train.data.enlarge_affinity,
|
| 44 |
+
aug=self.config.train.data.syn_aug,
|
| 45 |
+
vis_test_dir=self.config.vis_test_dir,
|
| 46 |
+
vis_opt=self.config.train.data.vis_opt,
|
| 47 |
+
sample=self.config.train.data.syn_sample,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
syn_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
| 51 |
+
|
| 52 |
+
syn_loader = torch.utils.data.DataLoader(
|
| 53 |
+
dataset,
|
| 54 |
+
batch_size=self.config.train.batch_size // self.config.train.synth_ratio,
|
| 55 |
+
shuffle=False,
|
| 56 |
+
num_workers=self.config.train.num_workers,
|
| 57 |
+
sampler=syn_sampler,
|
| 58 |
+
drop_last=True,
|
| 59 |
+
pin_memory=True,
|
| 60 |
+
)
|
| 61 |
+
return syn_loader
|
| 62 |
+
|
| 63 |
+
def get_custom_dataset(self):
|
| 64 |
+
|
| 65 |
+
custom_dataset = CustomDataset(
|
| 66 |
+
output_size=self.config.train.data.output_size,
|
| 67 |
+
data_dir=self.config.data_root_dir,
|
| 68 |
+
saved_gt_dir=None,
|
| 69 |
+
mean=self.config.train.data.mean,
|
| 70 |
+
variance=self.config.train.data.variance,
|
| 71 |
+
gauss_init_size=self.config.train.data.gauss_init_size,
|
| 72 |
+
gauss_sigma=self.config.train.data.gauss_sigma,
|
| 73 |
+
enlarge_region=self.config.train.data.enlarge_region,
|
| 74 |
+
enlarge_affinity=self.config.train.data.enlarge_affinity,
|
| 75 |
+
watershed_param=self.config.train.data.watershed,
|
| 76 |
+
aug=self.config.train.data.custom_aug,
|
| 77 |
+
vis_test_dir=self.config.vis_test_dir,
|
| 78 |
+
sample=self.config.train.data.custom_sample,
|
| 79 |
+
vis_opt=self.config.train.data.vis_opt,
|
| 80 |
+
pseudo_vis_opt=self.config.train.data.pseudo_vis_opt,
|
| 81 |
+
do_not_care_label=self.config.train.data.do_not_care_label,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
return custom_dataset
|
| 85 |
+
|
| 86 |
+
def get_load_param(self, gpu):
|
| 87 |
+
|
| 88 |
+
if self.config.train.ckpt_path is not None:
|
| 89 |
+
map_location = "cuda:%d" % gpu
|
| 90 |
+
param = torch.load(self.config.train.ckpt_path, map_location=map_location)
|
| 91 |
+
else:
|
| 92 |
+
param = None
|
| 93 |
+
|
| 94 |
+
return param
|
| 95 |
+
|
| 96 |
+
def adjust_learning_rate(self, optimizer, gamma, step, lr):
|
| 97 |
+
lr = lr * (gamma ** step)
|
| 98 |
+
for param_group in optimizer.param_groups:
|
| 99 |
+
param_group["lr"] = lr
|
| 100 |
+
return param_group["lr"]
|
| 101 |
+
|
| 102 |
+
def get_loss(self):
|
| 103 |
+
if self.config.train.loss == 2:
|
| 104 |
+
criterion = Maploss_v2()
|
| 105 |
+
elif self.config.train.loss == 3:
|
| 106 |
+
criterion = Maploss_v3()
|
| 107 |
+
else:
|
| 108 |
+
raise Exception("Undefined loss")
|
| 109 |
+
return criterion
|
| 110 |
+
|
| 111 |
+
def iou_eval(self, dataset, train_step, buffer, model):
|
| 112 |
+
test_config = DotDict(self.config.test[dataset])
|
| 113 |
+
|
| 114 |
+
val_result_dir = os.path.join(
|
| 115 |
+
self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step))
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
evaluator = DetectionIoUEvaluator()
|
| 119 |
+
|
| 120 |
+
metrics = main_eval(
|
| 121 |
+
None,
|
| 122 |
+
self.config.train.backbone,
|
| 123 |
+
test_config,
|
| 124 |
+
evaluator,
|
| 125 |
+
val_result_dir,
|
| 126 |
+
buffer,
|
| 127 |
+
model,
|
| 128 |
+
self.mode,
|
| 129 |
+
)
|
| 130 |
+
if self.gpu == 0 and self.config.wandb_opt:
|
| 131 |
+
wandb.log(
|
| 132 |
+
{
|
| 133 |
+
"{} iou Recall".format(dataset): np.round(metrics["recall"], 3),
|
| 134 |
+
"{} iou Precision".format(dataset): np.round(
|
| 135 |
+
metrics["precision"], 3
|
| 136 |
+
),
|
| 137 |
+
"{} iou F1-score".format(dataset): np.round(metrics["hmean"], 3),
|
| 138 |
+
}
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def train(self, buffer_dict):
|
| 142 |
+
|
| 143 |
+
torch.cuda.set_device(self.gpu)
|
| 144 |
+
total_gpu_num = torch.cuda.device_count()
|
| 145 |
+
|
| 146 |
+
# MODEL -------------------------------------------------------------------------------------------------------#
|
| 147 |
+
# SUPERVISION model
|
| 148 |
+
if self.config.mode == "weak_supervision":
|
| 149 |
+
if self.config.train.backbone == "vgg":
|
| 150 |
+
supervision_model = CRAFT(pretrained=False, amp=self.config.train.amp)
|
| 151 |
+
else:
|
| 152 |
+
raise Exception("Undefined architecture")
|
| 153 |
+
|
| 154 |
+
# NOTE: only work on half GPU assign train / half GPU assign supervision setting
|
| 155 |
+
supervision_device = total_gpu_num // 2 + self.gpu
|
| 156 |
+
if self.config.train.ckpt_path is not None:
|
| 157 |
+
supervision_param = self.get_load_param(supervision_device)
|
| 158 |
+
supervision_model.load_state_dict(
|
| 159 |
+
copyStateDict(supervision_param["craft"])
|
| 160 |
+
)
|
| 161 |
+
supervision_model = supervision_model.to(f"cuda:{supervision_device}")
|
| 162 |
+
print(f"Supervision model loading on : gpu {supervision_device}")
|
| 163 |
+
else:
|
| 164 |
+
supervision_model, supervision_device = None, None
|
| 165 |
+
|
| 166 |
+
# TRAIN model
|
| 167 |
+
if self.config.train.backbone == "vgg":
|
| 168 |
+
craft = CRAFT(pretrained=False, amp=self.config.train.amp)
|
| 169 |
+
else:
|
| 170 |
+
raise Exception("Undefined architecture")
|
| 171 |
+
|
| 172 |
+
if self.config.train.ckpt_path is not None:
|
| 173 |
+
craft.load_state_dict(copyStateDict(self.net_param["craft"]))
|
| 174 |
+
|
| 175 |
+
craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft)
|
| 176 |
+
craft = craft.cuda()
|
| 177 |
+
craft = torch.nn.parallel.DistributedDataParallel(craft, device_ids=[self.gpu])
|
| 178 |
+
|
| 179 |
+
torch.backends.cudnn.benchmark = True
|
| 180 |
+
|
| 181 |
+
# DATASET -----------------------------------------------------------------------------------------------------#
|
| 182 |
+
|
| 183 |
+
if self.config.train.use_synthtext:
|
| 184 |
+
trn_syn_loader = self.get_synth_loader()
|
| 185 |
+
batch_syn = iter(trn_syn_loader)
|
| 186 |
+
|
| 187 |
+
if self.config.train.real_dataset == "custom":
|
| 188 |
+
trn_real_dataset = self.get_custom_dataset()
|
| 189 |
+
else:
|
| 190 |
+
raise Exception("Undefined dataset")
|
| 191 |
+
|
| 192 |
+
if self.config.mode == "weak_supervision":
|
| 193 |
+
trn_real_dataset.update_model(supervision_model)
|
| 194 |
+
trn_real_dataset.update_device(supervision_device)
|
| 195 |
+
|
| 196 |
+
trn_real_sampler = torch.utils.data.distributed.DistributedSampler(
|
| 197 |
+
trn_real_dataset
|
| 198 |
+
)
|
| 199 |
+
trn_real_loader = torch.utils.data.DataLoader(
|
| 200 |
+
trn_real_dataset,
|
| 201 |
+
batch_size=self.config.train.batch_size,
|
| 202 |
+
shuffle=False,
|
| 203 |
+
num_workers=self.config.train.num_workers,
|
| 204 |
+
sampler=trn_real_sampler,
|
| 205 |
+
drop_last=False,
|
| 206 |
+
pin_memory=True,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# OPTIMIZER ---------------------------------------------------------------------------------------------------#
|
| 210 |
+
optimizer = optim.Adam(
|
| 211 |
+
craft.parameters(),
|
| 212 |
+
lr=self.config.train.lr,
|
| 213 |
+
weight_decay=self.config.train.weight_decay,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
|
| 217 |
+
optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
|
| 218 |
+
self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
|
| 219 |
+
self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]
|
| 220 |
+
|
| 221 |
+
# LOSS --------------------------------------------------------------------------------------------------------#
|
| 222 |
+
# mixed precision
|
| 223 |
+
if self.config.train.amp:
|
| 224 |
+
scaler = torch.cuda.amp.GradScaler()
|
| 225 |
+
|
| 226 |
+
if (
|
| 227 |
+
self.config.train.ckpt_path is not None
|
| 228 |
+
and self.config.train.st_iter != 0
|
| 229 |
+
):
|
| 230 |
+
scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
|
| 231 |
+
else:
|
| 232 |
+
scaler = None
|
| 233 |
+
|
| 234 |
+
criterion = self.get_loss()
|
| 235 |
+
|
| 236 |
+
# TRAIN -------------------------------------------------------------------------------------------------------#
|
| 237 |
+
train_step = self.config.train.st_iter
|
| 238 |
+
whole_training_step = self.config.train.end_iter
|
| 239 |
+
update_lr_rate_step = 0
|
| 240 |
+
training_lr = self.config.train.lr
|
| 241 |
+
loss_value = 0
|
| 242 |
+
batch_time = 0
|
| 243 |
+
start_time = time.time()
|
| 244 |
+
|
| 245 |
+
print(
|
| 246 |
+
"================================ Train start ================================"
|
| 247 |
+
)
|
| 248 |
+
while train_step < whole_training_step:
|
| 249 |
+
trn_real_sampler.set_epoch(train_step)
|
| 250 |
+
for (
|
| 251 |
+
index,
|
| 252 |
+
(
|
| 253 |
+
images,
|
| 254 |
+
region_scores,
|
| 255 |
+
affinity_scores,
|
| 256 |
+
confidence_masks,
|
| 257 |
+
),
|
| 258 |
+
) in enumerate(trn_real_loader):
|
| 259 |
+
craft.train()
|
| 260 |
+
if train_step > 0 and train_step % self.config.train.lr_decay == 0:
|
| 261 |
+
update_lr_rate_step += 1
|
| 262 |
+
training_lr = self.adjust_learning_rate(
|
| 263 |
+
optimizer,
|
| 264 |
+
self.config.train.gamma,
|
| 265 |
+
update_lr_rate_step,
|
| 266 |
+
self.config.train.lr,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
images = images.cuda(non_blocking=True)
|
| 270 |
+
region_scores = region_scores.cuda(non_blocking=True)
|
| 271 |
+
affinity_scores = affinity_scores.cuda(non_blocking=True)
|
| 272 |
+
confidence_masks = confidence_masks.cuda(non_blocking=True)
|
| 273 |
+
|
| 274 |
+
if self.config.train.use_synthtext:
|
| 275 |
+
# Synth image load
|
| 276 |
+
syn_image, syn_region_label, syn_affi_label, syn_confidence_mask = next(
|
| 277 |
+
batch_syn
|
| 278 |
+
)
|
| 279 |
+
syn_image = syn_image.cuda(non_blocking=True)
|
| 280 |
+
syn_region_label = syn_region_label.cuda(non_blocking=True)
|
| 281 |
+
syn_affi_label = syn_affi_label.cuda(non_blocking=True)
|
| 282 |
+
syn_confidence_mask = syn_confidence_mask.cuda(non_blocking=True)
|
| 283 |
+
|
| 284 |
+
# concat syn & custom image
|
| 285 |
+
images = torch.cat((syn_image, images), 0)
|
| 286 |
+
region_image_label = torch.cat(
|
| 287 |
+
(syn_region_label, region_scores), 0
|
| 288 |
+
)
|
| 289 |
+
affinity_image_label = torch.cat((syn_affi_label, affinity_scores), 0)
|
| 290 |
+
confidence_mask_label = torch.cat(
|
| 291 |
+
(syn_confidence_mask, confidence_masks), 0
|
| 292 |
+
)
|
| 293 |
+
else:
|
| 294 |
+
region_image_label = region_scores
|
| 295 |
+
affinity_image_label = affinity_scores
|
| 296 |
+
confidence_mask_label = confidence_masks
|
| 297 |
+
|
| 298 |
+
if self.config.train.amp:
|
| 299 |
+
with torch.cuda.amp.autocast():
|
| 300 |
+
|
| 301 |
+
output, _ = craft(images)
|
| 302 |
+
out1 = output[:, :, :, 0]
|
| 303 |
+
out2 = output[:, :, :, 1]
|
| 304 |
+
|
| 305 |
+
loss = criterion(
|
| 306 |
+
region_image_label,
|
| 307 |
+
affinity_image_label,
|
| 308 |
+
out1,
|
| 309 |
+
out2,
|
| 310 |
+
confidence_mask_label,
|
| 311 |
+
self.config.train.neg_rto,
|
| 312 |
+
self.config.train.n_min_neg,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
optimizer.zero_grad()
|
| 316 |
+
scaler.scale(loss).backward()
|
| 317 |
+
scaler.step(optimizer)
|
| 318 |
+
scaler.update()
|
| 319 |
+
|
| 320 |
+
else:
|
| 321 |
+
output, _ = craft(images)
|
| 322 |
+
out1 = output[:, :, :, 0]
|
| 323 |
+
out2 = output[:, :, :, 1]
|
| 324 |
+
loss = criterion(
|
| 325 |
+
region_image_label,
|
| 326 |
+
affinity_image_label,
|
| 327 |
+
out1,
|
| 328 |
+
out2,
|
| 329 |
+
confidence_mask_label,
|
| 330 |
+
self.config.train.neg_rto,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
optimizer.zero_grad()
|
| 334 |
+
loss.backward()
|
| 335 |
+
optimizer.step()
|
| 336 |
+
|
| 337 |
+
end_time = time.time()
|
| 338 |
+
loss_value += loss.item()
|
| 339 |
+
batch_time += end_time - start_time
|
| 340 |
+
|
| 341 |
+
if train_step > 0 and train_step % 5 == 0 and self.gpu == 0:
|
| 342 |
+
mean_loss = loss_value / 5
|
| 343 |
+
loss_value = 0
|
| 344 |
+
avg_batch_time = batch_time / 5
|
| 345 |
+
batch_time = 0
|
| 346 |
+
|
| 347 |
+
print(
|
| 348 |
+
"{}, training_step: {}|{}, learning rate: {:.8f}, "
|
| 349 |
+
"training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
|
| 350 |
+
time.strftime(
|
| 351 |
+
"%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
|
| 352 |
+
),
|
| 353 |
+
train_step,
|
| 354 |
+
whole_training_step,
|
| 355 |
+
training_lr,
|
| 356 |
+
mean_loss,
|
| 357 |
+
avg_batch_time,
|
| 358 |
+
)
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
if self.gpu == 0 and self.config.wandb_opt:
|
| 362 |
+
wandb.log({"train_step": train_step, "mean_loss": mean_loss})
|
| 363 |
+
|
| 364 |
+
if (
|
| 365 |
+
train_step % self.config.train.eval_interval == 0
|
| 366 |
+
and train_step != 0
|
| 367 |
+
):
|
| 368 |
+
|
| 369 |
+
craft.eval()
|
| 370 |
+
# initialize all buffer value with zero
|
| 371 |
+
if self.gpu == 0:
|
| 372 |
+
for buffer in buffer_dict.values():
|
| 373 |
+
for i in range(len(buffer)):
|
| 374 |
+
buffer[i] = None
|
| 375 |
+
|
| 376 |
+
print("Saving state, index:", train_step)
|
| 377 |
+
save_param_dic = {
|
| 378 |
+
"iter": train_step,
|
| 379 |
+
"craft": craft.state_dict(),
|
| 380 |
+
"optimizer": optimizer.state_dict(),
|
| 381 |
+
}
|
| 382 |
+
save_param_path = (
|
| 383 |
+
self.config.results_dir
|
| 384 |
+
+ "/CRAFT_clr_"
|
| 385 |
+
+ repr(train_step)
|
| 386 |
+
+ ".pth"
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
if self.config.train.amp:
|
| 390 |
+
save_param_dic["scaler"] = scaler.state_dict()
|
| 391 |
+
save_param_path = (
|
| 392 |
+
self.config.results_dir
|
| 393 |
+
+ "/CRAFT_clr_amp_"
|
| 394 |
+
+ repr(train_step)
|
| 395 |
+
+ ".pth"
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
torch.save(save_param_dic, save_param_path)
|
| 399 |
+
|
| 400 |
+
# validation
|
| 401 |
+
self.iou_eval(
|
| 402 |
+
"custom_data",
|
| 403 |
+
train_step,
|
| 404 |
+
buffer_dict["custom_data"],
|
| 405 |
+
craft,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
train_step += 1
|
| 409 |
+
if train_step >= whole_training_step:
|
| 410 |
+
break
|
| 411 |
+
|
| 412 |
+
if self.config.mode == "weak_supervision":
|
| 413 |
+
state_dict = craft.module.state_dict()
|
| 414 |
+
supervision_model.load_state_dict(state_dict)
|
| 415 |
+
trn_real_dataset.update_model(supervision_model)
|
| 416 |
+
|
| 417 |
+
# save last model
|
| 418 |
+
if self.gpu == 0:
|
| 419 |
+
save_param_dic = {
|
| 420 |
+
"iter": train_step,
|
| 421 |
+
"craft": craft.state_dict(),
|
| 422 |
+
"optimizer": optimizer.state_dict(),
|
| 423 |
+
}
|
| 424 |
+
save_param_path = (
|
| 425 |
+
self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
if self.config.train.amp:
|
| 429 |
+
save_param_dic["scaler"] = scaler.state_dict()
|
| 430 |
+
save_param_path = (
|
| 431 |
+
self.config.results_dir
|
| 432 |
+
+ "/CRAFT_clr_amp_"
|
| 433 |
+
+ repr(train_step)
|
| 434 |
+
+ ".pth"
|
| 435 |
+
)
|
| 436 |
+
torch.save(save_param_dic, save_param_path)
|
| 437 |
+
|
| 438 |
+
def main():
|
| 439 |
+
parser = argparse.ArgumentParser(description="CRAFT custom data train")
|
| 440 |
+
parser.add_argument(
|
| 441 |
+
"--yaml",
|
| 442 |
+
"--yaml_file_name",
|
| 443 |
+
default="custom_data_train",
|
| 444 |
+
type=str,
|
| 445 |
+
help="Load configuration",
|
| 446 |
+
)
|
| 447 |
+
parser.add_argument(
|
| 448 |
+
"--port", "--use ddp port", default="2346", type=str, help="Port number"
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
args = parser.parse_args()
|
| 452 |
+
|
| 453 |
+
# load configure
|
| 454 |
+
exp_name = args.yaml
|
| 455 |
+
config = load_yaml(args.yaml)
|
| 456 |
+
|
| 457 |
+
print("-" * 20 + " Options " + "-" * 20)
|
| 458 |
+
print(yaml.dump(config))
|
| 459 |
+
print("-" * 40)
|
| 460 |
+
|
| 461 |
+
# Make result_dir
|
| 462 |
+
res_dir = os.path.join(config["results_dir"], args.yaml)
|
| 463 |
+
config["results_dir"] = res_dir
|
| 464 |
+
if not os.path.exists(res_dir):
|
| 465 |
+
os.makedirs(res_dir)
|
| 466 |
+
|
| 467 |
+
# Duplicate yaml file to result_dir
|
| 468 |
+
shutil.copy(
|
| 469 |
+
"config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
if config["mode"] == "weak_supervision":
|
| 473 |
+
# NOTE: half GPU assign train / half GPU assign supervision setting
|
| 474 |
+
ngpus_per_node = torch.cuda.device_count() // 2
|
| 475 |
+
mode = "weak_supervision"
|
| 476 |
+
else:
|
| 477 |
+
ngpus_per_node = torch.cuda.device_count()
|
| 478 |
+
mode = None
|
| 479 |
+
|
| 480 |
+
print(f"Total process num : {ngpus_per_node}")
|
| 481 |
+
|
| 482 |
+
manager = mp.Manager()
|
| 483 |
+
buffer1 = manager.list([None] * config["test"]["custom_data"]["test_set_size"])
|
| 484 |
+
|
| 485 |
+
buffer_dict = {"custom_data": buffer1}
|
| 486 |
+
torch.multiprocessing.spawn(
|
| 487 |
+
main_worker,
|
| 488 |
+
nprocs=ngpus_per_node,
|
| 489 |
+
args=(args.port, ngpus_per_node, config, buffer_dict, exp_name, mode,),
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def main_worker(gpu, port, ngpus_per_node, config, buffer_dict, exp_name, mode):
|
| 494 |
+
|
| 495 |
+
torch.distributed.init_process_group(
|
| 496 |
+
backend="nccl",
|
| 497 |
+
init_method="tcp://127.0.0.1:" + port,
|
| 498 |
+
world_size=ngpus_per_node,
|
| 499 |
+
rank=gpu,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
# Apply config to wandb
|
| 503 |
+
if gpu == 0 and config["wandb_opt"]:
|
| 504 |
+
wandb.init(project="craft-stage2", entity="user_name", name=exp_name)
|
| 505 |
+
wandb.config.update(config)
|
| 506 |
+
|
| 507 |
+
batch_size = int(config["train"]["batch_size"] / ngpus_per_node)
|
| 508 |
+
config["train"]["batch_size"] = batch_size
|
| 509 |
+
config = DotDict(config)
|
| 510 |
+
|
| 511 |
+
# Start train
|
| 512 |
+
trainer = Trainer(config, gpu, mode)
|
| 513 |
+
trainer.train(buffer_dict)
|
| 514 |
+
|
| 515 |
+
if gpu == 0:
|
| 516 |
+
if config["wandb_opt"]:
|
| 517 |
+
wandb.finish()
|
| 518 |
+
|
| 519 |
+
torch.distributed.barrier()
|
| 520 |
+
torch.distributed.destroy_process_group()
|
| 521 |
+
|
| 522 |
+
if __name__ == "__main__":
|
| 523 |
+
main()
|
trainer/craft/utils/craft_utils.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import cv2
|
| 6 |
+
import math
|
| 7 |
+
import numpy as np
|
| 8 |
+
from data import imgproc
|
| 9 |
+
|
| 10 |
+
""" auxilary functions """
|
| 11 |
+
# unwarp corodinates
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def warpCoord(Minv, pt):
|
| 17 |
+
out = np.matmul(Minv, (pt[0], pt[1], 1))
|
| 18 |
+
return np.array([out[0]/out[2], out[1]/out[2]])
|
| 19 |
+
""" end of auxilary functions """
|
| 20 |
+
|
| 21 |
+
def test():
|
| 22 |
+
print('pass')
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
|
| 26 |
+
# prepare data
|
| 27 |
+
linkmap = linkmap.copy()
|
| 28 |
+
textmap = textmap.copy()
|
| 29 |
+
img_h, img_w = textmap.shape
|
| 30 |
+
|
| 31 |
+
""" labeling method """
|
| 32 |
+
ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
|
| 33 |
+
ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
|
| 34 |
+
|
| 35 |
+
text_score_comb = np.clip(text_score + link_score, 0, 1)
|
| 36 |
+
nLabels, labels, stats, centroids = \
|
| 37 |
+
cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)
|
| 38 |
+
|
| 39 |
+
det = []
|
| 40 |
+
mapper = []
|
| 41 |
+
for k in range(1,nLabels):
|
| 42 |
+
# size filtering
|
| 43 |
+
size = stats[k, cv2.CC_STAT_AREA]
|
| 44 |
+
if size < 10: continue
|
| 45 |
+
|
| 46 |
+
# thresholding
|
| 47 |
+
if np.max(textmap[labels==k]) < text_threshold: continue
|
| 48 |
+
|
| 49 |
+
# make segmentation map
|
| 50 |
+
segmap = np.zeros(textmap.shape, dtype=np.uint8)
|
| 51 |
+
segmap[labels==k] = 255
|
| 52 |
+
segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area
|
| 53 |
+
x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
|
| 54 |
+
w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
|
| 55 |
+
niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
|
| 56 |
+
sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
|
| 57 |
+
# boundary check
|
| 58 |
+
if sx < 0 : sx = 0
|
| 59 |
+
if sy < 0 : sy = 0
|
| 60 |
+
if ex >= img_w: ex = img_w
|
| 61 |
+
if ey >= img_h: ey = img_h
|
| 62 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))
|
| 63 |
+
segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel, iterations=1)
|
| 64 |
+
#kernel1 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 5))
|
| 65 |
+
#segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel1, iterations=1)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# make box
|
| 69 |
+
np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
|
| 70 |
+
rectangle = cv2.minAreaRect(np_contours)
|
| 71 |
+
box = cv2.boxPoints(rectangle)
|
| 72 |
+
|
| 73 |
+
# align diamond-shape
|
| 74 |
+
w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
|
| 75 |
+
box_ratio = max(w, h) / (min(w, h) + 1e-5)
|
| 76 |
+
if abs(1 - box_ratio) <= 0.1:
|
| 77 |
+
l, r = min(np_contours[:,0]), max(np_contours[:,0])
|
| 78 |
+
t, b = min(np_contours[:,1]), max(np_contours[:,1])
|
| 79 |
+
box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
|
| 80 |
+
|
| 81 |
+
# make clock-wise order
|
| 82 |
+
startidx = box.sum(axis=1).argmin()
|
| 83 |
+
box = np.roll(box, 4-startidx, 0)
|
| 84 |
+
box = np.array(box)
|
| 85 |
+
|
| 86 |
+
det.append(box)
|
| 87 |
+
mapper.append(k)
|
| 88 |
+
|
| 89 |
+
return det, labels, mapper
|
| 90 |
+
|
| 91 |
+
def getPoly_core(boxes, labels, mapper, linkmap):
|
| 92 |
+
# configs
|
| 93 |
+
num_cp = 5
|
| 94 |
+
max_len_ratio = 0.7
|
| 95 |
+
expand_ratio = 1.45
|
| 96 |
+
max_r = 2.0
|
| 97 |
+
step_r = 0.2
|
| 98 |
+
|
| 99 |
+
polys = []
|
| 100 |
+
for k, box in enumerate(boxes):
|
| 101 |
+
# size filter for small instance
|
| 102 |
+
w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)
|
| 103 |
+
if w < 30 or h < 30:
|
| 104 |
+
polys.append(None); continue
|
| 105 |
+
|
| 106 |
+
# warp image
|
| 107 |
+
tar = np.float32([[0,0],[w,0],[w,h],[0,h]])
|
| 108 |
+
M = cv2.getPerspectiveTransform(box, tar)
|
| 109 |
+
word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)
|
| 110 |
+
try:
|
| 111 |
+
Minv = np.linalg.inv(M)
|
| 112 |
+
except:
|
| 113 |
+
polys.append(None); continue
|
| 114 |
+
|
| 115 |
+
# binarization for selected label
|
| 116 |
+
cur_label = mapper[k]
|
| 117 |
+
word_label[word_label != cur_label] = 0
|
| 118 |
+
word_label[word_label > 0] = 1
|
| 119 |
+
|
| 120 |
+
""" Polygon generation """
|
| 121 |
+
# find top/bottom contours
|
| 122 |
+
cp = []
|
| 123 |
+
max_len = -1
|
| 124 |
+
for i in range(w):
|
| 125 |
+
region = np.where(word_label[:,i] != 0)[0]
|
| 126 |
+
if len(region) < 2 : continue
|
| 127 |
+
cp.append((i, region[0], region[-1]))
|
| 128 |
+
length = region[-1] - region[0] + 1
|
| 129 |
+
if length > max_len: max_len = length
|
| 130 |
+
|
| 131 |
+
# pass if max_len is similar to h
|
| 132 |
+
if h * max_len_ratio < max_len:
|
| 133 |
+
polys.append(None); continue
|
| 134 |
+
|
| 135 |
+
# get pivot points with fixed length
|
| 136 |
+
tot_seg = num_cp * 2 + 1
|
| 137 |
+
seg_w = w / tot_seg # segment width
|
| 138 |
+
pp = [None] * num_cp # init pivot points
|
| 139 |
+
cp_section = [[0, 0]] * tot_seg
|
| 140 |
+
seg_height = [0] * num_cp
|
| 141 |
+
seg_num = 0
|
| 142 |
+
num_sec = 0
|
| 143 |
+
prev_h = -1
|
| 144 |
+
for i in range(0,len(cp)):
|
| 145 |
+
(x, sy, ey) = cp[i]
|
| 146 |
+
if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
|
| 147 |
+
# average previous segment
|
| 148 |
+
if num_sec == 0: break
|
| 149 |
+
cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]
|
| 150 |
+
num_sec = 0
|
| 151 |
+
|
| 152 |
+
# reset variables
|
| 153 |
+
seg_num += 1
|
| 154 |
+
prev_h = -1
|
| 155 |
+
|
| 156 |
+
# accumulate center points
|
| 157 |
+
cy = (sy + ey) * 0.5
|
| 158 |
+
cur_h = ey - sy + 1
|
| 159 |
+
cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]
|
| 160 |
+
num_sec += 1
|
| 161 |
+
|
| 162 |
+
if seg_num % 2 == 0: continue # No polygon area
|
| 163 |
+
|
| 164 |
+
if prev_h < cur_h:
|
| 165 |
+
pp[int((seg_num - 1)/2)] = (x, cy)
|
| 166 |
+
seg_height[int((seg_num - 1)/2)] = cur_h
|
| 167 |
+
prev_h = cur_h
|
| 168 |
+
|
| 169 |
+
# processing last segment
|
| 170 |
+
if num_sec != 0:
|
| 171 |
+
cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
|
| 172 |
+
|
| 173 |
+
# pass if num of pivots is not sufficient or segment widh is smaller than character height
|
| 174 |
+
if None in pp or seg_w < np.max(seg_height) * 0.25:
|
| 175 |
+
polys.append(None); continue
|
| 176 |
+
|
| 177 |
+
# calc median maximum of pivot points
|
| 178 |
+
half_char_h = np.median(seg_height) * expand_ratio / 2
|
| 179 |
+
|
| 180 |
+
# calc gradiant and apply to make horizontal pivots
|
| 181 |
+
new_pp = []
|
| 182 |
+
for i, (x, cy) in enumerate(pp):
|
| 183 |
+
dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
|
| 184 |
+
dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
|
| 185 |
+
if dx == 0: # gradient if zero
|
| 186 |
+
new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
|
| 187 |
+
continue
|
| 188 |
+
rad = - math.atan2(dy, dx)
|
| 189 |
+
c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
|
| 190 |
+
new_pp.append([x - s, cy - c, x + s, cy + c])
|
| 191 |
+
|
| 192 |
+
# get edge points to cover character heatmaps
|
| 193 |
+
isSppFound, isEppFound = False, False
|
| 194 |
+
grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
|
| 195 |
+
grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
|
| 196 |
+
for r in np.arange(0.5, max_r, step_r):
|
| 197 |
+
dx = 2 * half_char_h * r
|
| 198 |
+
if not isSppFound:
|
| 199 |
+
line_img = np.zeros(word_label.shape, dtype=np.uint8)
|
| 200 |
+
dy = grad_s * dx
|
| 201 |
+
p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
|
| 202 |
+
cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
|
| 203 |
+
if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
|
| 204 |
+
spp = p
|
| 205 |
+
isSppFound = True
|
| 206 |
+
if not isEppFound:
|
| 207 |
+
line_img = np.zeros(word_label.shape, dtype=np.uint8)
|
| 208 |
+
dy = grad_e * dx
|
| 209 |
+
p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
|
| 210 |
+
cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
|
| 211 |
+
if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
|
| 212 |
+
epp = p
|
| 213 |
+
isEppFound = True
|
| 214 |
+
if isSppFound and isEppFound:
|
| 215 |
+
break
|
| 216 |
+
|
| 217 |
+
# pass if boundary of polygon is not found
|
| 218 |
+
if not (isSppFound and isEppFound):
|
| 219 |
+
polys.append(None); continue
|
| 220 |
+
|
| 221 |
+
# make final polygon
|
| 222 |
+
poly = []
|
| 223 |
+
poly.append(warpCoord(Minv, (spp[0], spp[1])))
|
| 224 |
+
for p in new_pp:
|
| 225 |
+
poly.append(warpCoord(Minv, (p[0], p[1])))
|
| 226 |
+
poly.append(warpCoord(Minv, (epp[0], epp[1])))
|
| 227 |
+
poly.append(warpCoord(Minv, (epp[2], epp[3])))
|
| 228 |
+
for p in reversed(new_pp):
|
| 229 |
+
poly.append(warpCoord(Minv, (p[2], p[3])))
|
| 230 |
+
poly.append(warpCoord(Minv, (spp[2], spp[3])))
|
| 231 |
+
|
| 232 |
+
# add to final result
|
| 233 |
+
polys.append(np.array(poly))
|
| 234 |
+
|
| 235 |
+
return polys
|
| 236 |
+
|
| 237 |
+
def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
|
| 238 |
+
boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
|
| 239 |
+
|
| 240 |
+
if poly:
|
| 241 |
+
polys = getPoly_core(boxes, labels, mapper, linkmap)
|
| 242 |
+
else:
|
| 243 |
+
polys = [None] * len(boxes)
|
| 244 |
+
|
| 245 |
+
return boxes, polys
|
| 246 |
+
|
| 247 |
+
def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):
|
| 248 |
+
if len(polys) > 0:
|
| 249 |
+
polys = np.array(polys)
|
| 250 |
+
for k in range(len(polys)):
|
| 251 |
+
if polys[k] is not None:
|
| 252 |
+
polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
|
| 253 |
+
return polys
|
| 254 |
+
|
| 255 |
+
def save_outputs(image, region_scores, affinity_scores, text_threshold, link_threshold,
|
| 256 |
+
low_text, outoput_path, confidence_mask = None):
|
| 257 |
+
"""save image, region_scores, and affinity_scores in a single image. region_scores and affinity_scores must be
|
| 258 |
+
cpu numpy arrays. You can convert GPU Tensors to CPU numpy arrays like this:
|
| 259 |
+
>>> array = tensor.cpu().data.numpy()
|
| 260 |
+
When saving outputs of the network during training, make sure you convert ALL tensors (image, region_score,
|
| 261 |
+
affinity_score) to numpy array first.
|
| 262 |
+
:param image: numpy array
|
| 263 |
+
:param region_scores: [] 2D numpy array with each element between 0~1.
|
| 264 |
+
:param affinity_scores: same as region_scores
|
| 265 |
+
:param text_threshold: 0 ~ 1. Closer to 0, characters with lower confidence will also be considered a word and be boxed
|
| 266 |
+
:param link_threshold: 0 ~ 1. Closer to 0, links with lower confidence will also be considered a word and be boxed
|
| 267 |
+
:param low_text: 0 ~ 1. Closer to 0, boxes will be more loosely drawn.
|
| 268 |
+
:param outoput_path:
|
| 269 |
+
:param confidence_mask:
|
| 270 |
+
:return:
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
assert region_scores.shape == affinity_scores.shape
|
| 274 |
+
assert len(image.shape) - 1 == len(region_scores.shape)
|
| 275 |
+
|
| 276 |
+
boxes, polys = getDetBoxes(region_scores, affinity_scores, text_threshold, link_threshold,
|
| 277 |
+
low_text, False)
|
| 278 |
+
boxes = np.array(boxes, np.int32) * 2
|
| 279 |
+
if len(boxes) > 0:
|
| 280 |
+
np.clip(boxes[:, :, 0], 0, image.shape[1])
|
| 281 |
+
np.clip(boxes[:, :, 1], 0, image.shape[0])
|
| 282 |
+
for box in boxes:
|
| 283 |
+
cv2.polylines(image, [np.reshape(box, (-1, 1, 2))], True, (0, 0, 255))
|
| 284 |
+
|
| 285 |
+
target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores)
|
| 286 |
+
target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores)
|
| 287 |
+
|
| 288 |
+
if confidence_mask is not None:
|
| 289 |
+
confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask)
|
| 290 |
+
gt_scores = np.hstack([target_gaussian_heatmap_color, target_gaussian_affinity_heatmap_color])
|
| 291 |
+
confidence_mask_gray = np.hstack([np.zeros_like(confidence_mask_gray), confidence_mask_gray])
|
| 292 |
+
output = np.concatenate([gt_scores, confidence_mask_gray], axis=0)
|
| 293 |
+
output = np.hstack([image, output])
|
| 294 |
+
|
| 295 |
+
else:
|
| 296 |
+
gt_scores = np.concatenate([target_gaussian_heatmap_color, target_gaussian_affinity_heatmap_color], axis=0)
|
| 297 |
+
output = np.hstack([image, gt_scores])
|
| 298 |
+
|
| 299 |
+
cv2.imwrite(outoput_path, output)
|
| 300 |
+
return output
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def save_outputs_from_tensors(images, region_scores, affinity_scores, text_threshold, link_threshold,
|
| 304 |
+
low_text, output_dir, image_names, confidence_mask = None):
|
| 305 |
+
|
| 306 |
+
"""takes images, region_scores, and affinity_scores as tensors (cab be GPU).
|
| 307 |
+
:param images: 4D tensor
|
| 308 |
+
:param region_scores: 3D tensor with values between 0 ~ 1
|
| 309 |
+
:param affinity_scores: 3D tensor with values between 0 ~ 1
|
| 310 |
+
:param text_threshold:
|
| 311 |
+
:param link_threshold:
|
| 312 |
+
:param low_text:
|
| 313 |
+
:param output_dir: direcotry to save the output images. Will be joined with base names of image_names
|
| 314 |
+
:param image_names: names of each image. Doesn't have to be the base name (image file names)
|
| 315 |
+
:param confidence_mask:
|
| 316 |
+
:return:
|
| 317 |
+
"""
|
| 318 |
+
#import ipdb;ipdb.set_trace()
|
| 319 |
+
#images = images.cpu().permute(0, 2, 3, 1).contiguous().data.numpy()
|
| 320 |
+
if type(images) == torch.Tensor:
|
| 321 |
+
images = np.array(images)
|
| 322 |
+
|
| 323 |
+
region_scores = region_scores.cpu().data.numpy()
|
| 324 |
+
affinity_scores = affinity_scores.cpu().data.numpy()
|
| 325 |
+
|
| 326 |
+
batch_size = images.shape[0]
|
| 327 |
+
assert batch_size == region_scores.shape[0] and batch_size == affinity_scores.shape[0] and batch_size == len(image_names), \
|
| 328 |
+
"The first dimension (i.e. batch size) of images, region scores, and affinity scores must be equal"
|
| 329 |
+
|
| 330 |
+
output_images = []
|
| 331 |
+
|
| 332 |
+
for i in range(batch_size):
|
| 333 |
+
image = images[i]
|
| 334 |
+
region_score = region_scores[i]
|
| 335 |
+
affinity_score = affinity_scores[i]
|
| 336 |
+
|
| 337 |
+
image_name = os.path.basename(image_names[i])
|
| 338 |
+
outoput_path = os.path.join(output_dir,image_name)
|
| 339 |
+
|
| 340 |
+
output_image = save_outputs(image, region_score, affinity_score, text_threshold, link_threshold,
|
| 341 |
+
low_text, outoput_path, confidence_mask=confidence_mask)
|
| 342 |
+
|
| 343 |
+
output_images.append(output_image)
|
| 344 |
+
|
| 345 |
+
return output_images
|
trainer/craft/utils/inference_boxes.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import itertools
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import time
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch.autograd import Variable
|
| 10 |
+
|
| 11 |
+
from utils.craft_utils import getDetBoxes, adjustResultCoordinates
|
| 12 |
+
from data import imgproc
|
| 13 |
+
from data.dataset import SynthTextDataSet
|
| 14 |
+
import math
|
| 15 |
+
import xml.etree.ElementTree as elemTree
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
#-------------------------------------------------------------------------------------------------------------------#
|
| 19 |
+
def rotatePoint(xc, yc, xp, yp, theta):
|
| 20 |
+
xoff = xp - xc
|
| 21 |
+
yoff = yp - yc
|
| 22 |
+
|
| 23 |
+
cosTheta = math.cos(theta)
|
| 24 |
+
sinTheta = math.sin(theta)
|
| 25 |
+
pResx = cosTheta * xoff + sinTheta * yoff
|
| 26 |
+
pResy = - sinTheta * xoff + cosTheta * yoff
|
| 27 |
+
# pRes = (xc + pResx, yc + pResy)
|
| 28 |
+
return int(xc + pResx), int(yc + pResy)
|
| 29 |
+
|
| 30 |
+
def addRotatedShape(cx, cy, w, h, angle):
|
| 31 |
+
p0x, p0y = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle)
|
| 32 |
+
p1x, p1y = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle)
|
| 33 |
+
p2x, p2y = rotatePoint(cx, cy, cx + w / 2, cy + h / 2, -angle)
|
| 34 |
+
p3x, p3y = rotatePoint(cx, cy, cx - w / 2, cy + h / 2, -angle)
|
| 35 |
+
|
| 36 |
+
points = [[p0x, p0y], [p1x, p1y], [p2x, p2y], [p3x, p3y]]
|
| 37 |
+
|
| 38 |
+
return points
|
| 39 |
+
|
| 40 |
+
def xml_parsing(xml):
|
| 41 |
+
tree = elemTree.parse(xml)
|
| 42 |
+
|
| 43 |
+
annotations = [] # Initialize the list to store labels
|
| 44 |
+
iter_element = tree.iter(tag="object")
|
| 45 |
+
|
| 46 |
+
for element in iter_element:
|
| 47 |
+
annotation = {} # Initialize the dict to store labels
|
| 48 |
+
|
| 49 |
+
annotation['name'] = element.find("name").text # Save the name tag value
|
| 50 |
+
|
| 51 |
+
box_coords = element.iter(tag="robndbox")
|
| 52 |
+
|
| 53 |
+
for box_coord in box_coords:
|
| 54 |
+
cx = float(box_coord.find("cx").text)
|
| 55 |
+
cy = float(box_coord.find("cy").text)
|
| 56 |
+
w = float(box_coord.find("w").text)
|
| 57 |
+
h = float(box_coord.find("h").text)
|
| 58 |
+
angle = float(box_coord.find("angle").text)
|
| 59 |
+
|
| 60 |
+
convertcoodi = addRotatedShape(cx, cy, w, h, angle)
|
| 61 |
+
|
| 62 |
+
annotation['box_coodi'] = convertcoodi
|
| 63 |
+
annotations.append(annotation)
|
| 64 |
+
|
| 65 |
+
box_coords = element.iter(tag="bndbox")
|
| 66 |
+
|
| 67 |
+
for box_coord in box_coords:
|
| 68 |
+
xmin = int(box_coord.find("xmin").text)
|
| 69 |
+
ymin = int(box_coord.find("ymin").text)
|
| 70 |
+
xmax = int(box_coord.find("xmax").text)
|
| 71 |
+
ymax = int(box_coord.find("ymax").text)
|
| 72 |
+
# annotation['bndbox'] = [xmin,ymin,xmax,ymax]
|
| 73 |
+
|
| 74 |
+
annotation['box_coodi'] = [[xmin, ymin], [xmax, ymin], [xmax, ymax],
|
| 75 |
+
[xmin, ymax]]
|
| 76 |
+
annotations.append(annotation)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
bounds = []
|
| 82 |
+
for i in range(len(annotations)):
|
| 83 |
+
box_info_dict = {"points": None, "text": None, "ignore": None}
|
| 84 |
+
|
| 85 |
+
box_info_dict["points"] = np.array(annotations[i]['box_coodi'])
|
| 86 |
+
if annotations[i]['name'] == "dnc":
|
| 87 |
+
box_info_dict["text"] = "###"
|
| 88 |
+
box_info_dict["ignore"] = True
|
| 89 |
+
else:
|
| 90 |
+
box_info_dict["text"] = annotations[i]['name']
|
| 91 |
+
box_info_dict["ignore"] = False
|
| 92 |
+
|
| 93 |
+
bounds.append(box_info_dict)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
return bounds
|
| 98 |
+
|
| 99 |
+
#-------------------------------------------------------------------------------------------------------------------#
|
| 100 |
+
|
| 101 |
+
def load_prescription_gt(dataFolder):
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
total_img_path = []
|
| 105 |
+
total_imgs_bboxes = []
|
| 106 |
+
for (root, directories, files) in os.walk(dataFolder):
|
| 107 |
+
for file in files:
|
| 108 |
+
if '.jpg' in file:
|
| 109 |
+
img_path = os.path.join(root, file)
|
| 110 |
+
total_img_path.append(img_path)
|
| 111 |
+
if '.xml' in file:
|
| 112 |
+
gt_path = os.path.join(root, file)
|
| 113 |
+
total_imgs_bboxes.append(gt_path)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
total_imgs_parsing_bboxes = []
|
| 117 |
+
for img_path, bbox in zip(sorted(total_img_path), sorted(total_imgs_bboxes)):
|
| 118 |
+
# check file
|
| 119 |
+
|
| 120 |
+
assert img_path.split(".jpg")[0] == bbox.split(".xml")[0]
|
| 121 |
+
|
| 122 |
+
result_label = xml_parsing(bbox)
|
| 123 |
+
total_imgs_parsing_bboxes.append(result_label)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
return total_imgs_parsing_bboxes, sorted(total_img_path)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# NOTE
|
| 130 |
+
def load_prescription_cleval_gt(dataFolder):
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
total_img_path = []
|
| 134 |
+
total_gt_path = []
|
| 135 |
+
for (root, directories, files) in os.walk(dataFolder):
|
| 136 |
+
for file in files:
|
| 137 |
+
if '.jpg' in file:
|
| 138 |
+
img_path = os.path.join(root, file)
|
| 139 |
+
total_img_path.append(img_path)
|
| 140 |
+
if '_cl.txt' in file:
|
| 141 |
+
gt_path = os.path.join(root, file)
|
| 142 |
+
total_gt_path.append(gt_path)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
total_imgs_parsing_bboxes = []
|
| 146 |
+
for img_path, gt_path in zip(sorted(total_img_path), sorted(total_gt_path)):
|
| 147 |
+
# check file
|
| 148 |
+
|
| 149 |
+
assert img_path.split(".jpg")[0] == gt_path.split('_label_cl.txt')[0]
|
| 150 |
+
|
| 151 |
+
lines = open(gt_path, encoding="utf-8").readlines()
|
| 152 |
+
word_bboxes = []
|
| 153 |
+
|
| 154 |
+
for line in lines:
|
| 155 |
+
box_info_dict = {"points": None, "text": None, "ignore": None}
|
| 156 |
+
box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
|
| 157 |
+
|
| 158 |
+
box_points = [int(box_info[i]) for i in range(8)]
|
| 159 |
+
box_info_dict["points"] = np.array(box_points)
|
| 160 |
+
|
| 161 |
+
word_bboxes.append(box_info_dict)
|
| 162 |
+
total_imgs_parsing_bboxes.append(word_bboxes)
|
| 163 |
+
|
| 164 |
+
return total_imgs_parsing_bboxes, sorted(total_img_path)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def load_synthtext_gt(data_folder):
|
| 168 |
+
|
| 169 |
+
synth_dataset = SynthTextDataSet(
|
| 170 |
+
output_size=768, data_dir=data_folder, saved_gt_dir=data_folder, logging=False
|
| 171 |
+
)
|
| 172 |
+
img_names, img_bbox, img_words = synth_dataset.load_data(bbox="word")
|
| 173 |
+
|
| 174 |
+
total_img_path = []
|
| 175 |
+
total_imgs_bboxes = []
|
| 176 |
+
for index in range(len(img_bbox[:100])):
|
| 177 |
+
img_path = os.path.join(data_folder, img_names[index][0])
|
| 178 |
+
total_img_path.append(img_path)
|
| 179 |
+
try:
|
| 180 |
+
wordbox = img_bbox[index].transpose((2, 1, 0))
|
| 181 |
+
except:
|
| 182 |
+
wordbox = np.expand_dims(img_bbox[index], axis=0)
|
| 183 |
+
wordbox = wordbox.transpose((0, 2, 1))
|
| 184 |
+
|
| 185 |
+
words = [re.split(" \n|\n |\n| ", t.strip()) for t in img_words[index]]
|
| 186 |
+
words = list(itertools.chain(*words))
|
| 187 |
+
words = [t for t in words if len(t) > 0]
|
| 188 |
+
|
| 189 |
+
if len(words) != len(wordbox):
|
| 190 |
+
import ipdb
|
| 191 |
+
|
| 192 |
+
ipdb.set_trace()
|
| 193 |
+
|
| 194 |
+
single_img_bboxes = []
|
| 195 |
+
for j in range(len(words)):
|
| 196 |
+
box_info_dict = {"points": None, "text": None, "ignore": None}
|
| 197 |
+
box_info_dict["points"] = wordbox[j]
|
| 198 |
+
box_info_dict["text"] = words[j]
|
| 199 |
+
box_info_dict["ignore"] = False
|
| 200 |
+
single_img_bboxes.append(box_info_dict)
|
| 201 |
+
|
| 202 |
+
total_imgs_bboxes.append(single_img_bboxes)
|
| 203 |
+
|
| 204 |
+
return total_imgs_bboxes, total_img_path
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def load_icdar2015_gt(dataFolder, isTraing=False):
|
| 208 |
+
if isTraing:
|
| 209 |
+
img_folderName = "ch4_training_images"
|
| 210 |
+
gt_folderName = "ch4_training_localization_transcription_gt"
|
| 211 |
+
else:
|
| 212 |
+
img_folderName = "ch4_test_images"
|
| 213 |
+
gt_folderName = "ch4_test_localization_transcription_gt"
|
| 214 |
+
|
| 215 |
+
gt_folder_path = os.listdir(os.path.join(dataFolder, gt_folderName))
|
| 216 |
+
total_imgs_bboxes = []
|
| 217 |
+
total_img_path = []
|
| 218 |
+
for gt_path in gt_folder_path:
|
| 219 |
+
gt_path = os.path.join(os.path.join(dataFolder, gt_folderName), gt_path)
|
| 220 |
+
img_path = (
|
| 221 |
+
gt_path.replace(gt_folderName, img_folderName)
|
| 222 |
+
.replace(".txt", ".jpg")
|
| 223 |
+
.replace("gt_", "")
|
| 224 |
+
)
|
| 225 |
+
image = cv2.imread(img_path)
|
| 226 |
+
lines = open(gt_path, encoding="utf-8").readlines()
|
| 227 |
+
single_img_bboxes = []
|
| 228 |
+
for line in lines:
|
| 229 |
+
box_info_dict = {"points": None, "text": None, "ignore": None}
|
| 230 |
+
|
| 231 |
+
box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
|
| 232 |
+
box_points = [int(box_info[j]) for j in range(8)]
|
| 233 |
+
word = box_info[8:]
|
| 234 |
+
word = ",".join(word)
|
| 235 |
+
box_points = np.array(box_points, np.int32).reshape(4, 2)
|
| 236 |
+
cv2.polylines(
|
| 237 |
+
image, [np.array(box_points).astype(np.int)], True, (0, 0, 255), 1
|
| 238 |
+
)
|
| 239 |
+
box_info_dict["points"] = box_points
|
| 240 |
+
box_info_dict["text"] = word
|
| 241 |
+
if word == "###":
|
| 242 |
+
box_info_dict["ignore"] = True
|
| 243 |
+
else:
|
| 244 |
+
box_info_dict["ignore"] = False
|
| 245 |
+
|
| 246 |
+
single_img_bboxes.append(box_info_dict)
|
| 247 |
+
total_imgs_bboxes.append(single_img_bboxes)
|
| 248 |
+
total_img_path.append(img_path)
|
| 249 |
+
return total_imgs_bboxes, total_img_path
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def load_icdar2013_gt(dataFolder, isTraing=False):
|
| 253 |
+
|
| 254 |
+
# choose test dataset
|
| 255 |
+
if isTraing:
|
| 256 |
+
img_folderName = "Challenge2_Test_Task12_Images"
|
| 257 |
+
gt_folderName = "Challenge2_Test_Task1_GT"
|
| 258 |
+
else:
|
| 259 |
+
img_folderName = "Challenge2_Test_Task12_Images"
|
| 260 |
+
gt_folderName = "Challenge2_Test_Task1_GT"
|
| 261 |
+
|
| 262 |
+
gt_folder_path = os.listdir(os.path.join(dataFolder, gt_folderName))
|
| 263 |
+
|
| 264 |
+
total_imgs_bboxes = []
|
| 265 |
+
total_img_path = []
|
| 266 |
+
for gt_path in gt_folder_path:
|
| 267 |
+
gt_path = os.path.join(os.path.join(dataFolder, gt_folderName), gt_path)
|
| 268 |
+
img_path = (
|
| 269 |
+
gt_path.replace(gt_folderName, img_folderName)
|
| 270 |
+
.replace(".txt", ".jpg")
|
| 271 |
+
.replace("gt_", "")
|
| 272 |
+
)
|
| 273 |
+
image = cv2.imread(img_path)
|
| 274 |
+
lines = open(gt_path, encoding="utf-8").readlines()
|
| 275 |
+
single_img_bboxes = []
|
| 276 |
+
for line in lines:
|
| 277 |
+
box_info_dict = {"points": None, "text": None, "ignore": None}
|
| 278 |
+
|
| 279 |
+
box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
|
| 280 |
+
box = [int(box_info[j]) for j in range(4)]
|
| 281 |
+
word = box_info[4:]
|
| 282 |
+
word = ",".join(word)
|
| 283 |
+
box = [
|
| 284 |
+
[box[0], box[1]],
|
| 285 |
+
[box[2], box[1]],
|
| 286 |
+
[box[2], box[3]],
|
| 287 |
+
[box[0], box[3]],
|
| 288 |
+
]
|
| 289 |
+
|
| 290 |
+
box_info_dict["points"] = box
|
| 291 |
+
box_info_dict["text"] = word
|
| 292 |
+
if word == "###":
|
| 293 |
+
box_info_dict["ignore"] = True
|
| 294 |
+
else:
|
| 295 |
+
box_info_dict["ignore"] = False
|
| 296 |
+
|
| 297 |
+
single_img_bboxes.append(box_info_dict)
|
| 298 |
+
|
| 299 |
+
total_imgs_bboxes.append(single_img_bboxes)
|
| 300 |
+
total_img_path.append(img_path)
|
| 301 |
+
|
| 302 |
+
return total_imgs_bboxes, total_img_path
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def test_net(
|
| 306 |
+
net,
|
| 307 |
+
image,
|
| 308 |
+
text_threshold,
|
| 309 |
+
link_threshold,
|
| 310 |
+
low_text,
|
| 311 |
+
cuda,
|
| 312 |
+
poly,
|
| 313 |
+
canvas_size=1280,
|
| 314 |
+
mag_ratio=1.5,
|
| 315 |
+
):
|
| 316 |
+
# resize
|
| 317 |
+
|
| 318 |
+
img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
|
| 319 |
+
image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio
|
| 320 |
+
)
|
| 321 |
+
ratio_h = ratio_w = 1 / target_ratio
|
| 322 |
+
|
| 323 |
+
# preprocessing
|
| 324 |
+
x = imgproc.normalizeMeanVariance(img_resized)
|
| 325 |
+
x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
|
| 326 |
+
x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
|
| 327 |
+
if cuda:
|
| 328 |
+
x = x.cuda()
|
| 329 |
+
|
| 330 |
+
# forward pass
|
| 331 |
+
with torch.no_grad():
|
| 332 |
+
y, feature = net(x)
|
| 333 |
+
|
| 334 |
+
# make score and link map
|
| 335 |
+
score_text = y[0, :, :, 0].cpu().data.numpy().astype(np.float32)
|
| 336 |
+
score_link = y[0, :, :, 1].cpu().data.numpy().astype(np.float32)
|
| 337 |
+
|
| 338 |
+
# NOTE
|
| 339 |
+
score_text = score_text[: size_heatmap[0], : size_heatmap[1]]
|
| 340 |
+
score_link = score_link[: size_heatmap[0], : size_heatmap[1]]
|
| 341 |
+
|
| 342 |
+
# Post-processing
|
| 343 |
+
boxes, polys = getDetBoxes(
|
| 344 |
+
score_text, score_link, text_threshold, link_threshold, low_text, poly
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# coordinate adjustment
|
| 348 |
+
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
|
| 349 |
+
polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
|
| 350 |
+
for k in range(len(polys)):
|
| 351 |
+
if polys[k] is None:
|
| 352 |
+
polys[k] = boxes[k]
|
| 353 |
+
|
| 354 |
+
# render results (optional)
|
| 355 |
+
score_text = score_text.copy()
|
| 356 |
+
render_score_text = imgproc.cvt2HeatmapImg(score_text)
|
| 357 |
+
render_score_link = imgproc.cvt2HeatmapImg(score_link)
|
| 358 |
+
render_img = [render_score_text, render_score_link]
|
| 359 |
+
# ret_score_text = imgproc.cvt2HeatmapImg(render_img)
|
| 360 |
+
|
| 361 |
+
return boxes, polys, render_img
|
trainer/craft/utils/util.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from data import imgproc
|
| 8 |
+
from utils import craft_utils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def copyStateDict(state_dict):
|
| 12 |
+
if list(state_dict.keys())[0].startswith("module"):
|
| 13 |
+
start_idx = 1
|
| 14 |
+
else:
|
| 15 |
+
start_idx = 0
|
| 16 |
+
new_state_dict = OrderedDict()
|
| 17 |
+
for k, v in state_dict.items():
|
| 18 |
+
name = ".".join(k.split(".")[start_idx:])
|
| 19 |
+
new_state_dict[name] = v
|
| 20 |
+
return new_state_dict
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def saveInput(
|
| 24 |
+
imagename, vis_dir, image, region_scores, affinity_scores, confidence_mask
|
| 25 |
+
):
|
| 26 |
+
image = np.uint8(image.copy())
|
| 27 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 28 |
+
|
| 29 |
+
boxes, polys = craft_utils.getDetBoxes(
|
| 30 |
+
region_scores, affinity_scores, 0.85, 0.2, 0.5, False
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
if image.shape[0] / region_scores.shape[0] >= 2:
|
| 34 |
+
boxes = np.array(boxes, np.int32) * 2
|
| 35 |
+
else:
|
| 36 |
+
boxes = np.array(boxes, np.int32)
|
| 37 |
+
|
| 38 |
+
if len(boxes) > 0:
|
| 39 |
+
np.clip(boxes[:, :, 0], 0, image.shape[1])
|
| 40 |
+
np.clip(boxes[:, :, 1], 0, image.shape[0])
|
| 41 |
+
for box in boxes:
|
| 42 |
+
cv2.polylines(image, [np.reshape(box, (-1, 1, 2))], True, (0, 0, 255))
|
| 43 |
+
target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores)
|
| 44 |
+
target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores)
|
| 45 |
+
confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask)
|
| 46 |
+
|
| 47 |
+
# overlay
|
| 48 |
+
height, width, channel = image.shape
|
| 49 |
+
overlay_region = cv2.resize(target_gaussian_heatmap_color, (width, height))
|
| 50 |
+
overlay_aff = cv2.resize(target_gaussian_affinity_heatmap_color, (width, height))
|
| 51 |
+
confidence_mask_gray = cv2.resize(
|
| 52 |
+
confidence_mask_gray, (width, height), interpolation=cv2.INTER_NEAREST
|
| 53 |
+
)
|
| 54 |
+
overlay_region = cv2.addWeighted(image, 0.4, overlay_region, 0.6, 5)
|
| 55 |
+
overlay_aff = cv2.addWeighted(image, 0.4, overlay_aff, 0.7, 6)
|
| 56 |
+
|
| 57 |
+
gt_scores = np.concatenate([overlay_region, overlay_aff], axis=1)
|
| 58 |
+
|
| 59 |
+
output = np.concatenate([gt_scores, confidence_mask_gray], axis=1)
|
| 60 |
+
|
| 61 |
+
output = np.hstack([image, output])
|
| 62 |
+
|
| 63 |
+
# synthtext
|
| 64 |
+
if type(imagename) is not str:
|
| 65 |
+
imagename = imagename[0].split("/")[-1][:-4]
|
| 66 |
+
|
| 67 |
+
outpath = vis_dir + f"/{imagename}_input.jpg"
|
| 68 |
+
if not os.path.exists(os.path.dirname(outpath)):
|
| 69 |
+
os.makedirs(os.path.dirname(outpath), exist_ok=True)
|
| 70 |
+
cv2.imwrite(outpath, output)
|
| 71 |
+
# print(f'Logging train input into {outpath}')
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def saveImage(
|
| 75 |
+
imagename,
|
| 76 |
+
vis_dir,
|
| 77 |
+
image,
|
| 78 |
+
bboxes,
|
| 79 |
+
affi_bboxes,
|
| 80 |
+
region_scores,
|
| 81 |
+
affinity_scores,
|
| 82 |
+
confidence_mask,
|
| 83 |
+
):
|
| 84 |
+
output_image = np.uint8(image.copy())
|
| 85 |
+
output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR)
|
| 86 |
+
if len(bboxes) > 0:
|
| 87 |
+
for i in range(len(bboxes)):
|
| 88 |
+
_bboxes = np.int32(bboxes[i])
|
| 89 |
+
for j in range(_bboxes.shape[0]):
|
| 90 |
+
cv2.polylines(
|
| 91 |
+
output_image,
|
| 92 |
+
[np.reshape(_bboxes[j], (-1, 1, 2))],
|
| 93 |
+
True,
|
| 94 |
+
(0, 0, 255),
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
for i in range(len(affi_bboxes)):
|
| 98 |
+
cv2.polylines(
|
| 99 |
+
output_image,
|
| 100 |
+
[np.reshape(affi_bboxes[i].astype(np.int32), (-1, 1, 2))],
|
| 101 |
+
True,
|
| 102 |
+
(255, 0, 0),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores)
|
| 106 |
+
target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores)
|
| 107 |
+
confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask)
|
| 108 |
+
|
| 109 |
+
# overlay
|
| 110 |
+
height, width, channel = image.shape
|
| 111 |
+
overlay_region = cv2.resize(target_gaussian_heatmap_color, (width, height))
|
| 112 |
+
overlay_aff = cv2.resize(target_gaussian_affinity_heatmap_color, (width, height))
|
| 113 |
+
|
| 114 |
+
overlay_region = cv2.addWeighted(image.copy(), 0.4, overlay_region, 0.6, 5)
|
| 115 |
+
overlay_aff = cv2.addWeighted(image.copy(), 0.4, overlay_aff, 0.6, 5)
|
| 116 |
+
|
| 117 |
+
heat_map = np.concatenate([overlay_region, overlay_aff], axis=1)
|
| 118 |
+
|
| 119 |
+
# synthtext
|
| 120 |
+
if type(imagename) is not str:
|
| 121 |
+
imagename = imagename[0].split("/")[-1][:-4]
|
| 122 |
+
|
| 123 |
+
output = np.concatenate([output_image, heat_map, confidence_mask_gray], axis=1)
|
| 124 |
+
outpath = vis_dir + f"/{imagename}.jpg"
|
| 125 |
+
if not os.path.exists(os.path.dirname(outpath)):
|
| 126 |
+
os.makedirs(os.path.dirname(outpath), exist_ok=True)
|
| 127 |
+
|
| 128 |
+
cv2.imwrite(outpath, output)
|
| 129 |
+
# print(f'Logging original image into {outpath}')
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def save_parser(args):
|
| 133 |
+
|
| 134 |
+
""" final options """
|
| 135 |
+
with open(f"{args.results_dir}/opt.txt", "a", encoding="utf-8") as opt_file:
|
| 136 |
+
opt_log = "------------ Options -------------\n"
|
| 137 |
+
arg = vars(args)
|
| 138 |
+
for k, v in arg.items():
|
| 139 |
+
opt_log += f"{str(k)}: {str(v)}\n"
|
| 140 |
+
opt_log += "---------------------------------------\n"
|
| 141 |
+
print(opt_log)
|
| 142 |
+
opt_file.write(opt_log)
|
trainer/dataset.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import re
|
| 4 |
+
import six
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
from natsort import natsorted
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import numpy as np
|
| 12 |
+
from torch.utils.data import Dataset, ConcatDataset, Subset
|
| 13 |
+
from torch._utils import _accumulate
|
| 14 |
+
import torchvision.transforms as transforms
|
| 15 |
+
|
| 16 |
+
def contrast_grey(img):
|
| 17 |
+
high = np.percentile(img, 90)
|
| 18 |
+
low = np.percentile(img, 10)
|
| 19 |
+
return (high-low)/(high+low), high, low
|
| 20 |
+
|
| 21 |
+
def adjust_contrast_grey(img, target = 0.4):
|
| 22 |
+
contrast, high, low = contrast_grey(img)
|
| 23 |
+
if contrast < target:
|
| 24 |
+
img = img.astype(int)
|
| 25 |
+
ratio = 200./(high-low)
|
| 26 |
+
img = (img - low + 25)*ratio
|
| 27 |
+
img = np.maximum(np.full(img.shape, 0) ,np.minimum(np.full(img.shape, 255), img)).astype(np.uint8)
|
| 28 |
+
return img
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Batch_Balanced_Dataset(object):
|
| 32 |
+
|
| 33 |
+
def __init__(self, opt):
|
| 34 |
+
"""
|
| 35 |
+
Modulate the data ratio in the batch.
|
| 36 |
+
For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5",
|
| 37 |
+
the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST.
|
| 38 |
+
"""
|
| 39 |
+
log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
|
| 40 |
+
dashed_line = '-' * 80
|
| 41 |
+
print(dashed_line)
|
| 42 |
+
log.write(dashed_line + '\n')
|
| 43 |
+
print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}')
|
| 44 |
+
log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n')
|
| 45 |
+
assert len(opt.select_data) == len(opt.batch_ratio)
|
| 46 |
+
|
| 47 |
+
_AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust = opt.contrast_adjust)
|
| 48 |
+
self.data_loader_list = []
|
| 49 |
+
self.dataloader_iter_list = []
|
| 50 |
+
batch_size_list = []
|
| 51 |
+
Total_batch_size = 0
|
| 52 |
+
for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio):
|
| 53 |
+
_batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1)
|
| 54 |
+
print(dashed_line)
|
| 55 |
+
log.write(dashed_line + '\n')
|
| 56 |
+
_dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d])
|
| 57 |
+
total_number_dataset = len(_dataset)
|
| 58 |
+
log.write(_dataset_log)
|
| 59 |
+
|
| 60 |
+
"""
|
| 61 |
+
The total number of data can be modified with opt.total_data_usage_ratio.
|
| 62 |
+
ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage.
|
| 63 |
+
See 4.2 section in our paper.
|
| 64 |
+
"""
|
| 65 |
+
number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio))
|
| 66 |
+
dataset_split = [number_dataset, total_number_dataset - number_dataset]
|
| 67 |
+
indices = range(total_number_dataset)
|
| 68 |
+
_dataset, _ = [Subset(_dataset, indices[offset - length:offset])
|
| 69 |
+
for offset, length in zip(_accumulate(dataset_split), dataset_split)]
|
| 70 |
+
selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n'
|
| 71 |
+
selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}'
|
| 72 |
+
print(selected_d_log)
|
| 73 |
+
log.write(selected_d_log + '\n')
|
| 74 |
+
batch_size_list.append(str(_batch_size))
|
| 75 |
+
Total_batch_size += _batch_size
|
| 76 |
+
|
| 77 |
+
_data_loader = torch.utils.data.DataLoader(
|
| 78 |
+
_dataset, batch_size=_batch_size,
|
| 79 |
+
shuffle=True,
|
| 80 |
+
num_workers=int(opt.workers), #prefetch_factor=2,persistent_workers=True,
|
| 81 |
+
collate_fn=_AlignCollate, pin_memory=True)
|
| 82 |
+
self.data_loader_list.append(_data_loader)
|
| 83 |
+
self.dataloader_iter_list.append(iter(_data_loader))
|
| 84 |
+
|
| 85 |
+
Total_batch_size_log = f'{dashed_line}\n'
|
| 86 |
+
batch_size_sum = '+'.join(batch_size_list)
|
| 87 |
+
Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n'
|
| 88 |
+
Total_batch_size_log += f'{dashed_line}'
|
| 89 |
+
opt.batch_size = Total_batch_size
|
| 90 |
+
|
| 91 |
+
print(Total_batch_size_log)
|
| 92 |
+
log.write(Total_batch_size_log + '\n')
|
| 93 |
+
log.close()
|
| 94 |
+
|
| 95 |
+
def get_batch(self):
|
| 96 |
+
balanced_batch_images = []
|
| 97 |
+
balanced_batch_texts = []
|
| 98 |
+
|
| 99 |
+
for i, data_loader_iter in enumerate(self.dataloader_iter_list):
|
| 100 |
+
try:
|
| 101 |
+
image, text = data_loader_iter.next()
|
| 102 |
+
balanced_batch_images.append(image)
|
| 103 |
+
balanced_batch_texts += text
|
| 104 |
+
except StopIteration:
|
| 105 |
+
self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
|
| 106 |
+
image, text = self.dataloader_iter_list[i].next()
|
| 107 |
+
balanced_batch_images.append(image)
|
| 108 |
+
balanced_batch_texts += text
|
| 109 |
+
except ValueError:
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
balanced_batch_images = torch.cat(balanced_batch_images, 0)
|
| 113 |
+
|
| 114 |
+
return balanced_batch_images, balanced_batch_texts
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def hierarchical_dataset(root, opt, select_data='/'):
|
| 118 |
+
""" select_data='/' contains all sub-directory of root directory """
|
| 119 |
+
dataset_list = []
|
| 120 |
+
dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}'
|
| 121 |
+
print(dataset_log)
|
| 122 |
+
dataset_log += '\n'
|
| 123 |
+
for dirpath, dirnames, filenames in os.walk(root+'/'):
|
| 124 |
+
if not dirnames:
|
| 125 |
+
select_flag = False
|
| 126 |
+
for selected_d in select_data:
|
| 127 |
+
if selected_d in dirpath:
|
| 128 |
+
select_flag = True
|
| 129 |
+
break
|
| 130 |
+
|
| 131 |
+
if select_flag:
|
| 132 |
+
dataset = OCRDataset(dirpath, opt)
|
| 133 |
+
sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}'
|
| 134 |
+
print(sub_dataset_log)
|
| 135 |
+
dataset_log += f'{sub_dataset_log}\n'
|
| 136 |
+
dataset_list.append(dataset)
|
| 137 |
+
|
| 138 |
+
concatenated_dataset = ConcatDataset(dataset_list)
|
| 139 |
+
|
| 140 |
+
return concatenated_dataset, dataset_log
|
| 141 |
+
|
| 142 |
+
class OCRDataset(Dataset):
|
| 143 |
+
|
| 144 |
+
def __init__(self, root, opt):
|
| 145 |
+
|
| 146 |
+
self.root = root
|
| 147 |
+
self.opt = opt
|
| 148 |
+
print(root)
|
| 149 |
+
self.df = pd.read_csv(os.path.join(root,'labels.csv'), sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
|
| 150 |
+
self.nSamples = len(self.df)
|
| 151 |
+
|
| 152 |
+
if self.opt.data_filtering_off:
|
| 153 |
+
self.filtered_index_list = [index + 1 for index in range(self.nSamples)]
|
| 154 |
+
else:
|
| 155 |
+
self.filtered_index_list = []
|
| 156 |
+
for index in range(self.nSamples):
|
| 157 |
+
label = self.df.at[index,'words']
|
| 158 |
+
try:
|
| 159 |
+
if len(label) > self.opt.batch_max_length:
|
| 160 |
+
continue
|
| 161 |
+
except:
|
| 162 |
+
print(label)
|
| 163 |
+
out_of_char = f'[^{self.opt.character}]'
|
| 164 |
+
if re.search(out_of_char, label.lower()):
|
| 165 |
+
continue
|
| 166 |
+
self.filtered_index_list.append(index)
|
| 167 |
+
self.nSamples = len(self.filtered_index_list)
|
| 168 |
+
|
| 169 |
+
def __len__(self):
|
| 170 |
+
return self.nSamples
|
| 171 |
+
|
| 172 |
+
def __getitem__(self, index):
|
| 173 |
+
index = self.filtered_index_list[index]
|
| 174 |
+
img_fname = self.df.at[index,'filename']
|
| 175 |
+
img_fpath = os.path.join(self.root, img_fname)
|
| 176 |
+
label = self.df.at[index,'words']
|
| 177 |
+
|
| 178 |
+
if self.opt.rgb:
|
| 179 |
+
img = Image.open(img_fpath).convert('RGB') # for color image
|
| 180 |
+
else:
|
| 181 |
+
img = Image.open(img_fpath).convert('L')
|
| 182 |
+
|
| 183 |
+
if not self.opt.sensitive:
|
| 184 |
+
label = label.lower()
|
| 185 |
+
|
| 186 |
+
# We only train and evaluate on alphanumerics (or pre-defined character set in train.py)
|
| 187 |
+
out_of_char = f'[^{self.opt.character}]'
|
| 188 |
+
label = re.sub(out_of_char, '', label)
|
| 189 |
+
|
| 190 |
+
return (img, label)
|
| 191 |
+
|
| 192 |
+
class ResizeNormalize(object):
|
| 193 |
+
|
| 194 |
+
def __init__(self, size, interpolation=Image.BICUBIC):
|
| 195 |
+
self.size = size
|
| 196 |
+
self.interpolation = interpolation
|
| 197 |
+
self.toTensor = transforms.ToTensor()
|
| 198 |
+
|
| 199 |
+
def __call__(self, img):
|
| 200 |
+
img = img.resize(self.size, self.interpolation)
|
| 201 |
+
img = self.toTensor(img)
|
| 202 |
+
img.sub_(0.5).div_(0.5)
|
| 203 |
+
return img
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class NormalizePAD(object):
|
| 207 |
+
|
| 208 |
+
def __init__(self, max_size, PAD_type='right'):
|
| 209 |
+
self.toTensor = transforms.ToTensor()
|
| 210 |
+
self.max_size = max_size
|
| 211 |
+
self.max_width_half = math.floor(max_size[2] / 2)
|
| 212 |
+
self.PAD_type = PAD_type
|
| 213 |
+
|
| 214 |
+
def __call__(self, img):
|
| 215 |
+
img = self.toTensor(img)
|
| 216 |
+
img.sub_(0.5).div_(0.5)
|
| 217 |
+
c, h, w = img.size()
|
| 218 |
+
Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
|
| 219 |
+
Pad_img[:, :, :w] = img # right pad
|
| 220 |
+
if self.max_size[2] != w: # add border Pad
|
| 221 |
+
Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
|
| 222 |
+
|
| 223 |
+
return Pad_img
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class AlignCollate(object):
|
| 227 |
+
|
| 228 |
+
def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, contrast_adjust = 0.):
|
| 229 |
+
self.imgH = imgH
|
| 230 |
+
self.imgW = imgW
|
| 231 |
+
self.keep_ratio_with_pad = keep_ratio_with_pad
|
| 232 |
+
self.contrast_adjust = contrast_adjust
|
| 233 |
+
|
| 234 |
+
def __call__(self, batch):
|
| 235 |
+
batch = filter(lambda x: x is not None, batch)
|
| 236 |
+
images, labels = zip(*batch)
|
| 237 |
+
|
| 238 |
+
if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper
|
| 239 |
+
resized_max_w = self.imgW
|
| 240 |
+
input_channel = 3 if images[0].mode == 'RGB' else 1
|
| 241 |
+
transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
|
| 242 |
+
|
| 243 |
+
resized_images = []
|
| 244 |
+
for image in images:
|
| 245 |
+
w, h = image.size
|
| 246 |
+
|
| 247 |
+
#### augmentation here - change contrast
|
| 248 |
+
if self.contrast_adjust > 0:
|
| 249 |
+
image = np.array(image.convert("L"))
|
| 250 |
+
image = adjust_contrast_grey(image, target = self.contrast_adjust)
|
| 251 |
+
image = Image.fromarray(image, 'L')
|
| 252 |
+
|
| 253 |
+
ratio = w / float(h)
|
| 254 |
+
if math.ceil(self.imgH * ratio) > self.imgW:
|
| 255 |
+
resized_w = self.imgW
|
| 256 |
+
else:
|
| 257 |
+
resized_w = math.ceil(self.imgH * ratio)
|
| 258 |
+
|
| 259 |
+
resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
|
| 260 |
+
resized_images.append(transform(resized_image))
|
| 261 |
+
# resized_image.save('./image_test/%d_test.jpg' % w)
|
| 262 |
+
|
| 263 |
+
image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
|
| 264 |
+
|
| 265 |
+
else:
|
| 266 |
+
transform = ResizeNormalize((self.imgW, self.imgH))
|
| 267 |
+
image_tensors = [transform(image) for image in images]
|
| 268 |
+
image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)
|
| 269 |
+
|
| 270 |
+
return image_tensors, labels
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def tensor2im(image_tensor, imtype=np.uint8):
|
| 274 |
+
image_numpy = image_tensor.cpu().float().numpy()
|
| 275 |
+
if image_numpy.shape[0] == 1:
|
| 276 |
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
| 277 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
| 278 |
+
return image_numpy.astype(imtype)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def save_image(image_numpy, image_path):
|
| 282 |
+
image_pil = Image.fromarray(image_numpy)
|
| 283 |
+
image_pil.save(image_path)
|
trainer/model.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from modules.transformation import TPS_SpatialTransformerNetwork
|
| 3 |
+
from modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor
|
| 4 |
+
from modules.sequence_modeling import BidirectionalLSTM
|
| 5 |
+
from modules.prediction import Attention
|
| 6 |
+
|
| 7 |
+
class Model(nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, opt):
|
| 10 |
+
super(Model, self).__init__()
|
| 11 |
+
self.opt = opt
|
| 12 |
+
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
|
| 13 |
+
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}
|
| 14 |
+
|
| 15 |
+
""" Transformation """
|
| 16 |
+
if opt.Transformation == 'TPS':
|
| 17 |
+
self.Transformation = TPS_SpatialTransformerNetwork(
|
| 18 |
+
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
|
| 19 |
+
else:
|
| 20 |
+
print('No Transformation module specified')
|
| 21 |
+
|
| 22 |
+
""" FeatureExtraction """
|
| 23 |
+
if opt.FeatureExtraction == 'VGG':
|
| 24 |
+
self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
|
| 25 |
+
elif opt.FeatureExtraction == 'RCNN':
|
| 26 |
+
self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel)
|
| 27 |
+
elif opt.FeatureExtraction == 'ResNet':
|
| 28 |
+
self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
|
| 29 |
+
else:
|
| 30 |
+
raise Exception('No FeatureExtraction module specified')
|
| 31 |
+
self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512
|
| 32 |
+
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
|
| 33 |
+
|
| 34 |
+
""" Sequence modeling"""
|
| 35 |
+
if opt.SequenceModeling == 'BiLSTM':
|
| 36 |
+
self.SequenceModeling = nn.Sequential(
|
| 37 |
+
BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
|
| 38 |
+
BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
|
| 39 |
+
self.SequenceModeling_output = opt.hidden_size
|
| 40 |
+
else:
|
| 41 |
+
print('No SequenceModeling module specified')
|
| 42 |
+
self.SequenceModeling_output = self.FeatureExtraction_output
|
| 43 |
+
|
| 44 |
+
""" Prediction """
|
| 45 |
+
if opt.Prediction == 'CTC':
|
| 46 |
+
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
|
| 47 |
+
elif opt.Prediction == 'Attn':
|
| 48 |
+
self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class)
|
| 49 |
+
else:
|
| 50 |
+
raise Exception('Prediction is neither CTC or Attn')
|
| 51 |
+
|
| 52 |
+
def forward(self, input, text, is_train=True):
|
| 53 |
+
""" Transformation stage """
|
| 54 |
+
if not self.stages['Trans'] == "None":
|
| 55 |
+
input = self.Transformation(input)
|
| 56 |
+
|
| 57 |
+
""" Feature extraction stage """
|
| 58 |
+
visual_feature = self.FeatureExtraction(input)
|
| 59 |
+
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
|
| 60 |
+
visual_feature = visual_feature.squeeze(3)
|
| 61 |
+
|
| 62 |
+
""" Sequence modeling stage """
|
| 63 |
+
if self.stages['Seq'] == 'BiLSTM':
|
| 64 |
+
contextual_feature = self.SequenceModeling(visual_feature)
|
| 65 |
+
else:
|
| 66 |
+
contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM
|
| 67 |
+
|
| 68 |
+
""" Prediction stage """
|
| 69 |
+
if self.stages['Pred'] == 'CTC':
|
| 70 |
+
prediction = self.Prediction(contextual_feature.contiguous())
|
| 71 |
+
else:
|
| 72 |
+
prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length)
|
| 73 |
+
|
| 74 |
+
return prediction
|
trainer/modules/feature_extraction.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class VGG_FeatureExtractor(nn.Module):
|
| 6 |
+
""" FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """
|
| 7 |
+
|
| 8 |
+
def __init__(self, input_channel, output_channel=512):
|
| 9 |
+
super(VGG_FeatureExtractor, self).__init__()
|
| 10 |
+
self.output_channel = [int(output_channel / 8), int(output_channel / 4),
|
| 11 |
+
int(output_channel / 2), output_channel] # [64, 128, 256, 512]
|
| 12 |
+
self.ConvNet = nn.Sequential(
|
| 13 |
+
nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
|
| 14 |
+
nn.MaxPool2d(2, 2), # 64x16x50
|
| 15 |
+
nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True),
|
| 16 |
+
nn.MaxPool2d(2, 2), # 128x8x25
|
| 17 |
+
nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25
|
| 18 |
+
nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
|
| 19 |
+
nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
|
| 20 |
+
nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False),
|
| 21 |
+
nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25
|
| 22 |
+
nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False),
|
| 23 |
+
nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
|
| 24 |
+
nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
|
| 25 |
+
nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24
|
| 26 |
+
|
| 27 |
+
def forward(self, input):
|
| 28 |
+
return self.ConvNet(input)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class RCNN_FeatureExtractor(nn.Module):
|
| 32 |
+
""" FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """
|
| 33 |
+
|
| 34 |
+
def __init__(self, input_channel, output_channel=512):
|
| 35 |
+
super(RCNN_FeatureExtractor, self).__init__()
|
| 36 |
+
self.output_channel = [int(output_channel / 8), int(output_channel / 4),
|
| 37 |
+
int(output_channel / 2), output_channel] # [64, 128, 256, 512]
|
| 38 |
+
self.ConvNet = nn.Sequential(
|
| 39 |
+
nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
|
| 40 |
+
nn.MaxPool2d(2, 2), # 64 x 16 x 50
|
| 41 |
+
GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, pad=1),
|
| 42 |
+
nn.MaxPool2d(2, 2), # 64 x 8 x 25
|
| 43 |
+
GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, pad=1),
|
| 44 |
+
nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26
|
| 45 |
+
GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, pad=1),
|
| 46 |
+
nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27
|
| 47 |
+
nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False),
|
| 48 |
+
nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26
|
| 49 |
+
|
| 50 |
+
def forward(self, input):
|
| 51 |
+
return self.ConvNet(input)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ResNet_FeatureExtractor(nn.Module):
|
| 55 |
+
""" FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
|
| 56 |
+
|
| 57 |
+
def __init__(self, input_channel, output_channel=512):
|
| 58 |
+
super(ResNet_FeatureExtractor, self).__init__()
|
| 59 |
+
self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3])
|
| 60 |
+
|
| 61 |
+
def forward(self, input):
|
| 62 |
+
return self.ConvNet(input)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# For Gated RCNN
|
| 66 |
+
class GRCL(nn.Module):
|
| 67 |
+
|
| 68 |
+
def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad):
|
| 69 |
+
super(GRCL, self).__init__()
|
| 70 |
+
self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False)
|
| 71 |
+
self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False)
|
| 72 |
+
self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False)
|
| 73 |
+
self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False)
|
| 74 |
+
|
| 75 |
+
self.BN_x_init = nn.BatchNorm2d(output_channel)
|
| 76 |
+
|
| 77 |
+
self.num_iteration = num_iteration
|
| 78 |
+
self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)]
|
| 79 |
+
self.GRCL = nn.Sequential(*self.GRCL)
|
| 80 |
+
|
| 81 |
+
def forward(self, input):
|
| 82 |
+
""" The input of GRCL is consistant over time t, which is denoted by u(0)
|
| 83 |
+
thus wgf_u / wf_u is also consistant over time t.
|
| 84 |
+
"""
|
| 85 |
+
wgf_u = self.wgf_u(input)
|
| 86 |
+
wf_u = self.wf_u(input)
|
| 87 |
+
x = F.relu(self.BN_x_init(wf_u))
|
| 88 |
+
|
| 89 |
+
for i in range(self.num_iteration):
|
| 90 |
+
x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x))
|
| 91 |
+
|
| 92 |
+
return x
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class GRCL_unit(nn.Module):
|
| 96 |
+
|
| 97 |
+
def __init__(self, output_channel):
|
| 98 |
+
super(GRCL_unit, self).__init__()
|
| 99 |
+
self.BN_gfu = nn.BatchNorm2d(output_channel)
|
| 100 |
+
self.BN_grx = nn.BatchNorm2d(output_channel)
|
| 101 |
+
self.BN_fu = nn.BatchNorm2d(output_channel)
|
| 102 |
+
self.BN_rx = nn.BatchNorm2d(output_channel)
|
| 103 |
+
self.BN_Gx = nn.BatchNorm2d(output_channel)
|
| 104 |
+
|
| 105 |
+
def forward(self, wgf_u, wgr_x, wf_u, wr_x):
|
| 106 |
+
G_first_term = self.BN_gfu(wgf_u)
|
| 107 |
+
G_second_term = self.BN_grx(wgr_x)
|
| 108 |
+
G = F.sigmoid(G_first_term + G_second_term)
|
| 109 |
+
|
| 110 |
+
x_first_term = self.BN_fu(wf_u)
|
| 111 |
+
x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G)
|
| 112 |
+
x = F.relu(x_first_term + x_second_term)
|
| 113 |
+
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class BasicBlock(nn.Module):
|
| 118 |
+
expansion = 1
|
| 119 |
+
|
| 120 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 121 |
+
super(BasicBlock, self).__init__()
|
| 122 |
+
self.conv1 = self._conv3x3(inplanes, planes)
|
| 123 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 124 |
+
self.conv2 = self._conv3x3(planes, planes)
|
| 125 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 126 |
+
self.relu = nn.ReLU(inplace=True)
|
| 127 |
+
self.downsample = downsample
|
| 128 |
+
self.stride = stride
|
| 129 |
+
|
| 130 |
+
def _conv3x3(self, in_planes, out_planes, stride=1):
|
| 131 |
+
"3x3 convolution with padding"
|
| 132 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 133 |
+
padding=1, bias=False)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
residual = x
|
| 137 |
+
|
| 138 |
+
out = self.conv1(x)
|
| 139 |
+
out = self.bn1(out)
|
| 140 |
+
out = self.relu(out)
|
| 141 |
+
|
| 142 |
+
out = self.conv2(out)
|
| 143 |
+
out = self.bn2(out)
|
| 144 |
+
|
| 145 |
+
if self.downsample is not None:
|
| 146 |
+
residual = self.downsample(x)
|
| 147 |
+
out += residual
|
| 148 |
+
out = self.relu(out)
|
| 149 |
+
|
| 150 |
+
return out
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class ResNet(nn.Module):
|
| 154 |
+
|
| 155 |
+
def __init__(self, input_channel, output_channel, block, layers):
|
| 156 |
+
super(ResNet, self).__init__()
|
| 157 |
+
|
| 158 |
+
self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
|
| 159 |
+
|
| 160 |
+
self.inplanes = int(output_channel / 8)
|
| 161 |
+
self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
|
| 162 |
+
kernel_size=3, stride=1, padding=1, bias=False)
|
| 163 |
+
self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
|
| 164 |
+
self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
|
| 165 |
+
kernel_size=3, stride=1, padding=1, bias=False)
|
| 166 |
+
self.bn0_2 = nn.BatchNorm2d(self.inplanes)
|
| 167 |
+
self.relu = nn.ReLU(inplace=True)
|
| 168 |
+
|
| 169 |
+
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
| 170 |
+
self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
|
| 171 |
+
self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
|
| 172 |
+
0], kernel_size=3, stride=1, padding=1, bias=False)
|
| 173 |
+
self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
|
| 174 |
+
|
| 175 |
+
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
| 176 |
+
self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
|
| 177 |
+
self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
|
| 178 |
+
1], kernel_size=3, stride=1, padding=1, bias=False)
|
| 179 |
+
self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
|
| 180 |
+
|
| 181 |
+
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
|
| 182 |
+
self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
|
| 183 |
+
self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
|
| 184 |
+
2], kernel_size=3, stride=1, padding=1, bias=False)
|
| 185 |
+
self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
|
| 186 |
+
|
| 187 |
+
self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
|
| 188 |
+
self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
|
| 189 |
+
3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
|
| 190 |
+
self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
|
| 191 |
+
self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
|
| 192 |
+
3], kernel_size=2, stride=1, padding=0, bias=False)
|
| 193 |
+
self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
|
| 194 |
+
|
| 195 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 196 |
+
downsample = None
|
| 197 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 198 |
+
downsample = nn.Sequential(
|
| 199 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 200 |
+
kernel_size=1, stride=stride, bias=False),
|
| 201 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
layers = []
|
| 205 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 206 |
+
self.inplanes = planes * block.expansion
|
| 207 |
+
for i in range(1, blocks):
|
| 208 |
+
layers.append(block(self.inplanes, planes))
|
| 209 |
+
|
| 210 |
+
return nn.Sequential(*layers)
|
| 211 |
+
|
| 212 |
+
def forward(self, x):
|
| 213 |
+
x = self.conv0_1(x)
|
| 214 |
+
x = self.bn0_1(x)
|
| 215 |
+
x = self.relu(x)
|
| 216 |
+
x = self.conv0_2(x)
|
| 217 |
+
x = self.bn0_2(x)
|
| 218 |
+
x = self.relu(x)
|
| 219 |
+
|
| 220 |
+
x = self.maxpool1(x)
|
| 221 |
+
x = self.layer1(x)
|
| 222 |
+
x = self.conv1(x)
|
| 223 |
+
x = self.bn1(x)
|
| 224 |
+
x = self.relu(x)
|
| 225 |
+
|
| 226 |
+
x = self.maxpool2(x)
|
| 227 |
+
x = self.layer2(x)
|
| 228 |
+
x = self.conv2(x)
|
| 229 |
+
x = self.bn2(x)
|
| 230 |
+
x = self.relu(x)
|
| 231 |
+
|
| 232 |
+
x = self.maxpool3(x)
|
| 233 |
+
x = self.layer3(x)
|
| 234 |
+
x = self.conv3(x)
|
| 235 |
+
x = self.bn3(x)
|
| 236 |
+
x = self.relu(x)
|
| 237 |
+
|
| 238 |
+
x = self.layer4(x)
|
| 239 |
+
x = self.conv4_1(x)
|
| 240 |
+
x = self.bn4_1(x)
|
| 241 |
+
x = self.relu(x)
|
| 242 |
+
x = self.conv4_2(x)
|
| 243 |
+
x = self.bn4_2(x)
|
| 244 |
+
x = self.relu(x)
|
| 245 |
+
|
| 246 |
+
return x
|
trainer/modules/prediction.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Attention(nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, input_size, hidden_size, num_classes):
|
| 10 |
+
super(Attention, self).__init__()
|
| 11 |
+
self.attention_cell = AttentionCell(input_size, hidden_size, num_classes)
|
| 12 |
+
self.hidden_size = hidden_size
|
| 13 |
+
self.num_classes = num_classes
|
| 14 |
+
self.generator = nn.Linear(hidden_size, num_classes)
|
| 15 |
+
|
| 16 |
+
def _char_to_onehot(self, input_char, onehot_dim=38):
|
| 17 |
+
input_char = input_char.unsqueeze(1)
|
| 18 |
+
batch_size = input_char.size(0)
|
| 19 |
+
one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device)
|
| 20 |
+
one_hot = one_hot.scatter_(1, input_char, 1)
|
| 21 |
+
return one_hot
|
| 22 |
+
|
| 23 |
+
def forward(self, batch_H, text, is_train=True, batch_max_length=25):
|
| 24 |
+
"""
|
| 25 |
+
input:
|
| 26 |
+
batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes]
|
| 27 |
+
text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
|
| 28 |
+
output: probability distribution at each step [batch_size x num_steps x num_classes]
|
| 29 |
+
"""
|
| 30 |
+
batch_size = batch_H.size(0)
|
| 31 |
+
num_steps = batch_max_length + 1 # +1 for [s] at end of sentence.
|
| 32 |
+
|
| 33 |
+
output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device)
|
| 34 |
+
hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),
|
| 35 |
+
torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device))
|
| 36 |
+
|
| 37 |
+
if is_train:
|
| 38 |
+
for i in range(num_steps):
|
| 39 |
+
# one-hot vectors for a i-th char. in a batch
|
| 40 |
+
char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes)
|
| 41 |
+
# hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1})
|
| 42 |
+
hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
|
| 43 |
+
output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell)
|
| 44 |
+
probs = self.generator(output_hiddens)
|
| 45 |
+
|
| 46 |
+
else:
|
| 47 |
+
targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token
|
| 48 |
+
probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device)
|
| 49 |
+
|
| 50 |
+
for i in range(num_steps):
|
| 51 |
+
char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes)
|
| 52 |
+
hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
|
| 53 |
+
probs_step = self.generator(hidden[0])
|
| 54 |
+
probs[:, i, :] = probs_step
|
| 55 |
+
_, next_input = probs_step.max(1)
|
| 56 |
+
targets = next_input
|
| 57 |
+
|
| 58 |
+
return probs # batch_size x num_steps x num_classes
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class AttentionCell(nn.Module):
|
| 62 |
+
|
| 63 |
+
def __init__(self, input_size, hidden_size, num_embeddings):
|
| 64 |
+
super(AttentionCell, self).__init__()
|
| 65 |
+
self.i2h = nn.Linear(input_size, hidden_size, bias=False)
|
| 66 |
+
self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias
|
| 67 |
+
self.score = nn.Linear(hidden_size, 1, bias=False)
|
| 68 |
+
self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
|
| 69 |
+
self.hidden_size = hidden_size
|
| 70 |
+
|
| 71 |
+
def forward(self, prev_hidden, batch_H, char_onehots):
|
| 72 |
+
# [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
|
| 73 |
+
batch_H_proj = self.i2h(batch_H)
|
| 74 |
+
prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
|
| 75 |
+
e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
|
| 76 |
+
|
| 77 |
+
alpha = F.softmax(e, dim=1)
|
| 78 |
+
context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
|
| 79 |
+
concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding)
|
| 80 |
+
cur_hidden = self.rnn(concat_context, prev_hidden)
|
| 81 |
+
return cur_hidden, alpha
|
trainer/modules/sequence_modeling.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BidirectionalLSTM(nn.Module):
|
| 5 |
+
|
| 6 |
+
def __init__(self, input_size, hidden_size, output_size):
|
| 7 |
+
super(BidirectionalLSTM, self).__init__()
|
| 8 |
+
self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
|
| 9 |
+
self.linear = nn.Linear(hidden_size * 2, output_size)
|
| 10 |
+
|
| 11 |
+
def forward(self, input):
|
| 12 |
+
"""
|
| 13 |
+
input : visual feature [batch_size x T x input_size]
|
| 14 |
+
output : contextual feature [batch_size x T x output_size]
|
| 15 |
+
"""
|
| 16 |
+
try:
|
| 17 |
+
self.rnn.flatten_parameters()
|
| 18 |
+
except:
|
| 19 |
+
pass
|
| 20 |
+
recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
|
| 21 |
+
output = self.linear(recurrent) # batch_size x T x output_size
|
| 22 |
+
return output
|
trainer/modules/transformation.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TPS_SpatialTransformerNetwork(nn.Module):
|
| 9 |
+
""" Rectification Network of RARE, namely TPS based STN """
|
| 10 |
+
|
| 11 |
+
def __init__(self, F, I_size, I_r_size, I_channel_num=1):
|
| 12 |
+
""" Based on RARE TPS
|
| 13 |
+
input:
|
| 14 |
+
batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
|
| 15 |
+
I_size : (height, width) of the input image I
|
| 16 |
+
I_r_size : (height, width) of the rectified image I_r
|
| 17 |
+
I_channel_num : the number of channels of the input image I
|
| 18 |
+
output:
|
| 19 |
+
batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
|
| 20 |
+
"""
|
| 21 |
+
super(TPS_SpatialTransformerNetwork, self).__init__()
|
| 22 |
+
self.F = F
|
| 23 |
+
self.I_size = I_size
|
| 24 |
+
self.I_r_size = I_r_size # = (I_r_height, I_r_width)
|
| 25 |
+
self.I_channel_num = I_channel_num
|
| 26 |
+
self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
|
| 27 |
+
self.GridGenerator = GridGenerator(self.F, self.I_r_size)
|
| 28 |
+
|
| 29 |
+
def forward(self, batch_I):
|
| 30 |
+
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
|
| 31 |
+
build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
|
| 32 |
+
build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
|
| 33 |
+
batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
|
| 34 |
+
|
| 35 |
+
return batch_I_r
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LocalizationNetwork(nn.Module):
|
| 39 |
+
""" Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """
|
| 40 |
+
|
| 41 |
+
def __init__(self, F, I_channel_num):
|
| 42 |
+
super(LocalizationNetwork, self).__init__()
|
| 43 |
+
self.F = F
|
| 44 |
+
self.I_channel_num = I_channel_num
|
| 45 |
+
self.conv = nn.Sequential(
|
| 46 |
+
nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1,
|
| 47 |
+
bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
|
| 48 |
+
nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
|
| 49 |
+
nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
|
| 50 |
+
nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
|
| 51 |
+
nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
|
| 52 |
+
nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
|
| 53 |
+
nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True),
|
| 54 |
+
nn.AdaptiveAvgPool2d(1) # batch_size x 512
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
|
| 58 |
+
self.localization_fc2 = nn.Linear(256, self.F * 2)
|
| 59 |
+
|
| 60 |
+
# Init fc2 in LocalizationNetwork
|
| 61 |
+
self.localization_fc2.weight.data.fill_(0)
|
| 62 |
+
""" see RARE paper Fig. 6 (a) """
|
| 63 |
+
ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
|
| 64 |
+
ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
|
| 65 |
+
ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
|
| 66 |
+
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
| 67 |
+
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
| 68 |
+
initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
| 69 |
+
self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1)
|
| 70 |
+
|
| 71 |
+
def forward(self, batch_I):
|
| 72 |
+
"""
|
| 73 |
+
input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
|
| 74 |
+
output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
|
| 75 |
+
"""
|
| 76 |
+
batch_size = batch_I.size(0)
|
| 77 |
+
features = self.conv(batch_I).view(batch_size, -1)
|
| 78 |
+
batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2)
|
| 79 |
+
return batch_C_prime
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class GridGenerator(nn.Module):
|
| 83 |
+
""" Grid Generator of RARE, which produces P_prime by multiplying T with P """
|
| 84 |
+
|
| 85 |
+
def __init__(self, F, I_r_size):
|
| 86 |
+
""" Generate P_hat and inv_delta_C for later """
|
| 87 |
+
super(GridGenerator, self).__init__()
|
| 88 |
+
self.eps = 1e-6
|
| 89 |
+
self.I_r_height, self.I_r_width = I_r_size
|
| 90 |
+
self.F = F
|
| 91 |
+
self.C = self._build_C(self.F) # F x 2
|
| 92 |
+
self.P = self._build_P(self.I_r_width, self.I_r_height)
|
| 93 |
+
## for multi-gpu, you need register buffer
|
| 94 |
+
self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3
|
| 95 |
+
self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3
|
| 96 |
+
## for fine-tuning with different image width, you may use below instead of self.register_buffer
|
| 97 |
+
#self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3
|
| 98 |
+
#self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3
|
| 99 |
+
|
| 100 |
+
def _build_C(self, F):
|
| 101 |
+
""" Return coordinates of fiducial points in I_r; C """
|
| 102 |
+
ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
|
| 103 |
+
ctrl_pts_y_top = -1 * np.ones(int(F / 2))
|
| 104 |
+
ctrl_pts_y_bottom = np.ones(int(F / 2))
|
| 105 |
+
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
| 106 |
+
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
| 107 |
+
C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
| 108 |
+
return C # F x 2
|
| 109 |
+
|
| 110 |
+
def _build_inv_delta_C(self, F, C):
|
| 111 |
+
""" Return inv_delta_C which is needed to calculate T """
|
| 112 |
+
hat_C = np.zeros((F, F), dtype=float) # F x F
|
| 113 |
+
for i in range(0, F):
|
| 114 |
+
for j in range(i, F):
|
| 115 |
+
r = np.linalg.norm(C[i] - C[j])
|
| 116 |
+
hat_C[i, j] = r
|
| 117 |
+
hat_C[j, i] = r
|
| 118 |
+
np.fill_diagonal(hat_C, 1)
|
| 119 |
+
hat_C = (hat_C ** 2) * np.log(hat_C)
|
| 120 |
+
# print(C.shape, hat_C.shape)
|
| 121 |
+
delta_C = np.concatenate( # F+3 x F+3
|
| 122 |
+
[
|
| 123 |
+
np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
|
| 124 |
+
np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
|
| 125 |
+
np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3
|
| 126 |
+
],
|
| 127 |
+
axis=0
|
| 128 |
+
)
|
| 129 |
+
inv_delta_C = np.linalg.inv(delta_C)
|
| 130 |
+
return inv_delta_C # F+3 x F+3
|
| 131 |
+
|
| 132 |
+
def _build_P(self, I_r_width, I_r_height):
|
| 133 |
+
I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width
|
| 134 |
+
I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height
|
| 135 |
+
P = np.stack( # self.I_r_width x self.I_r_height x 2
|
| 136 |
+
np.meshgrid(I_r_grid_x, I_r_grid_y),
|
| 137 |
+
axis=2
|
| 138 |
+
)
|
| 139 |
+
return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
|
| 140 |
+
|
| 141 |
+
def _build_P_hat(self, F, C, P):
|
| 142 |
+
n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
|
| 143 |
+
P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2
|
| 144 |
+
C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
|
| 145 |
+
P_diff = P_tile - C_tile # n x F x 2
|
| 146 |
+
rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
|
| 147 |
+
rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F
|
| 148 |
+
P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
|
| 149 |
+
return P_hat # n x F+3
|
| 150 |
+
|
| 151 |
+
def build_P_prime(self, batch_C_prime):
|
| 152 |
+
""" Generate Grid from batch_C_prime [batch_size x F x 2] """
|
| 153 |
+
batch_size = batch_C_prime.size(0)
|
| 154 |
+
batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
|
| 155 |
+
batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
|
| 156 |
+
batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
|
| 157 |
+
batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2
|
| 158 |
+
batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2
|
| 159 |
+
batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
|
| 160 |
+
return batch_P_prime # batch_size x n x 2
|
trainer/saved_models/folder.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
trained model will be saved here
|
trainer/test.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import string
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.backends.cudnn as cudnn
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import numpy as np
|
| 11 |
+
from nltk.metrics.distance import edit_distance
|
| 12 |
+
|
| 13 |
+
from utils import CTCLabelConverter, AttnLabelConverter, Averager
|
| 14 |
+
from dataset import hierarchical_dataset, AlignCollate
|
| 15 |
+
from model import Model
|
| 16 |
+
|
| 17 |
+
def validation(model, criterion, evaluation_loader, converter, opt, device):
|
| 18 |
+
""" validation or evaluation """
|
| 19 |
+
n_correct = 0
|
| 20 |
+
norm_ED = 0
|
| 21 |
+
length_of_data = 0
|
| 22 |
+
infer_time = 0
|
| 23 |
+
valid_loss_avg = Averager()
|
| 24 |
+
|
| 25 |
+
for i, (image_tensors, labels) in enumerate(evaluation_loader):
|
| 26 |
+
batch_size = image_tensors.size(0)
|
| 27 |
+
length_of_data = length_of_data + batch_size
|
| 28 |
+
image = image_tensors.to(device)
|
| 29 |
+
# For max length prediction
|
| 30 |
+
length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
|
| 31 |
+
text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
|
| 32 |
+
|
| 33 |
+
text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
|
| 34 |
+
|
| 35 |
+
start_time = time.time()
|
| 36 |
+
if 'CTC' in opt.Prediction:
|
| 37 |
+
preds = model(image, text_for_pred)
|
| 38 |
+
forward_time = time.time() - start_time
|
| 39 |
+
|
| 40 |
+
# Calculate evaluation loss for CTC decoder.
|
| 41 |
+
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
| 42 |
+
# permute 'preds' to use CTCloss format
|
| 43 |
+
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
|
| 44 |
+
|
| 45 |
+
if opt.decode == 'greedy':
|
| 46 |
+
# Select max probabilty (greedy decoding) then decode index to character
|
| 47 |
+
_, preds_index = preds.max(2)
|
| 48 |
+
preds_index = preds_index.view(-1)
|
| 49 |
+
preds_str = converter.decode_greedy(preds_index.data, preds_size.data)
|
| 50 |
+
elif opt.decode == 'beamsearch':
|
| 51 |
+
preds_str = converter.decode_beamsearch(preds, beamWidth=2)
|
| 52 |
+
|
| 53 |
+
else:
|
| 54 |
+
preds = model(image, text_for_pred, is_train=False)
|
| 55 |
+
forward_time = time.time() - start_time
|
| 56 |
+
|
| 57 |
+
preds = preds[:, :text_for_loss.shape[1] - 1, :]
|
| 58 |
+
target = text_for_loss[:, 1:] # without [GO] Symbol
|
| 59 |
+
cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))
|
| 60 |
+
|
| 61 |
+
# select max probabilty (greedy decoding) then decode index to character
|
| 62 |
+
_, preds_index = preds.max(2)
|
| 63 |
+
preds_str = converter.decode(preds_index, length_for_pred)
|
| 64 |
+
labels = converter.decode(text_for_loss[:, 1:], length_for_loss)
|
| 65 |
+
|
| 66 |
+
infer_time += forward_time
|
| 67 |
+
valid_loss_avg.add(cost)
|
| 68 |
+
|
| 69 |
+
# calculate accuracy & confidence score
|
| 70 |
+
preds_prob = F.softmax(preds, dim=2)
|
| 71 |
+
preds_max_prob, _ = preds_prob.max(dim=2)
|
| 72 |
+
confidence_score_list = []
|
| 73 |
+
|
| 74 |
+
for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
|
| 75 |
+
if 'Attn' in opt.Prediction:
|
| 76 |
+
gt = gt[:gt.find('[s]')]
|
| 77 |
+
pred_EOS = pred.find('[s]')
|
| 78 |
+
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
|
| 79 |
+
pred_max_prob = pred_max_prob[:pred_EOS]
|
| 80 |
+
|
| 81 |
+
if pred == gt:
|
| 82 |
+
n_correct += 1
|
| 83 |
+
|
| 84 |
+
'''
|
| 85 |
+
(old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
|
| 86 |
+
"For each word we calculate the normalized edit distance to the length of the ground truth transcription."
|
| 87 |
+
if len(gt) == 0:
|
| 88 |
+
norm_ED += 1
|
| 89 |
+
else:
|
| 90 |
+
norm_ED += edit_distance(pred, gt) / len(gt)
|
| 91 |
+
'''
|
| 92 |
+
|
| 93 |
+
# ICDAR2019 Normalized Edit Distance
|
| 94 |
+
if len(gt) == 0 or len(pred) ==0:
|
| 95 |
+
norm_ED += 0
|
| 96 |
+
elif len(gt) > len(pred):
|
| 97 |
+
norm_ED += 1 - edit_distance(pred, gt) / len(gt)
|
| 98 |
+
else:
|
| 99 |
+
norm_ED += 1 - edit_distance(pred, gt) / len(pred)
|
| 100 |
+
|
| 101 |
+
# calculate confidence score (= multiply of pred_max_prob)
|
| 102 |
+
try:
|
| 103 |
+
confidence_score = pred_max_prob.cumprod(dim=0)[-1]
|
| 104 |
+
except:
|
| 105 |
+
confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s])
|
| 106 |
+
confidence_score_list.append(confidence_score)
|
| 107 |
+
# print(pred, gt, pred==gt, confidence_score)
|
| 108 |
+
|
| 109 |
+
accuracy = n_correct / float(length_of_data) * 100
|
| 110 |
+
norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
|
| 111 |
+
|
| 112 |
+
return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
|
trainer/train.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import torch.backends.cudnn as cudnn
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.init as init
|
| 9 |
+
import torch.optim as optim
|
| 10 |
+
import torch.utils.data
|
| 11 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from utils import CTCLabelConverter, AttnLabelConverter, Averager
|
| 15 |
+
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
|
| 16 |
+
from model import Model
|
| 17 |
+
from test import validation
|
| 18 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 19 |
+
|
| 20 |
+
def count_parameters(model):
|
| 21 |
+
print("Modules, Parameters")
|
| 22 |
+
total_params = 0
|
| 23 |
+
for name, parameter in model.named_parameters():
|
| 24 |
+
if not parameter.requires_grad: continue
|
| 25 |
+
param = parameter.numel()
|
| 26 |
+
#table.add_row([name, param])
|
| 27 |
+
total_params+=param
|
| 28 |
+
print(name, param)
|
| 29 |
+
print(f"Total Trainable Params: {total_params}")
|
| 30 |
+
return total_params
|
| 31 |
+
|
| 32 |
+
def train(opt, show_number = 2, amp=False):
|
| 33 |
+
""" dataset preparation """
|
| 34 |
+
if not opt.data_filtering_off:
|
| 35 |
+
print('Filtering the images containing characters which are not in opt.character')
|
| 36 |
+
print('Filtering the images whose label is longer than opt.batch_max_length')
|
| 37 |
+
|
| 38 |
+
opt.select_data = opt.select_data.split('-')
|
| 39 |
+
opt.batch_ratio = opt.batch_ratio.split('-')
|
| 40 |
+
train_dataset = Batch_Balanced_Dataset(opt)
|
| 41 |
+
|
| 42 |
+
log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a', encoding="utf8")
|
| 43 |
+
AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust=opt.contrast_adjust)
|
| 44 |
+
valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
|
| 45 |
+
valid_loader = torch.utils.data.DataLoader(
|
| 46 |
+
valid_dataset, batch_size=min(32, opt.batch_size),
|
| 47 |
+
shuffle=True, # 'True' to check training progress with validation function.
|
| 48 |
+
num_workers=int(opt.workers), prefetch_factor=512,
|
| 49 |
+
collate_fn=AlignCollate_valid, pin_memory=True)
|
| 50 |
+
log.write(valid_dataset_log)
|
| 51 |
+
print('-' * 80)
|
| 52 |
+
log.write('-' * 80 + '\n')
|
| 53 |
+
log.close()
|
| 54 |
+
|
| 55 |
+
""" model configuration """
|
| 56 |
+
if 'CTC' in opt.Prediction:
|
| 57 |
+
converter = CTCLabelConverter(opt.character)
|
| 58 |
+
else:
|
| 59 |
+
converter = AttnLabelConverter(opt.character)
|
| 60 |
+
opt.num_class = len(converter.character)
|
| 61 |
+
|
| 62 |
+
if opt.rgb:
|
| 63 |
+
opt.input_channel = 3
|
| 64 |
+
model = Model(opt)
|
| 65 |
+
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
|
| 66 |
+
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
|
| 67 |
+
opt.SequenceModeling, opt.Prediction)
|
| 68 |
+
|
| 69 |
+
if opt.saved_model != '':
|
| 70 |
+
pretrained_dict = torch.load(opt.saved_model)
|
| 71 |
+
if opt.new_prediction:
|
| 72 |
+
model.Prediction = nn.Linear(model.SequenceModeling_output, len(pretrained_dict['module.Prediction.weight']))
|
| 73 |
+
|
| 74 |
+
model = torch.nn.DataParallel(model).to(device)
|
| 75 |
+
print(f'loading pretrained model from {opt.saved_model}')
|
| 76 |
+
if opt.FT:
|
| 77 |
+
model.load_state_dict(pretrained_dict, strict=False)
|
| 78 |
+
else:
|
| 79 |
+
model.load_state_dict(pretrained_dict)
|
| 80 |
+
if opt.new_prediction:
|
| 81 |
+
model.module.Prediction = nn.Linear(model.module.SequenceModeling_output, opt.num_class)
|
| 82 |
+
for name, param in model.module.Prediction.named_parameters():
|
| 83 |
+
if 'bias' in name:
|
| 84 |
+
init.constant_(param, 0.0)
|
| 85 |
+
elif 'weight' in name:
|
| 86 |
+
init.kaiming_normal_(param)
|
| 87 |
+
model = model.to(device)
|
| 88 |
+
else:
|
| 89 |
+
# weight initialization
|
| 90 |
+
for name, param in model.named_parameters():
|
| 91 |
+
if 'localization_fc2' in name:
|
| 92 |
+
print(f'Skip {name} as it is already initialized')
|
| 93 |
+
continue
|
| 94 |
+
try:
|
| 95 |
+
if 'bias' in name:
|
| 96 |
+
init.constant_(param, 0.0)
|
| 97 |
+
elif 'weight' in name:
|
| 98 |
+
init.kaiming_normal_(param)
|
| 99 |
+
except Exception as e: # for batchnorm.
|
| 100 |
+
if 'weight' in name:
|
| 101 |
+
param.data.fill_(1)
|
| 102 |
+
continue
|
| 103 |
+
model = torch.nn.DataParallel(model).to(device)
|
| 104 |
+
|
| 105 |
+
model.train()
|
| 106 |
+
print("Model:")
|
| 107 |
+
print(model)
|
| 108 |
+
count_parameters(model)
|
| 109 |
+
|
| 110 |
+
""" setup loss """
|
| 111 |
+
if 'CTC' in opt.Prediction:
|
| 112 |
+
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
|
| 113 |
+
else:
|
| 114 |
+
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
|
| 115 |
+
# loss averager
|
| 116 |
+
loss_avg = Averager()
|
| 117 |
+
|
| 118 |
+
# freeze some layers
|
| 119 |
+
try:
|
| 120 |
+
if opt.freeze_FeatureFxtraction:
|
| 121 |
+
for param in model.module.FeatureExtraction.parameters():
|
| 122 |
+
param.requires_grad = False
|
| 123 |
+
if opt.freeze_SequenceModeling:
|
| 124 |
+
for param in model.module.SequenceModeling.parameters():
|
| 125 |
+
param.requires_grad = False
|
| 126 |
+
except:
|
| 127 |
+
pass
|
| 128 |
+
|
| 129 |
+
# filter that only require gradient decent
|
| 130 |
+
filtered_parameters = []
|
| 131 |
+
params_num = []
|
| 132 |
+
for p in filter(lambda p: p.requires_grad, model.parameters()):
|
| 133 |
+
filtered_parameters.append(p)
|
| 134 |
+
params_num.append(np.prod(p.size()))
|
| 135 |
+
print('Trainable params num : ', sum(params_num))
|
| 136 |
+
# [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]
|
| 137 |
+
|
| 138 |
+
# setup optimizer
|
| 139 |
+
if opt.optim=='adam':
|
| 140 |
+
#optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
|
| 141 |
+
optimizer = optim.Adam(filtered_parameters)
|
| 142 |
+
else:
|
| 143 |
+
optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
|
| 144 |
+
print("Optimizer:")
|
| 145 |
+
print(optimizer)
|
| 146 |
+
|
| 147 |
+
""" final options """
|
| 148 |
+
# print(opt)
|
| 149 |
+
with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf8") as opt_file:
|
| 150 |
+
opt_log = '------------ Options -------------\n'
|
| 151 |
+
args = vars(opt)
|
| 152 |
+
for k, v in args.items():
|
| 153 |
+
opt_log += f'{str(k)}: {str(v)}\n'
|
| 154 |
+
opt_log += '---------------------------------------\n'
|
| 155 |
+
print(opt_log)
|
| 156 |
+
opt_file.write(opt_log)
|
| 157 |
+
|
| 158 |
+
""" start training """
|
| 159 |
+
start_iter = 0
|
| 160 |
+
if opt.saved_model != '':
|
| 161 |
+
try:
|
| 162 |
+
start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
|
| 163 |
+
print(f'continue to train, start_iter: {start_iter}')
|
| 164 |
+
except:
|
| 165 |
+
pass
|
| 166 |
+
|
| 167 |
+
start_time = time.time()
|
| 168 |
+
best_accuracy = -1
|
| 169 |
+
best_norm_ED = -1
|
| 170 |
+
i = start_iter
|
| 171 |
+
|
| 172 |
+
scaler = GradScaler()
|
| 173 |
+
t1= time.time()
|
| 174 |
+
|
| 175 |
+
while(True):
|
| 176 |
+
# train part
|
| 177 |
+
optimizer.zero_grad(set_to_none=True)
|
| 178 |
+
|
| 179 |
+
if amp:
|
| 180 |
+
with autocast():
|
| 181 |
+
image_tensors, labels = train_dataset.get_batch()
|
| 182 |
+
image = image_tensors.to(device)
|
| 183 |
+
text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
|
| 184 |
+
batch_size = image.size(0)
|
| 185 |
+
|
| 186 |
+
if 'CTC' in opt.Prediction:
|
| 187 |
+
preds = model(image, text).log_softmax(2)
|
| 188 |
+
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
| 189 |
+
preds = preds.permute(1, 0, 2)
|
| 190 |
+
torch.backends.cudnn.enabled = False
|
| 191 |
+
cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
|
| 192 |
+
torch.backends.cudnn.enabled = True
|
| 193 |
+
else:
|
| 194 |
+
preds = model(image, text[:, :-1]) # align with Attention.forward
|
| 195 |
+
target = text[:, 1:] # without [GO] Symbol
|
| 196 |
+
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
|
| 197 |
+
scaler.scale(cost).backward()
|
| 198 |
+
scaler.unscale_(optimizer)
|
| 199 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
| 200 |
+
scaler.step(optimizer)
|
| 201 |
+
scaler.update()
|
| 202 |
+
else:
|
| 203 |
+
image_tensors, labels = train_dataset.get_batch()
|
| 204 |
+
image = image_tensors.to(device)
|
| 205 |
+
text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
|
| 206 |
+
batch_size = image.size(0)
|
| 207 |
+
if 'CTC' in opt.Prediction:
|
| 208 |
+
preds = model(image, text).log_softmax(2)
|
| 209 |
+
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
| 210 |
+
preds = preds.permute(1, 0, 2)
|
| 211 |
+
torch.backends.cudnn.enabled = False
|
| 212 |
+
cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
|
| 213 |
+
torch.backends.cudnn.enabled = True
|
| 214 |
+
else:
|
| 215 |
+
preds = model(image, text[:, :-1]) # align with Attention.forward
|
| 216 |
+
target = text[:, 1:] # without [GO] Symbol
|
| 217 |
+
cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
|
| 218 |
+
cost.backward()
|
| 219 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
| 220 |
+
optimizer.step()
|
| 221 |
+
loss_avg.add(cost)
|
| 222 |
+
|
| 223 |
+
# validation part
|
| 224 |
+
if (i % opt.valInterval == 0) and (i!=0):
|
| 225 |
+
print('training time: ', time.time()-t1)
|
| 226 |
+
t1=time.time()
|
| 227 |
+
elapsed_time = time.time() - start_time
|
| 228 |
+
# for log
|
| 229 |
+
with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf8") as log:
|
| 230 |
+
model.eval()
|
| 231 |
+
with torch.no_grad():
|
| 232 |
+
valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels,\
|
| 233 |
+
infer_time, length_of_data = validation(model, criterion, valid_loader, converter, opt, device)
|
| 234 |
+
model.train()
|
| 235 |
+
|
| 236 |
+
# training loss and validation loss
|
| 237 |
+
loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
|
| 238 |
+
loss_avg.reset()
|
| 239 |
+
|
| 240 |
+
current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.4f}'
|
| 241 |
+
|
| 242 |
+
# keep best accuracy model (on valid dataset)
|
| 243 |
+
if current_accuracy > best_accuracy:
|
| 244 |
+
best_accuracy = current_accuracy
|
| 245 |
+
torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
|
| 246 |
+
if current_norm_ED > best_norm_ED:
|
| 247 |
+
best_norm_ED = current_norm_ED
|
| 248 |
+
torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
|
| 249 |
+
best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.4f}'
|
| 250 |
+
|
| 251 |
+
loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
|
| 252 |
+
print(loss_model_log)
|
| 253 |
+
log.write(loss_model_log + '\n')
|
| 254 |
+
|
| 255 |
+
# show some predicted results
|
| 256 |
+
dashed_line = '-' * 80
|
| 257 |
+
head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
|
| 258 |
+
predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
|
| 259 |
+
|
| 260 |
+
#show_number = min(show_number, len(labels))
|
| 261 |
+
|
| 262 |
+
start = random.randint(0,len(labels) - show_number )
|
| 263 |
+
for gt, pred, confidence in zip(labels[start:start+show_number], preds[start:start+show_number], confidence_score[start:start+show_number]):
|
| 264 |
+
if 'Attn' in opt.Prediction:
|
| 265 |
+
gt = gt[:gt.find('[s]')]
|
| 266 |
+
pred = pred[:pred.find('[s]')]
|
| 267 |
+
|
| 268 |
+
predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
|
| 269 |
+
predicted_result_log += f'{dashed_line}'
|
| 270 |
+
print(predicted_result_log)
|
| 271 |
+
log.write(predicted_result_log + '\n')
|
| 272 |
+
print('validation time: ', time.time()-t1)
|
| 273 |
+
t1=time.time()
|
| 274 |
+
# save model per 1e+4 iter.
|
| 275 |
+
if (i + 1) % 1e+4 == 0:
|
| 276 |
+
torch.save(
|
| 277 |
+
model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')
|
| 278 |
+
|
| 279 |
+
if i == opt.num_iter:
|
| 280 |
+
print('end the training')
|
| 281 |
+
sys.exit()
|
| 282 |
+
i += 1
|