""" CASCADE – Flask app for HF Spaces Mounted at /cascade_v1/ """ import os import json import uuid import redis from flask import Flask, Blueprint, request, jsonify, render_template, send_file, abort, redirect, Response from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Chem.Draw import rdMolDraw2D from NMR_Prediction.valid import validate_smiles bp = Blueprint("cascade", __name__) redis_client = redis.StrictRedis( host="localhost", port=6379, db=0, decode_responses=True ) # ── Pages ───────────────────────────────────────────────────────────────────── @bp.route("/") @bp.route("") @bp.route("/predict/") @bp.route("/predict") @bp.route("/home/") @bp.route("/about/") def predict(): return render_template("cascade/predict.html") # ── Job submission ──────────────────────────────────────────────────────────── def _submit_job(smiles, type_): if not validate_smiles(smiles): return jsonify({"message": "Invalid SMILES or molecule exceeds 50 heavy atoms", "task_id": None}) task_id = uuid.uuid4().hex redis_client.set(f"task_detail_{task_id}", json.dumps({"smiles": smiles, "type_": type_})) redis_client.rpush("task_queue", task_id) return task_id @bp.route("/predict_NMR_C/", methods=["POST"]) def predict_NMR_C(): result = _submit_job(request.form["smiles"], request.form["type_"]) if not isinstance(result, str): return result return jsonify({"message": "Molecule submitted to C queue", "task_id": result}) @bp.route("/predict_NMR_H/", methods=["POST"]) def predict_NMR_H(): result = _submit_job(request.form["smiles"], request.form["type_"]) if not isinstance(result, str): return result return jsonify({"message": "Molecule submitted to H queue", "task_id": result}) # ── check_task ──────────────────────────────────────────────────────────────── @bp.route("/check_task/") def check_task(): raw = redis_client.get(f"task_result_{request.args['task_id']}") if not raw: return "running", 200 result = json.loads(raw) if "errMessage" in result: return "Error1", 200 return "done", 200 # ── get_result ──────────────────────────────────────────────────────────────── @bp.route("/get_result/") def get_result(): task_id = request.args["task_id"] raw = redis_client.get(f"task_result_{task_id}") if not raw: return jsonify({"error": "Result not found"}), 404 result = json.loads(raw) if "errMessage" in result: return jsonify({"error": result["errMessage"]}), 500 smiles = result["smiles"] nucleus = result.get("type_", "C") weighted_shift_txt = result["weightedShiftTxt"] shift_map = {} for item in filter(None, weighted_shift_txt.split(";")): parts = item.split(",") if len(parts) == 2: shift_map[int(parts[0])] = parts[1] if nucleus == "H": mol = Chem.MolFromSmiles(smiles) mol = Chem.AddHs(mol) AllChem.Compute2DCoords(mol) mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=True) n_label = mol.GetNumAtoms() drawer = rdMolDraw2D.MolDraw2DSVG(700, 500) else: mol = Chem.MolFromSmiles(smiles) AllChem.Compute2DCoords(mol) mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=True) n_label = mol.GetNumAtoms() drawer = rdMolDraw2D.MolDraw2DSVG(600, 450) opts = drawer.drawOptions() for atom_1idx, shift_val in shift_map.items(): atom_0idx = atom_1idx - 1 if atom_0idx < n_label: opts.atomLabels[atom_0idx] = shift_val opts.clearBackground = False opts.bondLineWidth = 1 opts.padding = 0.15 opts.additionalAtomLabelPadding = 0.1 drawer.DrawMolecule(mol_draw) drawer.FinishDrawing() svg = drawer.GetDrawingText().replace("svg:", "").replace(":svg", "") return jsonify({ "svg": svg, "smiles": smiles, "nucleus": nucleus, "conf_sdfs": result.get("conf_sdfs", []), "weightedShift": weighted_shift_txt, "confShift": result["confShiftTxt"], "relative_E": result["relative_E"], "taskId": task_id, }) # ── Download as CSV ─────────────────────────────────────────────────────────── @bp.route("/download//") def download(task_id): raw = redis_client.get(f"task_result_{task_id}") if not raw: abort(404) result = json.loads(raw) if "errMessage" in result: abort(404) nucleus = result.get("type_", "C") header = f"Atom Index,Predicted {'1H' if nucleus == 'H' else '13C'} Shift (ppm)" lines = [header] for item in filter(None, result["weightedShiftTxt"].split(";")): parts = item.split(",") if len(parts) == 2: lines.append(f"{parts[0]},{parts[1]}") return Response( "\n".join(lines), mimetype="text/csv", headers={"Content-Disposition": f"attachment; filename=cascade_{task_id[:8]}.csv"} ) def create_app(): app = Flask(__name__, static_folder="static", template_folder="templates") app.register_blueprint(bp, url_prefix="/cascade_v1") @app.route("/") @app.route("/cascade_v1") def root(): return ''' ''' @app.after_request def remove_iframe_restriction(response): response.headers.pop("X-Frame-Options", None) response.headers["Content-Security-Policy"] = "frame-ancestors *" return response return app app = create_app() if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=False)