File size: 8,643 Bytes
cf6a8b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
from collections import defaultdict, namedtuple
import gc
import os
import re
import time
import tracemalloc
from typing import Callable, List, Optional
from ray.util.annotations import DeveloperAPI

_logged = set()
_disabled = False
_periodic_log = False
_last_logged = 0.0


@DeveloperAPI
def log_once(key):
    """Returns True if this is the "first" call for a given key.

    Various logging settings can adjust the definition of "first".

    Example:

        .. testcode::

            import logging
            from ray.util.debug import log_once

            logger = logging.getLogger(__name__)
            if log_once("some_key"):
                logger.info("Some verbose logging statement")
    """

    global _last_logged

    if _disabled:
        return False
    elif key not in _logged:
        _logged.add(key)
        _last_logged = time.time()
        return True
    elif _periodic_log and time.time() - _last_logged > 60.0:
        _logged.clear()
        _last_logged = time.time()
        return False
    else:
        return False


@DeveloperAPI
def disable_log_once_globally():
    """Make log_once() return False in this process."""

    global _disabled
    _disabled = True


@DeveloperAPI
def enable_periodic_logging():
    """Make log_once() periodically return True in this process."""

    global _periodic_log
    _periodic_log = True


@DeveloperAPI
def reset_log_once(key: Optional[str] = None):
    """Resets log_once for the provided key.

    If you don't provide a key, resets log_once for all keys.
    """
    if key is None:
        _logged.clear()
    else:
        _logged.discard(key)


# A suspicious memory-allocating stack-trace that we should re-test
# to make sure it's not a false positive.
Suspect = DeveloperAPI(
    namedtuple(
        "Suspect",
        [
            # The stack trace of the allocation, going back n frames, depending
            # on the tracemalloc.start(n) call.
            "traceback",
            # The amount of memory taken by this particular stack trace
            # over the course of the experiment.
            "memory_increase",
            # The slope of the scipy linear regression (x=iteration; y=memory size).
            "slope",
            # The rvalue of the scipy linear regression.
            "rvalue",
            # The memory size history (list of all memory sizes over all iterations).
            "hist",
        ],
    )
)


def _test_some_code_for_memory_leaks(
    desc: str,
    init: Optional[Callable[[], None]],
    code: Callable[[], None],
    repeats: int,
    max_num_trials: int = 1,
) -> List[Suspect]:
    """Runs given code (and init code) n times and checks for memory leaks.

    Args:
        desc: A descriptor of the test.
        init: Optional code to be executed initially.
        code: The actual code to be checked for producing memory leaks.
        repeats: How many times to repeatedly execute `code`.
        max_num_trials: The maximum number of trials to run. A new trial is only
            run, if the previous one produced a memory leak. For all non-1st trials,
            `repeats` calculates as: actual_repeats = `repeats` * (trial + 1), where
            the first trial is 0.

    Returns:
        A list of Suspect objects, describing possible memory leaks. If list
        is empty, no leaks have been found.
    """

    def _i_print(i):
        if (i + 1) % 10 == 0:
            print(".", end="" if (i + 1) % 100 else f" {i + 1}\n", flush=True)

    # Do n trials to make sure a found leak is really one.
    suspicious = set()
    suspicious_stats = []
    for trial in range(max_num_trials):
        # Store up to n frames of each call stack.
        tracemalloc.start(20)

        table = defaultdict(list)

        # Repeat running code for n times.
        # Increase repeat value with each trial to make sure stats are more
        # solid each time (avoiding false positives).
        actual_repeats = repeats * (trial + 1)

        print(f"{desc} {actual_repeats} times.")

        # Initialize if necessary.
        if init is not None:
            init()
        # Run `code` n times, each time taking a memory snapshot.
        for i in range(actual_repeats):
            _i_print(i)
            # Manually trigger garbage collection before and after code runs in order to
            # make tracemalloc snapshots as accurate as possible.
            gc.collect()
            code()
            gc.collect()
            _take_snapshot(table, suspicious)
        print("\n")

        # Check, which traces have moved up in their memory consumption
        # constantly over time.
        suspicious.clear()
        suspicious_stats.clear()
        # Suspicious memory allocation found?
        suspects = _find_memory_leaks_in_table(table)
        for suspect in sorted(suspects, key=lambda s: s.memory_increase, reverse=True):
            # Only print out the biggest offender:
            if len(suspicious) == 0:
                _pprint_suspect(suspect)
                print("-> added to retry list")
            suspicious.add(suspect.traceback)
            suspicious_stats.append(suspect)

        tracemalloc.stop()

        # Some suspicious memory allocations found.
        if len(suspicious) > 0:
            print(f"{len(suspicious)} suspects found. Top-ten:")
            for i, s in enumerate(suspicious_stats):
                if i > 10:
                    break
                print(
                    f"{i}) line={s.traceback[-1]} mem-increase={s.memory_increase}B "
                    f"slope={s.slope}B/detection rval={s.rvalue}"
                )
        # Nothing suspicious found -> Exit trial loop and return.
        else:
            print("No remaining suspects found -> returning")
            break

    # Print out final top offender.
    if len(suspicious_stats) > 0:
        _pprint_suspect(suspicious_stats[0])

    return suspicious_stats


def _take_snapshot(table, suspicious=None):
    # Take a memory snapshot.
    snapshot = tracemalloc.take_snapshot()
    # Group all memory allocations by their stacktrace (going n frames
    # deep as defined above in tracemalloc.start(n)).
    # Then sort groups by size, then count, then trace.
    top_stats = snapshot.statistics("traceback")

    # For the first m largest increases, keep only, if a) first trial or b) those
    # that are already in the `suspicious` set.
    for stat in top_stats[:100]:
        if not suspicious or stat.traceback in suspicious:
            table[stat.traceback].append(stat.size)


def _find_memory_leaks_in_table(table):
    import scipy.stats
    import numpy as np

    suspects = []

    for traceback, hist in table.items():
        # Do a quick mem increase check.
        memory_increase = hist[-1] - hist[0]

        # Only if memory increased, do we check further.
        if memory_increase <= 0.0:
            continue

        # Ignore this very module here (we are collecting lots of data
        # so an increase is expected).
        top_stack = str(traceback[-1])
        drive_separator = "\\\\" if os.name == "nt" else "/"
        if any(
            s in top_stack
            for s in [
                "tracemalloc",
                "pycharm",
                "thirdparty_files/psutil",
                re.sub("\\.", drive_separator, __name__) + ".py",
            ]
        ):
            continue

        # Do a linear regression to get the slope and R-value.
        line = scipy.stats.linregress(x=np.arange(len(hist)), y=np.array(hist))

        # - If weak positive slope and some confidence and
        #   increase > n bytes -> error.
        # - If stronger positive slope -> error.
        if memory_increase > 1000 and (
            (line.slope > 60.0 and line.rvalue > 0.875)
            or (line.slope > 20.0 and line.rvalue > 0.9)
            or (line.slope > 10.0 and line.rvalue > 0.95)
        ):
            suspects.append(
                Suspect(
                    traceback=traceback,
                    memory_increase=memory_increase,
                    slope=line.slope,
                    rvalue=line.rvalue,
                    hist=hist,
                )
            )

    return suspects


def _pprint_suspect(suspect):
    print(
        "Most suspicious memory allocation in traceback "
        "(only printing out this one, but all (less suspicious)"
        " suspects will be investigated as well):"
    )
    print("\n".join(suspect.traceback.format()))
    print(f"Increase total={suspect.memory_increase}B")
    print(f"Slope={suspect.slope} B/detection")
    print(f"Rval={suspect.rvalue}")