| from collections import deque
|
| from datetime import datetime
|
| import io
|
| import logging
|
| import sys
|
| import threading
|
|
|
| logs = None
|
| stdout_interceptor = None
|
| stderr_interceptor = None
|
|
|
|
|
| class LogInterceptor(io.TextIOWrapper):
|
| def __init__(self, stream, *args, **kwargs):
|
| buffer = stream.buffer
|
| encoding = stream.encoding
|
| super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
|
| self._lock = threading.Lock()
|
| self._flush_callbacks = []
|
| self._logs_since_flush = []
|
|
|
| def write(self, data):
|
| entry = {"t": datetime.now().isoformat(), "m": data}
|
| with self._lock:
|
| self._logs_since_flush.append(entry)
|
|
|
|
|
|
|
| if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
|
| logs.pop()
|
| logs.append(entry)
|
| super().write(data)
|
|
|
| def flush(self):
|
| super().flush()
|
| for cb in self._flush_callbacks:
|
| cb(self._logs_since_flush)
|
| self._logs_since_flush = []
|
|
|
| def on_flush(self, callback):
|
| self._flush_callbacks.append(callback)
|
|
|
|
|
| def get_logs():
|
| return logs
|
|
|
|
|
| def on_flush(callback):
|
| if stdout_interceptor is not None:
|
| stdout_interceptor.on_flush(callback)
|
| if stderr_interceptor is not None:
|
| stderr_interceptor.on_flush(callback)
|
|
|
| def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
|
| global logs
|
| if logs:
|
| return
|
|
|
|
|
| logs = deque(maxlen=capacity)
|
|
|
| global stdout_interceptor
|
| global stderr_interceptor
|
| stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
|
| stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
|
|
|
|
|
| logger = logging.getLogger()
|
| logger.setLevel(log_level)
|
|
|
| stream_handler = logging.StreamHandler()
|
| stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
|
|
| if use_stdout:
|
|
|
| stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
|
|
|
|
|
| stdout_handler = logging.StreamHandler(sys.stdout)
|
| stdout_handler.setFormatter(logging.Formatter("%(message)s"))
|
| stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
|
| logger.addHandler(stdout_handler)
|
|
|
| logger.addHandler(stream_handler)
|
|
|
|
|
| STARTUP_WARNINGS = []
|
|
|
|
|
| def log_startup_warning(msg):
|
| logging.warning(msg)
|
| STARTUP_WARNINGS.append(msg)
|
|
|
|
|
| def print_startup_warnings():
|
| for s in STARTUP_WARNINGS:
|
| logging.warning(s)
|
| STARTUP_WARNINGS.clear()
|
|
|