import inspect from functools import partial from joblib.externals.cloudpickle import dumps, loads WRAP_CACHE = {} class CloudpickledObjectWrapper: def __init__(self, obj, keep_wrapper=False): self._obj = obj self._keep_wrapper = keep_wrapper def __reduce__(self): _pickled_object = dumps(self._obj) if not self._keep_wrapper: return loads, (_pickled_object,) return _reconstruct_wrapper, (_pickled_object, self._keep_wrapper) def __getattr__(self, attr): # Ensure that the wrapped object can be used seemlessly as the # previous object. if attr not in ["_obj", "_keep_wrapper"]: return getattr(self._obj, attr) return getattr(self, attr) # Make sure the wrapped object conserves the callable property class CallableObjectWrapper(CloudpickledObjectWrapper): def __call__(self, *args, **kwargs): return self._obj(*args, **kwargs) def _wrap_non_picklable_objects(obj, keep_wrapper): if callable(obj): return CallableObjectWrapper(obj, keep_wrapper=keep_wrapper) return CloudpickledObjectWrapper(obj, keep_wrapper=keep_wrapper) def _reconstruct_wrapper(_pickled_object, keep_wrapper): obj = loads(_pickled_object) return _wrap_non_picklable_objects(obj, keep_wrapper) def _wrap_objects_when_needed(obj): # Function to introspect an object and decide if it should be wrapped or # not. need_wrap = "__main__" in getattr(obj, "__module__", "") if isinstance(obj, partial): return partial( _wrap_objects_when_needed(obj.func), *[_wrap_objects_when_needed(a) for a in obj.args], **{ k: _wrap_objects_when_needed(v) for k, v in obj.keywords.items() } ) if callable(obj): # Need wrap if the object is a function defined in a local scope of # another function. func_code = getattr(obj, "__code__", "") need_wrap |= getattr(func_code, "co_flags", 0) & inspect.CO_NESTED # Need wrap if the obj is a lambda expression func_name = getattr(obj, "__name__", "") need_wrap |= "" in func_name if not need_wrap: return obj wrapped_obj = WRAP_CACHE.get(obj) if wrapped_obj is None: wrapped_obj = _wrap_non_picklable_objects(obj, keep_wrapper=False) WRAP_CACHE[obj] = wrapped_obj return wrapped_obj def wrap_non_picklable_objects(obj, keep_wrapper=True): """Wrapper for non-picklable object to use cloudpickle to serialize them. Note that this wrapper tends to slow down the serialization process as it is done with cloudpickle which is typically slower compared to pickle. The proper way to solve serialization issues is to avoid defining functions and objects in the main scripts and to implement __reduce__ functions for complex classes. """ # If obj is a class, create a CloudpickledClassWrapper which instantiates # the object internally and wrap it directly in a CloudpickledObjectWrapper if inspect.isclass(obj): class CloudpickledClassWrapper(CloudpickledObjectWrapper): def __init__(self, *args, **kwargs): self._obj = obj(*args, **kwargs) self._keep_wrapper = keep_wrapper CloudpickledClassWrapper.__name__ = obj.__name__ return CloudpickledClassWrapper # If obj is an instance of a class, just wrap it in a regular # CloudpickledObjectWrapper return _wrap_non_picklable_objects(obj, keep_wrapper=keep_wrapper)