| import atexit |
| import re |
| import shlex |
| import subprocess |
| from pathlib import Path |
| from tempfile import TemporaryDirectory |
| from typing import Union |
| from gradio import strings |
| import os |
|
|
| from modules.shared import cmd_opts |
|
|
| LOCALHOST_RUN = "localhost.run" |
| REMOTE_MOE = "remote.moe" |
| localhostrun_pattern = re.compile(r"(?P<url>https?://\S+\.lhr\.life)") |
| remotemoe_pattern = re.compile(r"(?P<url>https?://\S+\.remote\.moe)") |
|
|
|
|
| def gen_key(path: Union[str, Path]) -> None: |
| path = Path(path) |
| arg_string = f'ssh-keygen -t rsa -b 4096 -N "" -q -f {path.as_posix()}' |
| args = shlex.split(arg_string) |
| subprocess.run(args, check=True) |
| path.chmod(0o600) |
|
|
|
|
| def ssh_tunnel(host: str = LOCALHOST_RUN) -> None: |
| ssh_name = "id_rsa" |
| ssh_path = Path(__file__).parent.parent / ssh_name |
|
|
| tmp = None |
| if not ssh_path.exists(): |
| try: |
| gen_key(ssh_path) |
| |
| except subprocess.CalledProcessError: |
| tmp = TemporaryDirectory() |
| ssh_path = Path(tmp.name) / ssh_name |
| gen_key(ssh_path) |
|
|
| port = cmd_opts.port if cmd_opts.port else 7860 |
|
|
| arg_string = f"ssh -R 80:127.0.0.1:{port} -o StrictHostKeyChecking=no -i {ssh_path.as_posix()} {host}" |
| args = shlex.split(arg_string) |
|
|
| tunnel = subprocess.Popen( |
| args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8" |
| ) |
|
|
| atexit.register(tunnel.terminate) |
| if tmp is not None: |
| atexit.register(tmp.cleanup) |
|
|
| tunnel_url = "" |
| lines = 27 if host == LOCALHOST_RUN else 5 |
| pattern = localhostrun_pattern if host == LOCALHOST_RUN else remotemoe_pattern |
|
|
| for _ in range(lines): |
| line = tunnel.stdout.readline() |
| if line.startswith("Warning"): |
| print(line, end="") |
|
|
| url_match = pattern.search(line) |
| if url_match: |
| tunnel_url = url_match.group("url") |
| break |
| else: |
| raise RuntimeError(f"Failed to run {host}") |
|
|
| |
| os.environ['webui_url'] = tunnel_url |
| colab_url = os.getenv('colab_url') |
| strings.en["SHARE_LINK_MESSAGE"] = f"Running on public URL (recommended): {tunnel_url}" |
|
|
| if cmd_opts.localhostrun: |
| print("localhost.run detected, trying to connect...") |
| ssh_tunnel(LOCALHOST_RUN) |
|
|
| if cmd_opts.remotemoe: |
| print("remote.moe detected, trying to connect...") |
| ssh_tunnel(REMOTE_MOE) |
|
|