| import importlib.util |
| import logging |
| import os |
| import sys |
| from pathlib import Path |
| from types import ModuleType |
|
|
| import pytest |
|
|
| LOGGER = logging.getLogger(__name__) |
|
|
|
|
| @pytest.fixture(scope="module") |
| def server_url(): |
| server_url = os.getenv("H2OGPT_SERVER") |
| if not server_url: |
| LOGGER.info("Couldn't find a running h2oGPT server. Hence starting a one.") |
|
|
| generate = _import_module_from_h2ogpt("generate.py") |
| generate.main( |
| base_model="h2oai/h2ogpt-oig-oasst1-512-6_9b", |
| prompt_type="human_bot", |
| chat=False, |
| stream_output=False, |
| gradio=True, |
| num_beams=1, |
| block_gradio_exit=False, |
| ) |
| server_url = "http://0.0.0.0:7860" |
| LOGGER.info(f"h2oGPT server started at '{server_url}'.") |
| return server_url |
|
|
|
|
| @pytest.fixture(scope="module") |
| def h2ogpt_key(): |
| return os.getenv("H2OGPT_KEY") or os.getenv("H2OGPT_H2OGPT_KEY") |
|
|
|
|
| @pytest.fixture(scope="module") |
| def eval_func_param_names(): |
| parameters = _import_module_from_h2ogpt("src/evaluate_params.py") |
| return parameters.eval_func_param_names |
|
|
|
|
| def _import_module_from_h2ogpt(file_name: str) -> ModuleType: |
| h2ogpt_dir = Path(__file__).parent.parent.parent |
| file_path = (h2ogpt_dir / file_name).absolute() |
| module_name = file_path.stem |
|
|
| LOGGER.info(f"Loading module '{module_name}' from '{file_path}'.") |
| spec = importlib.util.spec_from_file_location(module_name, file_path) |
| if not spec: |
| raise Exception(f"Couldn't load module '{module_name}' from '{file_path}'.") |
| module = importlib.util.module_from_spec(spec) |
| sys.modules[module_name] = module |
| spec.loader.exec_module(module) |
| return module |
|
|