cascade / app.py
bobbypaton
Add 50 heavy atom limit; CASCADE 1.0 version label; iframe headers
6a594c1
"""
CASCADE – Flask app for HF Spaces
Mounted at /cascade_v1/
"""
import os
import json
import uuid
import redis
from flask import Flask, Blueprint, request, jsonify, render_template, send_file, abort, redirect, Response
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import rdMolDraw2D
from NMR_Prediction.valid import validate_smiles
bp = Blueprint("cascade", __name__)
redis_client = redis.StrictRedis(
host="localhost", port=6379, db=0, decode_responses=True
)
# ── Pages ─────────────────────────────────────────────────────────────────────
@bp.route("/")
@bp.route("")
@bp.route("/predict/")
@bp.route("/predict")
@bp.route("/home/")
@bp.route("/about/")
def predict():
return render_template("cascade/predict.html")
# ── Job submission ────────────────────────────────────────────────────────────
def _submit_job(smiles, type_):
if not validate_smiles(smiles):
return jsonify({"message": "Invalid SMILES or molecule exceeds 50 heavy atoms", "task_id": None})
task_id = uuid.uuid4().hex
redis_client.set(f"task_detail_{task_id}",
json.dumps({"smiles": smiles, "type_": type_}))
redis_client.rpush("task_queue", task_id)
return task_id
@bp.route("/predict_NMR_C/", methods=["POST"])
def predict_NMR_C():
result = _submit_job(request.form["smiles"], request.form["type_"])
if not isinstance(result, str):
return result
return jsonify({"message": "Molecule submitted to C queue", "task_id": result})
@bp.route("/predict_NMR_H/", methods=["POST"])
def predict_NMR_H():
result = _submit_job(request.form["smiles"], request.form["type_"])
if not isinstance(result, str):
return result
return jsonify({"message": "Molecule submitted to H queue", "task_id": result})
# ── check_task ────────────────────────────────────────────────────────────────
@bp.route("/check_task/")
def check_task():
raw = redis_client.get(f"task_result_{request.args['task_id']}")
if not raw:
return "running", 200
result = json.loads(raw)
if "errMessage" in result:
return "Error1", 200
return "done", 200
# ── get_result ────────────────────────────────────────────────────────────────
@bp.route("/get_result/")
def get_result():
task_id = request.args["task_id"]
raw = redis_client.get(f"task_result_{task_id}")
if not raw:
return jsonify({"error": "Result not found"}), 404
result = json.loads(raw)
if "errMessage" in result:
return jsonify({"error": result["errMessage"]}), 500
smiles = result["smiles"]
nucleus = result.get("type_", "C")
weighted_shift_txt = result["weightedShiftTxt"]
shift_map = {}
for item in filter(None, weighted_shift_txt.split(";")):
parts = item.split(",")
if len(parts) == 2:
shift_map[int(parts[0])] = parts[1]
if nucleus == "H":
mol = Chem.MolFromSmiles(smiles)
mol = Chem.AddHs(mol)
AllChem.Compute2DCoords(mol)
mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=True)
n_label = mol.GetNumAtoms()
drawer = rdMolDraw2D.MolDraw2DSVG(700, 500)
else:
mol = Chem.MolFromSmiles(smiles)
AllChem.Compute2DCoords(mol)
mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=True)
n_label = mol.GetNumAtoms()
drawer = rdMolDraw2D.MolDraw2DSVG(600, 450)
opts = drawer.drawOptions()
for atom_1idx, shift_val in shift_map.items():
atom_0idx = atom_1idx - 1
if atom_0idx < n_label:
opts.atomLabels[atom_0idx] = shift_val
opts.clearBackground = False
opts.bondLineWidth = 1
opts.padding = 0.15
opts.additionalAtomLabelPadding = 0.1
drawer.DrawMolecule(mol_draw)
drawer.FinishDrawing()
svg = drawer.GetDrawingText().replace("svg:", "").replace(":svg", "")
return jsonify({
"svg": svg,
"smiles": smiles,
"nucleus": nucleus,
"conf_sdfs": result.get("conf_sdfs", []),
"weightedShift": weighted_shift_txt,
"confShift": result["confShiftTxt"],
"relative_E": result["relative_E"],
"taskId": task_id,
})
# ── Download as CSV ───────────────────────────────────────────────────────────
@bp.route("/download/<task_id>/")
def download(task_id):
raw = redis_client.get(f"task_result_{task_id}")
if not raw:
abort(404)
result = json.loads(raw)
if "errMessage" in result:
abort(404)
nucleus = result.get("type_", "C")
header = f"Atom Index,Predicted {'1H' if nucleus == 'H' else '13C'} Shift (ppm)"
lines = [header]
for item in filter(None, result["weightedShiftTxt"].split(";")):
parts = item.split(",")
if len(parts) == 2:
lines.append(f"{parts[0]},{parts[1]}")
return Response(
"\n".join(lines),
mimetype="text/csv",
headers={"Content-Disposition": f"attachment; filename=cascade_{task_id[:8]}.csv"}
)
def create_app():
app = Flask(__name__, static_folder="static", template_folder="templates")
app.register_blueprint(bp, url_prefix="/cascade_v1")
@app.route("/")
@app.route("/cascade_v1")
def root():
return '''<!DOCTYPE html>
<html>
<head>
<meta http-equiv="refresh" content="0; url=/cascade_v1/predict/">
<script>window.location.href="/cascade_v1/predict/";</script>
</head>
<body></body>
</html>'''
@app.after_request
def remove_iframe_restriction(response):
response.headers.pop("X-Frame-Options", None)
response.headers["Content-Security-Policy"] = "frame-ancestors *"
return response
return app
app = create_app()
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=False)