| from collections import defaultdict, namedtuple |
| import gc |
| import os |
| import re |
| import time |
| import tracemalloc |
| from typing import Callable, List, Optional |
| from ray.util.annotations import DeveloperAPI |
|
|
| _logged = set() |
| _disabled = False |
| _periodic_log = False |
| _last_logged = 0.0 |
|
|
|
|
| @DeveloperAPI |
| def log_once(key): |
| """Returns True if this is the "first" call for a given key. |
| |
| Various logging settings can adjust the definition of "first". |
| |
| Example: |
| |
| .. testcode:: |
| |
| import logging |
| from ray.util.debug import log_once |
| |
| logger = logging.getLogger(__name__) |
| if log_once("some_key"): |
| logger.info("Some verbose logging statement") |
| """ |
|
|
| global _last_logged |
|
|
| if _disabled: |
| return False |
| elif key not in _logged: |
| _logged.add(key) |
| _last_logged = time.time() |
| return True |
| elif _periodic_log and time.time() - _last_logged > 60.0: |
| _logged.clear() |
| _last_logged = time.time() |
| return False |
| else: |
| return False |
|
|
|
|
| @DeveloperAPI |
| def disable_log_once_globally(): |
| """Make log_once() return False in this process.""" |
|
|
| global _disabled |
| _disabled = True |
|
|
|
|
| @DeveloperAPI |
| def enable_periodic_logging(): |
| """Make log_once() periodically return True in this process.""" |
|
|
| global _periodic_log |
| _periodic_log = True |
|
|
|
|
| @DeveloperAPI |
| def reset_log_once(key: Optional[str] = None): |
| """Resets log_once for the provided key. |
| |
| If you don't provide a key, resets log_once for all keys. |
| """ |
| if key is None: |
| _logged.clear() |
| else: |
| _logged.discard(key) |
|
|
|
|
| |
| |
| Suspect = DeveloperAPI( |
| namedtuple( |
| "Suspect", |
| [ |
| |
| |
| "traceback", |
| |
| |
| "memory_increase", |
| |
| "slope", |
| |
| "rvalue", |
| |
| "hist", |
| ], |
| ) |
| ) |
|
|
|
|
| def _test_some_code_for_memory_leaks( |
| desc: str, |
| init: Optional[Callable[[], None]], |
| code: Callable[[], None], |
| repeats: int, |
| max_num_trials: int = 1, |
| ) -> List[Suspect]: |
| """Runs given code (and init code) n times and checks for memory leaks. |
| |
| Args: |
| desc: A descriptor of the test. |
| init: Optional code to be executed initially. |
| code: The actual code to be checked for producing memory leaks. |
| repeats: How many times to repeatedly execute `code`. |
| max_num_trials: The maximum number of trials to run. A new trial is only |
| run, if the previous one produced a memory leak. For all non-1st trials, |
| `repeats` calculates as: actual_repeats = `repeats` * (trial + 1), where |
| the first trial is 0. |
| |
| Returns: |
| A list of Suspect objects, describing possible memory leaks. If list |
| is empty, no leaks have been found. |
| """ |
|
|
| def _i_print(i): |
| if (i + 1) % 10 == 0: |
| print(".", end="" if (i + 1) % 100 else f" {i + 1}\n", flush=True) |
|
|
| |
| suspicious = set() |
| suspicious_stats = [] |
| for trial in range(max_num_trials): |
| |
| tracemalloc.start(20) |
|
|
| table = defaultdict(list) |
|
|
| |
| |
| |
| actual_repeats = repeats * (trial + 1) |
|
|
| print(f"{desc} {actual_repeats} times.") |
|
|
| |
| if init is not None: |
| init() |
| |
| for i in range(actual_repeats): |
| _i_print(i) |
| |
| |
| gc.collect() |
| code() |
| gc.collect() |
| _take_snapshot(table, suspicious) |
| print("\n") |
|
|
| |
| |
| suspicious.clear() |
| suspicious_stats.clear() |
| |
| suspects = _find_memory_leaks_in_table(table) |
| for suspect in sorted(suspects, key=lambda s: s.memory_increase, reverse=True): |
| |
| if len(suspicious) == 0: |
| _pprint_suspect(suspect) |
| print("-> added to retry list") |
| suspicious.add(suspect.traceback) |
| suspicious_stats.append(suspect) |
|
|
| tracemalloc.stop() |
|
|
| |
| if len(suspicious) > 0: |
| print(f"{len(suspicious)} suspects found. Top-ten:") |
| for i, s in enumerate(suspicious_stats): |
| if i > 10: |
| break |
| print( |
| f"{i}) line={s.traceback[-1]} mem-increase={s.memory_increase}B " |
| f"slope={s.slope}B/detection rval={s.rvalue}" |
| ) |
| |
| else: |
| print("No remaining suspects found -> returning") |
| break |
|
|
| |
| if len(suspicious_stats) > 0: |
| _pprint_suspect(suspicious_stats[0]) |
|
|
| return suspicious_stats |
|
|
|
|
| def _take_snapshot(table, suspicious=None): |
| |
| snapshot = tracemalloc.take_snapshot() |
| |
| |
| |
| top_stats = snapshot.statistics("traceback") |
|
|
| |
| |
| for stat in top_stats[:100]: |
| if not suspicious or stat.traceback in suspicious: |
| table[stat.traceback].append(stat.size) |
|
|
|
|
| def _find_memory_leaks_in_table(table): |
| import scipy.stats |
| import numpy as np |
|
|
| suspects = [] |
|
|
| for traceback, hist in table.items(): |
| |
| memory_increase = hist[-1] - hist[0] |
|
|
| |
| if memory_increase <= 0.0: |
| continue |
|
|
| |
| |
| top_stack = str(traceback[-1]) |
| drive_separator = "\\\\" if os.name == "nt" else "/" |
| if any( |
| s in top_stack |
| for s in [ |
| "tracemalloc", |
| "pycharm", |
| "thirdparty_files/psutil", |
| re.sub("\\.", drive_separator, __name__) + ".py", |
| ] |
| ): |
| continue |
|
|
| |
| line = scipy.stats.linregress(x=np.arange(len(hist)), y=np.array(hist)) |
|
|
| |
| |
| |
| if memory_increase > 1000 and ( |
| (line.slope > 60.0 and line.rvalue > 0.875) |
| or (line.slope > 20.0 and line.rvalue > 0.9) |
| or (line.slope > 10.0 and line.rvalue > 0.95) |
| ): |
| suspects.append( |
| Suspect( |
| traceback=traceback, |
| memory_increase=memory_increase, |
| slope=line.slope, |
| rvalue=line.rvalue, |
| hist=hist, |
| ) |
| ) |
|
|
| return suspects |
|
|
|
|
| def _pprint_suspect(suspect): |
| print( |
| "Most suspicious memory allocation in traceback " |
| "(only printing out this one, but all (less suspicious)" |
| " suspects will be investigated as well):" |
| ) |
| print("\n".join(suspect.traceback.format())) |
| print(f"Increase total={suspect.memory_increase}B") |
| print(f"Slope={suspect.slope} B/detection") |
| print(f"Rval={suspect.rvalue}") |
|
|