| from flask import Flask, request, Response, stream_with_context |
| import requests |
| import os |
| import json |
|
|
| app = Flask(__name__) |
|
|
| |
| TARGET_API = os.getenv("TARGET_API", "https://huggingface.co") |
|
|
| |
| |
| def get_path_mappings(): |
| mappings_str = os.getenv("PATH_MAPPINGS", '{"/": "/"}') |
| try: |
| return json.loads(mappings_str) |
| except json.JSONDecodeError: |
| |
| return { |
| "/": "/", |
| } |
|
|
| PATH_MAPPINGS = get_path_mappings() |
|
|
|
|
| @app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH']) |
| def proxy(path): |
| |
| full_path = f"/{path}" |
|
|
| |
| for original_path, new_path in PATH_MAPPINGS.items(): |
| if full_path.startswith(original_path): |
| full_path = full_path.replace(original_path, new_path, 1) |
| break |
|
|
| |
| target_url = f"{TARGET_API}{full_path}" |
|
|
| |
| headers = {key: value for key, value in request.headers if key != 'Host'} |
|
|
| |
| if request.method == 'POST': |
| response = requests.post( |
| target_url, |
| headers=headers, |
| json=request.get_json(silent=True), |
| params=request.args, |
| stream=True |
| ) |
| elif request.method == 'GET': |
| response = requests.get( |
| target_url, |
| headers=headers, |
| params=request.args, |
| stream=True |
| ) |
| elif request.method == 'PUT': |
| response = requests.put( |
| target_url, |
| headers=headers, |
| json=request.get_json(silent=True), |
| params=request.args, |
| stream=True |
| ) |
| elif request.method == 'DELETE': |
| response = requests.delete( |
| target_url, |
| headers=headers, |
| params=request.args, |
| stream=True |
| ) |
| elif request.method == 'PATCH': |
| response = requests.patch( |
| target_url, |
| headers=headers, |
| json=request.get_json(silent=True), |
| params=request.args, |
| stream=True |
| ) |
|
|
| |
| def generate(): |
| for chunk in response.iter_content(chunk_size=8192): |
| yield chunk |
|
|
| |
| proxy_response = Response( |
| stream_with_context(generate()), |
| status=response.status_code |
| ) |
|
|
| |
| for key, value in response.headers.items(): |
| if key.lower() not in ('content-length', 'transfer-encoding', 'connection'): |
| proxy_response.headers[key] = value |
|
|
| return proxy_response |
|
|
|
|
| @app.route('/', methods=['GET']) |
| def index(): |
| return "service running." |
|
|
|
|
| if __name__ == '__main__': |
| app.run(host='0.0.0.0', port=7860, debug=False) |
|
|