| from flask import Flask, render_template, request, redirect, url_for, jsonify, Request, Response |
| import json |
|
|
| import base64 |
| import hashlib |
| import json |
| import time |
| import os |
| import queue |
| import zstandard as zstd |
|
|
| q = queue.Queue() |
|
|
| """ |
| Format: |
| { |
| "status": "queue" | "progress" | "done" |
| "prompt": "Some ducks..." |
| "id": "abc123" |
| } |
| """ |
| models = [] |
|
|
| WORKERS = [] |
|
|
| def enqueue(prompt: str): |
| tm = time.time() |
| hsh = hashlib.sha256(prompt.encode("utf-8")).hexdigest() |
| md = { |
| "status": "queue", |
| "prompt": prompt, |
| "id": f"{hsh}.{tm}" |
| } |
| models.append(md) |
| q.put(json.dumps(md)) |
| return jsonify({ |
| "status": "ok" |
| }) |
| |
| def dequeue(): |
| if not q.empty(): |
| pr = json.loads(q.get_nowait()) |
| return jsonify({ |
| "status": "ok", |
| "prompt": pr["prompt"], |
| "id": pr["id"] |
| }) |
| return jsonify({ |
| "status": "empty" |
| }) |
| |
| def complete(data): |
| jsn = json.loads(data) |
| for i in range(len(models)): |
| if models[i]["id"] == jsn["_id"]: |
| models[i]["status"] = "done" |
| for fl in jsn["files"]: |
| rd = zstd.decompress(base64.b64decode(fl["data"])) |
| os.makedirs(f"files/{fl['path']}", exist_ok=True) |
| os.rmdir(f"files/{fl['path']}") |
| with open(f"files/{fl['path']}", "wb") as f: |
| f.write(rd) |
| f.flush() |
| f.close() |
| break |
| return jsonify({"status": "ok"}) |
| |
| def worker(): |
| while True: |
| if not q.empty(): |
| pr = json.loads(q.get_nowait()) |
| pr["status"] = "progress" |
| for w in WORKERS: |
| if w["status"] == "idle": |
| w["prompt"] = pr["prompt"] |
| w["id"] = pr["id"] |
| break |
| else: |
| q.put(jsonify(pr)) |
| time.sleep(1) |
| |
| app = Flask(__name__) |
|
|
| if __name__ == "__main__": |
| app.add_url_rule("/enqueue/<prompt>", "enqueue", enqueue, methods=["POST"]) |
| app.add_url_rule("/dequeue", "dequeue", dequeue, methods=["GET"]) |
| app.add_url_rule("/complete", "complete", complete, methods=["POST"]) |
| app.add_url_rule("/", "index", lambda: """ |
| <html> |
| <head> |
| <title>Snail</title> |
| </head> |
| <body> |
| <h1>Snail</h1> |
| |
| </html> |
| """, methods=["GET"]) |
| |
| app.static_folder = "public" |
| |
| app.run(port=7860, host="0.0.0.0") |