|
1 """Helper to provide extensibility for pickle/cPickle. |
|
2 |
|
3 This is only useful to add pickle support for extension types defined in |
|
4 C, not for instances of user-defined classes. |
|
5 """ |
|
6 |
|
7 from types import ClassType as _ClassType |
|
8 |
|
9 __all__ = ["pickle", "constructor", |
|
10 "add_extension", "remove_extension", "clear_extension_cache"] |
|
11 |
|
12 dispatch_table = {} |
|
13 |
|
14 def pickle(ob_type, pickle_function, constructor_ob=None): |
|
15 if type(ob_type) is _ClassType: |
|
16 raise TypeError("copy_reg is not intended for use with classes") |
|
17 |
|
18 if not hasattr(pickle_function, '__call__'): |
|
19 raise TypeError("reduction functions must be callable") |
|
20 dispatch_table[ob_type] = pickle_function |
|
21 |
|
22 # The constructor_ob function is a vestige of safe for unpickling. |
|
23 # There is no reason for the caller to pass it anymore. |
|
24 if constructor_ob is not None: |
|
25 constructor(constructor_ob) |
|
26 |
|
27 def constructor(object): |
|
28 if not hasattr(object, '__call__'): |
|
29 raise TypeError("constructors must be callable") |
|
30 |
|
31 # Example: provide pickling support for complex numbers. |
|
32 |
|
33 try: |
|
34 complex |
|
35 except NameError: |
|
36 pass |
|
37 else: |
|
38 |
|
39 def pickle_complex(c): |
|
40 return complex, (c.real, c.imag) |
|
41 |
|
42 pickle(complex, pickle_complex, complex) |
|
43 |
|
44 # Support for pickling new-style objects |
|
45 |
|
46 def _reconstructor(cls, base, state): |
|
47 if base is object: |
|
48 obj = object.__new__(cls) |
|
49 else: |
|
50 obj = base.__new__(cls, state) |
|
51 if base.__init__ != object.__init__: |
|
52 base.__init__(obj, state) |
|
53 return obj |
|
54 |
|
55 _HEAPTYPE = 1<<9 |
|
56 |
|
57 # Python code for object.__reduce_ex__ for protocols 0 and 1 |
|
58 |
|
59 def _reduce_ex(self, proto): |
|
60 assert proto < 2 |
|
61 for base in self.__class__.__mro__: |
|
62 if hasattr(base, '__flags__') and not base.__flags__ & _HEAPTYPE: |
|
63 break |
|
64 else: |
|
65 base = object # not really reachable |
|
66 if base is object: |
|
67 state = None |
|
68 else: |
|
69 if base is self.__class__: |
|
70 raise TypeError, "can't pickle %s objects" % base.__name__ |
|
71 state = base(self) |
|
72 args = (self.__class__, base, state) |
|
73 try: |
|
74 getstate = self.__getstate__ |
|
75 except AttributeError: |
|
76 if getattr(self, "__slots__", None): |
|
77 raise TypeError("a class that defines __slots__ without " |
|
78 "defining __getstate__ cannot be pickled") |
|
79 try: |
|
80 dict = self.__dict__ |
|
81 except AttributeError: |
|
82 dict = None |
|
83 else: |
|
84 dict = getstate() |
|
85 if dict: |
|
86 return _reconstructor, args, dict |
|
87 else: |
|
88 return _reconstructor, args |
|
89 |
|
90 # Helper for __reduce_ex__ protocol 2 |
|
91 |
|
92 def __newobj__(cls, *args): |
|
93 return cls.__new__(cls, *args) |
|
94 |
|
95 def _slotnames(cls): |
|
96 """Return a list of slot names for a given class. |
|
97 |
|
98 This needs to find slots defined by the class and its bases, so we |
|
99 can't simply return the __slots__ attribute. We must walk down |
|
100 the Method Resolution Order and concatenate the __slots__ of each |
|
101 class found there. (This assumes classes don't modify their |
|
102 __slots__ attribute to misrepresent their slots after the class is |
|
103 defined.) |
|
104 """ |
|
105 |
|
106 # Get the value from a cache in the class if possible |
|
107 names = cls.__dict__.get("__slotnames__") |
|
108 if names is not None: |
|
109 return names |
|
110 |
|
111 # Not cached -- calculate the value |
|
112 names = [] |
|
113 if not hasattr(cls, "__slots__"): |
|
114 # This class has no slots |
|
115 pass |
|
116 else: |
|
117 # Slots found -- gather slot names from all base classes |
|
118 for c in cls.__mro__: |
|
119 if "__slots__" in c.__dict__: |
|
120 slots = c.__dict__['__slots__'] |
|
121 # if class has a single slot, it can be given as a string |
|
122 if isinstance(slots, basestring): |
|
123 slots = (slots,) |
|
124 for name in slots: |
|
125 # special descriptors |
|
126 if name in ("__dict__", "__weakref__"): |
|
127 continue |
|
128 # mangled names |
|
129 elif name.startswith('__') and not name.endswith('__'): |
|
130 names.append('_%s%s' % (c.__name__, name)) |
|
131 else: |
|
132 names.append(name) |
|
133 |
|
134 # Cache the outcome in the class if at all possible |
|
135 try: |
|
136 cls.__slotnames__ = names |
|
137 except: |
|
138 pass # But don't die if we can't |
|
139 |
|
140 return names |
|
141 |
|
142 # A registry of extension codes. This is an ad-hoc compression |
|
143 # mechanism. Whenever a global reference to <module>, <name> is about |
|
144 # to be pickled, the (<module>, <name>) tuple is looked up here to see |
|
145 # if it is a registered extension code for it. Extension codes are |
|
146 # universal, so that the meaning of a pickle does not depend on |
|
147 # context. (There are also some codes reserved for local use that |
|
148 # don't have this restriction.) Codes are positive ints; 0 is |
|
149 # reserved. |
|
150 |
|
151 _extension_registry = {} # key -> code |
|
152 _inverted_registry = {} # code -> key |
|
153 _extension_cache = {} # code -> object |
|
154 # Don't ever rebind those names: cPickle grabs a reference to them when |
|
155 # it's initialized, and won't see a rebinding. |
|
156 |
|
157 def add_extension(module, name, code): |
|
158 """Register an extension code.""" |
|
159 code = int(code) |
|
160 if not 1 <= code <= 0x7fffffff: |
|
161 raise ValueError, "code out of range" |
|
162 key = (module, name) |
|
163 if (_extension_registry.get(key) == code and |
|
164 _inverted_registry.get(code) == key): |
|
165 return # Redundant registrations are benign |
|
166 if key in _extension_registry: |
|
167 raise ValueError("key %s is already registered with code %s" % |
|
168 (key, _extension_registry[key])) |
|
169 if code in _inverted_registry: |
|
170 raise ValueError("code %s is already in use for key %s" % |
|
171 (code, _inverted_registry[code])) |
|
172 _extension_registry[key] = code |
|
173 _inverted_registry[code] = key |
|
174 |
|
175 def remove_extension(module, name, code): |
|
176 """Unregister an extension code. For testing only.""" |
|
177 key = (module, name) |
|
178 if (_extension_registry.get(key) != code or |
|
179 _inverted_registry.get(code) != key): |
|
180 raise ValueError("key %s is not registered with code %s" % |
|
181 (key, code)) |
|
182 del _extension_registry[key] |
|
183 del _inverted_registry[code] |
|
184 if code in _extension_cache: |
|
185 del _extension_cache[code] |
|
186 |
|
187 def clear_extension_cache(): |
|
188 _extension_cache.clear() |
|
189 |
|
190 # Standard extension code assignments |
|
191 |
|
192 # Reserved ranges |
|
193 |
|
194 # First Last Count Purpose |
|
195 # 1 127 127 Reserved for Python standard library |
|
196 # 128 191 64 Reserved for Zope |
|
197 # 192 239 48 Reserved for 3rd parties |
|
198 # 240 255 16 Reserved for private use (will never be assigned) |
|
199 # 256 Inf Inf Reserved for future assignment |
|
200 |
|
201 # Extension codes are assigned by the Python Software Foundation. |