-
Notifications
You must be signed in to change notification settings - Fork 468
/
Copy pathapp.py
138 lines (116 loc) · 5.35 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright (C) 2021-2025, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import cv2
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
from doctr.file_utils import is_tf_available
from doctr.io import DocumentFile
from doctr.utils.visualization import visualize_page
if is_tf_available():
import tensorflow as tf
from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
if any(tf.config.experimental.list_physical_devices("gpu")):
forward_device = tf.device("/gpu:0")
else:
forward_device = tf.device("/cpu:0")
else:
import torch
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def main(det_archs, reco_archs):
"""Build a streamlit layout"""
# Wide mode
st.set_page_config(layout="wide")
# Designing the interface
st.title("docTR: Document Text Recognition")
# For newline
st.write("\n")
# Instructions
st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
# Set the columns
cols = st.columns((1, 1, 1, 1))
cols[0].subheader("Input page")
cols[1].subheader("Segmentation heatmap")
cols[2].subheader("OCR output")
cols[3].subheader("Page reconstitution")
# Sidebar
# File selection
st.sidebar.title("Document selection")
# Choose your own image
uploaded_file = st.sidebar.file_uploader("Upload files", type=["pdf", "png", "jpeg", "jpg"])
if uploaded_file is not None:
if uploaded_file.name.endswith(".pdf"):
doc = DocumentFile.from_pdf(uploaded_file.read())
else:
doc = DocumentFile.from_images(uploaded_file.read())
page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
page = doc[page_idx]
cols[0].image(page)
# Model selection
st.sidebar.title("Model selection")
st.sidebar.markdown("**Backend**: " + ("TensorFlow" if is_tf_available() else "PyTorch"))
det_arch = st.sidebar.selectbox("Text detection model", det_archs)
reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)
# For newline
st.sidebar.write("\n")
# Only straight pages or possible rotation
st.sidebar.title("Parameters")
assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
# Disable page orientation detection
disable_page_orientation = st.sidebar.checkbox("Disable page orientation detection", value=False)
# Disable crop orientation detection
disable_crop_orientation = st.sidebar.checkbox("Disable crop orientation detection", value=False)
# Straighten pages
straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
# Export as straight boxes
export_straight_boxes = st.sidebar.checkbox("Export as straight boxes", value=False)
st.sidebar.write("\n")
# Binarization threshold
bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
st.sidebar.write("\n")
# Box threshold
box_thresh = st.sidebar.slider("Box threshold", min_value=0.1, max_value=0.9, value=0.1, step=0.1)
st.sidebar.write("\n")
if st.sidebar.button("Analyze page"):
if uploaded_file is None:
st.sidebar.write("Please upload a document")
else:
with st.spinner("Loading model..."):
predictor = load_predictor(
det_arch=det_arch,
reco_arch=reco_arch,
assume_straight_pages=assume_straight_pages,
straighten_pages=straighten_pages,
export_as_straight_boxes=export_straight_boxes,
disable_page_orientation=disable_page_orientation,
disable_crop_orientation=disable_crop_orientation,
bin_thresh=bin_thresh,
box_thresh=box_thresh,
device=forward_device,
)
with st.spinner("Analyzing..."):
# Forward the image to the model
seg_map = forward_image(predictor, page, forward_device)
seg_map = np.squeeze(seg_map)
seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)
# Plot the raw heatmap
fig, ax = plt.subplots()
ax.imshow(seg_map)
ax.axis("off")
cols[1].pyplot(fig)
# Plot OCR output
out = predictor([page])
fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)
cols[2].pyplot(fig)
# Page reconsitution under input page
page_export = out.pages[0].export()
if assume_straight_pages or (not assume_straight_pages and straighten_pages):
img = out.pages[0].synthesize()
cols[3].image(img, clamp=True)
# Display JSON
st.markdown("\nHere are your analysis results in JSON format:")
st.json(page_export, expanded=False)
if __name__ == "__main__":
main(DET_ARCHS, RECO_ARCHS)