dkescape commited on
Commit
6632323
·
verified ·
1 Parent(s): b61db58

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ import logging
4
+
5
+ # ----------------------------
6
+ # 1. Warning & logging setup
7
+ # ----------------------------
8
+ # Suppress FutureWarning from timm internals
9
+ warnings.filterwarnings(
10
+ "ignore",
11
+ category=FutureWarning,
12
+ module="timm.models.layers"
13
+ )
14
+ # Suppress UserWarning from modelscope (e.g. missing preprocessor config)
15
+ warnings.filterwarnings(
16
+ "ignore",
17
+ category=UserWarning,
18
+ module="modelscope"
19
+ )
20
+ # Only show ERROR+ logs from modelscope
21
+ logging.getLogger("modelscope").setLevel(logging.ERROR)
22
+
23
+ # ----------------------------
24
+ # 2. Standard imports
25
+ # ----------------------------
26
+ import cv2
27
+ import tempfile
28
+ import gradio as gr
29
+ import numpy as np
30
+ from PIL import Image, ImageEnhance, ImageFilter
31
+ from modelscope.outputs import OutputKeys
32
+ from modelscope.pipelines import pipeline
33
+ from modelscope.utils.constant import Tasks
34
+
35
+ # ----------------------------
36
+ # 3. Load your colorization model
37
+ # ----------------------------
38
+ img_colorization = pipeline(
39
+ Tasks.image_colorization,
40
+ model="iic/cv_ddcolor_image-colorization",
41
+ model_revision="v1.02",
42
+ )
43
+
44
+ # ----------------------------
45
+ # 4. Image processing fns
46
+ # ----------------------------
47
+ def colorize_image(img_path: str) -> str:
48
+ image = cv2.imread(str(img_path))
49
+ output = img_colorization(image[..., ::-1])
50
+ result = output[OutputKeys.OUTPUT_IMG].astype(np.uint8)
51
+
52
+ temp_dir = tempfile.mkdtemp()
53
+ out_path = os.path.join(temp_dir, "colorized.png")
54
+ cv2.imwrite(out_path, result)
55
+ return out_path
56
+
57
+
58
+ def enhance_image(
59
+ img_path: str,
60
+ brightness: float = 1.0,
61
+ contrast: float = 1.0,
62
+ edge_enhance: bool = False
63
+ ) -> str:
64
+ image = Image.open(img_path)
65
+ image = ImageEnhance.Brightness(image).enhance(brightness)
66
+ image = ImageEnhance.Contrast(image).enhance(contrast)
67
+ if edge_enhance:
68
+ image = image.filter(ImageFilter.EDGE_ENHANCE)
69
+
70
+ temp_dir = tempfile.mkdtemp()
71
+ enhanced_path = os.path.join(temp_dir, "enhanced.png")
72
+ image.save(enhanced_path)
73
+ return enhanced_path
74
+
75
+
76
+ def process_image(
77
+ img_path: str,
78
+ brightness: float,
79
+ contrast: float,
80
+ edge_enhance: bool,
81
+ output_format: str
82
+ ):
83
+ # Colorize → Enhance → Re‑save in chosen format
84
+ colorized_path = colorize_image(img_path)
85
+ enhanced_path = enhance_image(colorized_path, brightness, contrast, edge_enhance)
86
+
87
+ img = Image.open(enhanced_path)
88
+ temp_dir = tempfile.mkdtemp()
89
+ filename = f"colorized_image.{output_format.lower()}"
90
+ output_path = os.path.join(temp_dir, filename)
91
+ img.save(output_path, format=output_format.upper())
92
+
93
+ # Return side-by-side gallery and downloadable file
94
+ return ([img_path, enhanced_path], output_path)
95
+
96
+ # ----------------------------
97
+ # 5. Gradio UI + custom CSS
98
+ # ----------------------------
99
+ custom_css = """
100
+ body { background-color: #f0f2f5; }
101
+ .gradio-container { max-width: 900px !important; margin: auto !important; }
102
+ #header { background-color: #4CAF50; padding: 20px; border-radius: 8px;
103
+ text-align: center; margin-bottom: 20px; }
104
+ #header h2, #header p { color: white; margin: 0; }
105
+ #header p { margin-top: 5px; font-size: 1rem; }
106
+ #control-panel { background: white; padding: 20px; border-radius: 8px;
107
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1); margin-bottom: 20px; }
108
+ #submit-btn { background-color: #4CAF50 !important; color: white !important;
109
+ border-radius: 8px !important; font-weight: bold;
110
+ padding: 10px 20px !important; margin-top: 10px !important; }
111
+ #control-panel .gr-row { gap: 15px; }
112
+ .gr-slider, .gr-checkbox, .gr-dropdown { margin-top: 10px; }
113
+ #comparison_gallery { background: white; padding: 10px;
114
+ border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
115
+ #download-btn { margin-top: 15px !important; }
116
+ """
117
+
118
+ TITLE = "🌈 Color Restorization Model"
119
+ DESCRIPTION = "Bring your old black & white photos back to life—upload, adjust, and download in vivid color."
120
+
121
+ with gr.Blocks(title=TITLE, css=custom_css) as app:
122
+ # Header
123
+ gr.HTML(
124
+ """
125
+ <div id="header">
126
+ <h2>🌈 Color Restorization Model</h2>
127
+ <p>Bring your old black & white photos back to life—upload, adjust, and download in vivid color.</p>
128
+ </div>
129
+ """
130
+ )
131
+
132
+ # Controls & results
133
+ with gr.Column(elem_id="control-panel"):
134
+ with gr.Row():
135
+ # Inputs
136
+ with gr.Column():
137
+ input_image = gr.Image(type="filepath", label="Upload B&W Image", interactive=True)
138
+ brightness_slider = gr.Slider(0.5, 2.0, value=1.0, label="Brightness")
139
+ contrast_slider = gr.Slider(0.5, 2.0, value=1.0, label="Contrast")
140
+ edge_enhance_checkbox = gr.Checkbox(label="Apply Edge Enhancement")
141
+ output_format_dropdown = gr.Dropdown(["PNG", "JPEG", "TIFF"], value="PNG", label="Output Format")
142
+ submit_btn = gr.Button("Colorize", elem_id="submit-btn")
143
+
144
+ # Outputs
145
+ with gr.Column():
146
+ comparison_gallery = gr.Gallery(
147
+ label="Original vs. Colorized",
148
+ columns=2,
149
+ elem_id="comparison_gallery",
150
+ height="auto"
151
+ )
152
+ download_btn = gr.File(label="Download Colorized Image", elem_id="download-btn")
153
+
154
+ # Wire up UI listener with API name
155
+ submit_btn.click(
156
+ fn=process_image,
157
+ inputs=[
158
+ input_image,
159
+ brightness_slider,
160
+ contrast_slider,
161
+ edge_enhance_checkbox,
162
+ output_format_dropdown
163
+ ],
164
+ outputs=[comparison_gallery, download_btn],
165
+ api_name="process_image"
166
+ )
167
+
168
+ # Optional: additional direct API route (unrelated to button click)
169
+ gr.api(process_image, api_name="process_image_direct")
170
+
171
+ # Launch with queue and API visible
172
+ if __name__ == "__main__":
173
+ port = int(os.environ.get("PORT", 7860))
174
+ app.queue()
175
+ app.launch(server_name="0.0.0.0", server_port=port, show_api=True)