from flask import Flask, render_template, request, send_file, redirect, url_for import os import uuid from colorizers import * import torch from colorizers.util import load_img, preprocess_img, postprocess_tens from PIL import Image import numpy as np import gc # Disable CUDA to save memory on Render torch.cuda.is_available = lambda: False UPLOAD_FOLDER = 'uploads' OUTPUT_FOLDER = 'outputs' ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'bmp'} app = Flask(__name__) app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER app.config['OUTPUT_FOLDER'] = OUTPUT_FOLDER os.makedirs(UPLOAD_FOLDER, exist_ok=True) os.makedirs(OUTPUT_FOLDER, exist_ok=True) # Load models once at startup (CPU only) print("Loading colorization models...") colorizer_eccv16 = eccv16(pretrained=True).eval() colorizer_siggraph17 = siggraph17(pretrained=True).eval() print("Models loaded successfully!") def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS @app.route('/', methods=['GET', 'POST']) def index(): if request.method == 'POST': files = request.files.getlist('files') if not files or len(files) == 0: return render_template('index.html', error='No files selected') results = [] for file in files: if file.filename == '': continue if file and allowed_file(file.filename): filename = str(uuid.uuid4()) + os.path.splitext(file.filename)[1] filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(filepath) try: out_paths = colorize_and_save(filepath, filename) results.append({ 'orig_img': url_for('uploaded_file', filename=filename), 'eccv16_img': url_for('output_file', filename=os.path.basename(out_paths['eccv16'])), 'siggraph17_img': url_for('output_file', filename=os.path.basename(out_paths['siggraph17'])), 'filename': os.path.splitext(filename)[0] }) except Exception as e: print(f"Error processing {filename}: {str(e)}") continue else: continue if len(results) == 0: return render_template('index.html', error='No valid files to process') return render_template('result.html', images=results, total_count=len(results)) return render_template('index.html') @app.route('/uploads/') def uploaded_file(filename): return send_file(os.path.join(app.config['UPLOAD_FOLDER'], filename)) @app.route('/outputs/') def output_file(filename): return send_file(os.path.join(app.config['OUTPUT_FOLDER'], filename)) def colorize_and_save(img_path, filename): global colorizer_eccv16, colorizer_siggraph17 img = load_img(img_path) (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256,256)) # Colorize with both models (tens_l_rs is already a PyTorch tensor) with torch.no_grad(): out_ab_eccv16 = colorizer_eccv16(tens_l_rs) out_ab_siggraph17 = colorizer_siggraph17(tens_l_rs) out_img_eccv16 = postprocess_tens(tens_l_orig, out_ab_eccv16.cpu()) out_img_siggraph17 = postprocess_tens(tens_l_orig, out_ab_siggraph17.cpu()) # Convert to uint8 and save with PIL base_filename = os.path.splitext(filename)[0] out_img_eccv16_uint8 = (np.clip(out_img_eccv16, 0, 1) * 255).astype(np.uint8) eccv16_path = os.path.join(OUTPUT_FOLDER, f'{base_filename}_eccv16.png') Image.fromarray(out_img_eccv16_uint8).save(eccv16_path) out_img_siggraph17_uint8 = (np.clip(out_img_siggraph17, 0, 1) * 255).astype(np.uint8) siggraph17_path = os.path.join(OUTPUT_FOLDER, f'{base_filename}_siggraph17.png') Image.fromarray(out_img_siggraph17_uint8).save(siggraph17_path) # Clean up memory del tens_l_rs, out_ab_eccv16, out_ab_siggraph17, img, tens_l_orig del out_img_eccv16, out_img_siggraph17, out_img_eccv16_uint8, out_img_siggraph17_uint8 gc.collect() return {'eccv16': eccv16_path, 'siggraph17': siggraph17_path} if __name__ == '__main__': # Support both Render (PORT env var) and HuggingFace Spaces (default 7860) port = int(os.getenv('PORT', os.getenv('SERVER_PORT', 7860))) app.run(host='0.0.0.0', port=port, debug=False)