154 lines
5.1 KiB
Python
154 lines
5.1 KiB
Python
![]() |
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Python memory leak detection utility.
|
||
|
|
||
|
Please don't use this class directly. Instead, use `MemoryChecker` wrapper.
|
||
|
"""
|
||
|
|
||
|
import collections
|
||
|
import copy
|
||
|
import gc
|
||
|
|
||
|
from tensorflow.python.framework import _python_memory_checker_helper
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
from tensorflow.python.profiler import trace
|
||
|
|
||
|
|
||
|
def _get_typename(obj):
|
||
|
"""Return human readable pretty type name string."""
|
||
|
objtype = type(obj)
|
||
|
name = objtype.__name__
|
||
|
module = getattr(objtype, '__module__', None)
|
||
|
if module:
|
||
|
return '{}.{}'.format(module, name)
|
||
|
else:
|
||
|
return name
|
||
|
|
||
|
|
||
|
def _create_python_object_snapshot():
|
||
|
gc.collect()
|
||
|
all_objects = gc.get_objects()
|
||
|
result = collections.defaultdict(set)
|
||
|
for obj in all_objects:
|
||
|
result[_get_typename(obj)].add(id(obj))
|
||
|
return result
|
||
|
|
||
|
|
||
|
def _snapshot_diff(old_snapshot, new_snapshot, exclude_ids):
|
||
|
result = collections.Counter()
|
||
|
for new_name, new_ids in new_snapshot.items():
|
||
|
old_ids = old_snapshot[new_name]
|
||
|
result[new_name] = len(new_ids - exclude_ids) - len(old_ids - exclude_ids)
|
||
|
|
||
|
# This removes zero or negative value entries.
|
||
|
result += collections.Counter()
|
||
|
return result
|
||
|
|
||
|
|
||
|
class _PythonMemoryChecker(object):
|
||
|
"""Python memory leak detection class."""
|
||
|
|
||
|
def __init__(self):
|
||
|
self._snapshots = []
|
||
|
# cache the function used by mark_stack_trace_and_call to avoid
|
||
|
# contaminating the leak measurement.
|
||
|
def _record_snapshot():
|
||
|
self._snapshots.append(_create_python_object_snapshot())
|
||
|
|
||
|
self._record_snapshot = _record_snapshot
|
||
|
|
||
|
# We do not enable trace_wrapper on this function to avoid contaminating
|
||
|
# the snapshot.
|
||
|
def record_snapshot(self):
|
||
|
# Function called using `mark_stack_trace_and_call` will have
|
||
|
# "_python_memory_checker_helper" string in the C++ stack trace. This will
|
||
|
# be used to filter out C++ memory allocations caused by this function,
|
||
|
# because we are not interested in detecting memory growth caused by memory
|
||
|
# checker itself.
|
||
|
_python_memory_checker_helper.mark_stack_trace_and_call(
|
||
|
self._record_snapshot)
|
||
|
|
||
|
@trace.trace_wrapper
|
||
|
def report(self):
|
||
|
# TODO(kkb): Implement.
|
||
|
pass
|
||
|
|
||
|
@trace.trace_wrapper
|
||
|
def assert_no_leak_if_all_possibly_except_one(self):
|
||
|
"""Raises an exception if a leak is detected.
|
||
|
|
||
|
This algorithm classifies a series of allocations as a leak if it's the same
|
||
|
type at every snapshot, but possibly except one snapshot.
|
||
|
"""
|
||
|
|
||
|
snapshot_diffs = []
|
||
|
for i in range(0, len(self._snapshots) - 1):
|
||
|
snapshot_diffs.append(self._snapshot_diff(i, i + 1))
|
||
|
|
||
|
allocation_counter = collections.Counter()
|
||
|
for diff in snapshot_diffs:
|
||
|
for name, count in diff.items():
|
||
|
if count > 0:
|
||
|
allocation_counter[name] += 1
|
||
|
|
||
|
leaking_object_names = {
|
||
|
name for name, count in allocation_counter.items()
|
||
|
if count >= len(snapshot_diffs) - 1
|
||
|
}
|
||
|
|
||
|
if leaking_object_names:
|
||
|
object_list_to_print = '\n'.join(
|
||
|
[' - ' + name for name in leaking_object_names])
|
||
|
raise AssertionError(
|
||
|
'These Python objects were allocated in every snapshot possibly '
|
||
|
f'except one.\n\n{object_list_to_print}')
|
||
|
|
||
|
@trace.trace_wrapper
|
||
|
def assert_no_new_objects(self, threshold=None):
|
||
|
"""Assert no new Python objects."""
|
||
|
|
||
|
if not threshold:
|
||
|
threshold = {}
|
||
|
|
||
|
count_diff = self._snapshot_diff(0, -1)
|
||
|
original_count_diff = copy.deepcopy(count_diff)
|
||
|
count_diff.subtract(collections.Counter(threshold))
|
||
|
|
||
|
if max(count_diff.values() or [0]) > 0:
|
||
|
raise AssertionError('New Python objects created exceeded the threshold.'
|
||
|
'\nPython object threshold:\n'
|
||
|
f'{threshold}\n\nNew Python objects:\n'
|
||
|
f'{original_count_diff.most_common()}')
|
||
|
elif min(count_diff.values(), default=0) < 0:
|
||
|
logging.warning('New Python objects created were less than the threshold.'
|
||
|
'\nPython object threshold:\n'
|
||
|
f'{threshold}\n\nNew Python objects:\n'
|
||
|
f'{original_count_diff.most_common()}')
|
||
|
|
||
|
@trace.trace_wrapper
|
||
|
def _snapshot_diff(self, old_index, new_index):
|
||
|
return _snapshot_diff(self._snapshots[old_index],
|
||
|
self._snapshots[new_index],
|
||
|
self._get_internal_object_ids())
|
||
|
|
||
|
@trace.trace_wrapper
|
||
|
def _get_internal_object_ids(self):
|
||
|
ids = set()
|
||
|
for snapshot in self._snapshots:
|
||
|
ids.add(id(snapshot))
|
||
|
for v in snapshot.values():
|
||
|
ids.add(id(v))
|
||
|
return ids
|