| """ |
| CASCADE β Flask app for HF Spaces |
| Mounted at /cascade_v1/ |
| """ |
|
|
| import os |
| import json |
| import uuid |
|
|
| import redis |
| from flask import Flask, Blueprint, request, jsonify, render_template, send_file, abort, redirect, Response |
|
|
| from rdkit import Chem |
| from rdkit.Chem import AllChem |
| from rdkit.Chem.Draw import rdMolDraw2D |
|
|
| from NMR_Prediction.valid import validate_smiles |
|
|
| bp = Blueprint("cascade", __name__) |
|
|
| redis_client = redis.StrictRedis( |
| host="localhost", port=6379, db=0, decode_responses=True |
| ) |
|
|
| |
| @bp.route("/") |
| @bp.route("") |
| @bp.route("/predict/") |
| @bp.route("/predict") |
| @bp.route("/home/") |
| @bp.route("/about/") |
| def predict(): |
| return render_template("cascade/predict.html") |
|
|
|
|
| |
| def _submit_job(smiles, type_): |
| if not validate_smiles(smiles): |
| return jsonify({"message": "Invalid SMILES or molecule exceeds 50 heavy atoms", "task_id": None}) |
| task_id = uuid.uuid4().hex |
| redis_client.set(f"task_detail_{task_id}", |
| json.dumps({"smiles": smiles, "type_": type_})) |
| redis_client.rpush("task_queue", task_id) |
| return task_id |
|
|
|
|
| @bp.route("/predict_NMR_C/", methods=["POST"]) |
| def predict_NMR_C(): |
| result = _submit_job(request.form["smiles"], request.form["type_"]) |
| if not isinstance(result, str): |
| return result |
| return jsonify({"message": "Molecule submitted to C queue", "task_id": result}) |
|
|
|
|
| @bp.route("/predict_NMR_H/", methods=["POST"]) |
| def predict_NMR_H(): |
| result = _submit_job(request.form["smiles"], request.form["type_"]) |
| if not isinstance(result, str): |
| return result |
| return jsonify({"message": "Molecule submitted to H queue", "task_id": result}) |
|
|
|
|
| |
| @bp.route("/check_task/") |
| def check_task(): |
| raw = redis_client.get(f"task_result_{request.args['task_id']}") |
| if not raw: |
| return "running", 200 |
| result = json.loads(raw) |
| if "errMessage" in result: |
| return "Error1", 200 |
| return "done", 200 |
|
|
|
|
| |
| @bp.route("/get_result/") |
| def get_result(): |
| task_id = request.args["task_id"] |
| raw = redis_client.get(f"task_result_{task_id}") |
| if not raw: |
| return jsonify({"error": "Result not found"}), 404 |
|
|
| result = json.loads(raw) |
| if "errMessage" in result: |
| return jsonify({"error": result["errMessage"]}), 500 |
|
|
| smiles = result["smiles"] |
| nucleus = result.get("type_", "C") |
|
|
| weighted_shift_txt = result["weightedShiftTxt"] |
| shift_map = {} |
| for item in filter(None, weighted_shift_txt.split(";")): |
| parts = item.split(",") |
| if len(parts) == 2: |
| shift_map[int(parts[0])] = parts[1] |
|
|
| if nucleus == "H": |
| mol = Chem.MolFromSmiles(smiles) |
| mol = Chem.AddHs(mol) |
| AllChem.Compute2DCoords(mol) |
| mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=True) |
| n_label = mol.GetNumAtoms() |
| drawer = rdMolDraw2D.MolDraw2DSVG(700, 500) |
| else: |
| mol = Chem.MolFromSmiles(smiles) |
| AllChem.Compute2DCoords(mol) |
| mol_draw = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=True) |
| n_label = mol.GetNumAtoms() |
| drawer = rdMolDraw2D.MolDraw2DSVG(600, 450) |
|
|
| opts = drawer.drawOptions() |
| for atom_1idx, shift_val in shift_map.items(): |
| atom_0idx = atom_1idx - 1 |
| if atom_0idx < n_label: |
| opts.atomLabels[atom_0idx] = shift_val |
| opts.clearBackground = False |
| opts.bondLineWidth = 1 |
| opts.padding = 0.15 |
| opts.additionalAtomLabelPadding = 0.1 |
|
|
| drawer.DrawMolecule(mol_draw) |
| drawer.FinishDrawing() |
| svg = drawer.GetDrawingText().replace("svg:", "").replace(":svg", "") |
|
|
| return jsonify({ |
| "svg": svg, |
| "smiles": smiles, |
| "nucleus": nucleus, |
| "conf_sdfs": result.get("conf_sdfs", []), |
| "weightedShift": weighted_shift_txt, |
| "confShift": result["confShiftTxt"], |
| "relative_E": result["relative_E"], |
| "taskId": task_id, |
| }) |
|
|
|
|
| |
| @bp.route("/download/<task_id>/") |
| def download(task_id): |
| raw = redis_client.get(f"task_result_{task_id}") |
| if not raw: |
| abort(404) |
| result = json.loads(raw) |
| if "errMessage" in result: |
| abort(404) |
|
|
| nucleus = result.get("type_", "C") |
| header = f"Atom Index,Predicted {'1H' if nucleus == 'H' else '13C'} Shift (ppm)" |
| lines = [header] |
| for item in filter(None, result["weightedShiftTxt"].split(";")): |
| parts = item.split(",") |
| if len(parts) == 2: |
| lines.append(f"{parts[0]},{parts[1]}") |
|
|
| return Response( |
| "\n".join(lines), |
| mimetype="text/csv", |
| headers={"Content-Disposition": f"attachment; filename=cascade_{task_id[:8]}.csv"} |
| ) |
|
|
|
|
| def create_app(): |
| app = Flask(__name__, static_folder="static", template_folder="templates") |
| app.register_blueprint(bp, url_prefix="/cascade_v1") |
|
|
| @app.route("/") |
| @app.route("/cascade_v1") |
| def root(): |
| return '''<!DOCTYPE html> |
| <html> |
| <head> |
| <meta http-equiv="refresh" content="0; url=/cascade_v1/predict/"> |
| <script>window.location.href="/cascade_v1/predict/";</script> |
| </head> |
| <body></body> |
| </html>''' |
|
|
| @app.after_request |
| def remove_iframe_restriction(response): |
| response.headers.pop("X-Frame-Options", None) |
| response.headers["Content-Security-Policy"] = "frame-ancestors *" |
| return response |
|
|
| return app |
|
|
| app = create_app() |
|
|
| if __name__ == "__main__": |
| app.run(host="0.0.0.0", port=7860, debug=False) |
|
|