| from flask import Flask, render_template, request, flash, jsonify |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from huggingface_hub import login |
| import os, json |
|
|
| app = Flask(__name__) |
| app.secret_key = os.urandom(24) |
|
|
| ee_model = None |
| ee_tokenizer = None |
| ee_config = None |
| ee_model_name = None |
|
|
| SPACE_HOST = os.environ.get("SPACE_HOST", "") |
| SPACE_URL = f"https://{SPACE_HOST}" if SPACE_HOST else "http://localhost:7860" |
|
|
|
|
| @app.route("/", methods=["GET", "POST"]) |
| def index(): |
| global ee_model, ee_tokenizer, ee_config, ee_model_name |
|
|
| if request.method == "POST": |
| ee_model_name = request.form["ee_model_name"].strip() |
| hf_token = request.form["hf_token"].strip() |
|
|
| try: |
| login(token=hf_token) |
|
|
| ee_model = AutoModelForCausalLM.from_pretrained( |
| ee_model_name, torch_dtype=torch.float16, |
| device_map="auto", trust_remote_code=True |
| ) |
| ee_tokenizer = AutoTokenizer.from_pretrained( |
| ee_model_name, trust_remote_code=True |
| ) |
|
|
| from huggingface_hub import hf_hub_download |
| config_path = hf_hub_download(ee_model_name, "ee_config.json") |
| with open(config_path) as f: |
| ee_config = json.load(f) |
|
|
| flash(f"β
Model loaded: {ee_model_name}", "success") |
| flash("Point your Client Space to this Space's URL below.", "info") |
|
|
| except Exception as e: |
| flash(f"Error: {str(e)}", "danger") |
|
|
| return render_template( |
| "index.html", |
| server_ready=(ee_model is not None), |
| model_name=ee_model_name if ee_config else None, |
| space_url=SPACE_URL, |
| ) |
|
|
|
|
| @app.route("/generate", methods=["POST"]) |
| def generate(): |
| """ |
| Receives sigma-encrypted embeddings + optional past_key_values. |
| Returns last hidden state (still in sigma-space) + new KV cache. |
| Does NOT run lm_head β that stays on the client. |
| Server never sees token IDs, logits, or plaintext. |
| """ |
| if ee_model is None: |
| return jsonify({"error": "Server not started yet"}), 400 |
|
|
| try: |
| data = request.json |
| model_dtype = next(ee_model.parameters()).dtype |
|
|
| inputs_embeds = torch.tensor(data["inputs_embeds"]).to( |
| dtype=model_dtype, device=ee_model.device |
| ) |
|
|
| attention_mask = torch.tensor( |
| data.get("attention_mask", [[1] * inputs_embeds.shape[1]]) |
| ).to(device=ee_model.device) |
|
|
| past_key_values = None |
| if data.get("past_key_values"): |
| past_key_values = tuple( |
| tuple( |
| torch.tensor(t).to(dtype=model_dtype, device=ee_model.device) |
| for t in layer |
| ) |
| for layer in data["past_key_values"] |
| ) |
|
|
| with torch.no_grad(): |
| out = ee_model( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| use_cache=False, |
| output_hidden_states=True, |
| ) |
|
|
| |
| last_hidden = out.hidden_states[-1] |
|
|
| return jsonify({ |
| "last_hidden": last_hidden.cpu().tolist(), |
| }) |
|
|
| except Exception as e: |
| import traceback |
| return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500 |
|
|
|
|
| if __name__ == "__main__": |
| app.run(host="0.0.0.0", port=7860) |