|
1 import unittest, os |
|
2 from test import test_support |
|
3 |
|
4 import warnings |
|
5 warnings.filterwarnings( |
|
6 "ignore", |
|
7 category=DeprecationWarning, |
|
8 message=".*complex divmod.*are deprecated" |
|
9 ) |
|
10 |
|
11 from random import random |
|
12 |
|
13 # These tests ensure that complex math does the right thing |
|
14 |
|
15 class ComplexTest(unittest.TestCase): |
|
16 |
|
17 def assertAlmostEqual(self, a, b): |
|
18 if isinstance(a, complex): |
|
19 if isinstance(b, complex): |
|
20 unittest.TestCase.assertAlmostEqual(self, a.real, b.real) |
|
21 unittest.TestCase.assertAlmostEqual(self, a.imag, b.imag) |
|
22 else: |
|
23 unittest.TestCase.assertAlmostEqual(self, a.real, b) |
|
24 unittest.TestCase.assertAlmostEqual(self, a.imag, 0.) |
|
25 else: |
|
26 if isinstance(b, complex): |
|
27 unittest.TestCase.assertAlmostEqual(self, a, b.real) |
|
28 unittest.TestCase.assertAlmostEqual(self, 0., b.imag) |
|
29 else: |
|
30 unittest.TestCase.assertAlmostEqual(self, a, b) |
|
31 |
|
32 def assertCloseAbs(self, x, y, eps=1e-9): |
|
33 """Return true iff floats x and y "are close\"""" |
|
34 # put the one with larger magnitude second |
|
35 if abs(x) > abs(y): |
|
36 x, y = y, x |
|
37 if y == 0: |
|
38 return abs(x) < eps |
|
39 if x == 0: |
|
40 return abs(y) < eps |
|
41 # check that relative difference < eps |
|
42 self.assert_(abs((x-y)/y) < eps) |
|
43 |
|
44 def assertClose(self, x, y, eps=1e-9): |
|
45 """Return true iff complexes x and y "are close\"""" |
|
46 self.assertCloseAbs(x.real, y.real, eps) |
|
47 self.assertCloseAbs(x.imag, y.imag, eps) |
|
48 |
|
49 def assertIs(self, a, b): |
|
50 self.assert_(a is b) |
|
51 |
|
52 def check_div(self, x, y): |
|
53 """Compute complex z=x*y, and check that z/x==y and z/y==x.""" |
|
54 z = x * y |
|
55 if x != 0: |
|
56 q = z / x |
|
57 self.assertClose(q, y) |
|
58 q = z.__div__(x) |
|
59 self.assertClose(q, y) |
|
60 q = z.__truediv__(x) |
|
61 self.assertClose(q, y) |
|
62 if y != 0: |
|
63 q = z / y |
|
64 self.assertClose(q, x) |
|
65 q = z.__div__(y) |
|
66 self.assertClose(q, x) |
|
67 q = z.__truediv__(y) |
|
68 self.assertClose(q, x) |
|
69 |
|
70 def test_div(self): |
|
71 simple_real = [float(i) for i in xrange(-5, 6)] |
|
72 simple_complex = [complex(x, y) for x in simple_real for y in simple_real] |
|
73 for x in simple_complex: |
|
74 for y in simple_complex: |
|
75 self.check_div(x, y) |
|
76 |
|
77 # A naive complex division algorithm (such as in 2.0) is very prone to |
|
78 # nonsense errors for these (overflows and underflows). |
|
79 self.check_div(complex(1e200, 1e200), 1+0j) |
|
80 self.check_div(complex(1e-200, 1e-200), 1+0j) |
|
81 |
|
82 # Just for fun. |
|
83 for i in xrange(100): |
|
84 self.check_div(complex(random(), random()), |
|
85 complex(random(), random())) |
|
86 |
|
87 self.assertRaises(ZeroDivisionError, complex.__div__, 1+1j, 0+0j) |
|
88 # FIXME: The following currently crashes on Alpha |
|
89 # self.assertRaises(OverflowError, pow, 1e200+1j, 1e200+1j) |
|
90 |
|
91 def test_truediv(self): |
|
92 self.assertAlmostEqual(complex.__truediv__(2+0j, 1+1j), 1-1j) |
|
93 self.assertRaises(ZeroDivisionError, complex.__truediv__, 1+1j, 0+0j) |
|
94 |
|
95 def test_floordiv(self): |
|
96 self.assertAlmostEqual(complex.__floordiv__(3+0j, 1.5+0j), 2) |
|
97 self.assertRaises(ZeroDivisionError, complex.__floordiv__, 3+0j, 0+0j) |
|
98 |
|
99 def test_coerce(self): |
|
100 self.assertRaises(OverflowError, complex.__coerce__, 1+1j, 1L<<10000) |
|
101 |
|
102 def test_richcompare(self): |
|
103 self.assertRaises(OverflowError, complex.__eq__, 1+1j, 1L<<10000) |
|
104 self.assertEqual(complex.__lt__(1+1j, None), NotImplemented) |
|
105 self.assertIs(complex.__eq__(1+1j, 1+1j), True) |
|
106 self.assertIs(complex.__eq__(1+1j, 2+2j), False) |
|
107 self.assertIs(complex.__ne__(1+1j, 1+1j), False) |
|
108 self.assertIs(complex.__ne__(1+1j, 2+2j), True) |
|
109 self.assertRaises(TypeError, complex.__lt__, 1+1j, 2+2j) |
|
110 self.assertRaises(TypeError, complex.__le__, 1+1j, 2+2j) |
|
111 self.assertRaises(TypeError, complex.__gt__, 1+1j, 2+2j) |
|
112 self.assertRaises(TypeError, complex.__ge__, 1+1j, 2+2j) |
|
113 |
|
114 def test_mod(self): |
|
115 self.assertRaises(ZeroDivisionError, (1+1j).__mod__, 0+0j) |
|
116 |
|
117 a = 3.33+4.43j |
|
118 try: |
|
119 a % 0 |
|
120 except ZeroDivisionError: |
|
121 pass |
|
122 else: |
|
123 self.fail("modulo parama can't be 0") |
|
124 |
|
125 def test_divmod(self): |
|
126 self.assertRaises(ZeroDivisionError, divmod, 1+1j, 0+0j) |
|
127 |
|
128 def test_pow(self): |
|
129 self.assertAlmostEqual(pow(1+1j, 0+0j), 1.0) |
|
130 self.assertAlmostEqual(pow(0+0j, 2+0j), 0.0) |
|
131 self.assertRaises(ZeroDivisionError, pow, 0+0j, 1j) |
|
132 self.assertAlmostEqual(pow(1j, -1), 1/1j) |
|
133 self.assertAlmostEqual(pow(1j, 200), 1) |
|
134 self.assertRaises(ValueError, pow, 1+1j, 1+1j, 1+1j) |
|
135 |
|
136 a = 3.33+4.43j |
|
137 self.assertEqual(a ** 0j, 1) |
|
138 self.assertEqual(a ** 0.+0.j, 1) |
|
139 |
|
140 self.assertEqual(3j ** 0j, 1) |
|
141 self.assertEqual(3j ** 0, 1) |
|
142 |
|
143 try: |
|
144 0j ** a |
|
145 except ZeroDivisionError: |
|
146 pass |
|
147 else: |
|
148 self.fail("should fail 0.0 to negative or complex power") |
|
149 |
|
150 try: |
|
151 0j ** (3-2j) |
|
152 except ZeroDivisionError: |
|
153 pass |
|
154 else: |
|
155 self.fail("should fail 0.0 to negative or complex power") |
|
156 |
|
157 # The following is used to exercise certain code paths |
|
158 self.assertEqual(a ** 105, a ** 105) |
|
159 self.assertEqual(a ** -105, a ** -105) |
|
160 self.assertEqual(a ** -30, a ** -30) |
|
161 |
|
162 self.assertEqual(0.0j ** 0, 1) |
|
163 |
|
164 b = 5.1+2.3j |
|
165 self.assertRaises(ValueError, pow, a, b, 0) |
|
166 |
|
167 def test_boolcontext(self): |
|
168 for i in xrange(100): |
|
169 self.assert_(complex(random() + 1e-6, random() + 1e-6)) |
|
170 self.assert_(not complex(0.0, 0.0)) |
|
171 |
|
172 def test_conjugate(self): |
|
173 self.assertClose(complex(5.3, 9.8).conjugate(), 5.3-9.8j) |
|
174 |
|
175 def test_constructor(self): |
|
176 class OS: |
|
177 def __init__(self, value): self.value = value |
|
178 def __complex__(self): return self.value |
|
179 class NS(object): |
|
180 def __init__(self, value): self.value = value |
|
181 def __complex__(self): return self.value |
|
182 self.assertEqual(complex(OS(1+10j)), 1+10j) |
|
183 self.assertEqual(complex(NS(1+10j)), 1+10j) |
|
184 self.assertRaises(TypeError, complex, OS(None)) |
|
185 self.assertRaises(TypeError, complex, NS(None)) |
|
186 |
|
187 self.assertAlmostEqual(complex("1+10j"), 1+10j) |
|
188 self.assertAlmostEqual(complex(10), 10+0j) |
|
189 self.assertAlmostEqual(complex(10.0), 10+0j) |
|
190 self.assertAlmostEqual(complex(10L), 10+0j) |
|
191 self.assertAlmostEqual(complex(10+0j), 10+0j) |
|
192 self.assertAlmostEqual(complex(1,10), 1+10j) |
|
193 self.assertAlmostEqual(complex(1,10L), 1+10j) |
|
194 self.assertAlmostEqual(complex(1,10.0), 1+10j) |
|
195 self.assertAlmostEqual(complex(1L,10), 1+10j) |
|
196 self.assertAlmostEqual(complex(1L,10L), 1+10j) |
|
197 self.assertAlmostEqual(complex(1L,10.0), 1+10j) |
|
198 self.assertAlmostEqual(complex(1.0,10), 1+10j) |
|
199 self.assertAlmostEqual(complex(1.0,10L), 1+10j) |
|
200 self.assertAlmostEqual(complex(1.0,10.0), 1+10j) |
|
201 self.assertAlmostEqual(complex(3.14+0j), 3.14+0j) |
|
202 self.assertAlmostEqual(complex(3.14), 3.14+0j) |
|
203 self.assertAlmostEqual(complex(314), 314.0+0j) |
|
204 self.assertAlmostEqual(complex(314L), 314.0+0j) |
|
205 self.assertAlmostEqual(complex(3.14+0j, 0j), 3.14+0j) |
|
206 self.assertAlmostEqual(complex(3.14, 0.0), 3.14+0j) |
|
207 self.assertAlmostEqual(complex(314, 0), 314.0+0j) |
|
208 self.assertAlmostEqual(complex(314L, 0L), 314.0+0j) |
|
209 self.assertAlmostEqual(complex(0j, 3.14j), -3.14+0j) |
|
210 self.assertAlmostEqual(complex(0.0, 3.14j), -3.14+0j) |
|
211 self.assertAlmostEqual(complex(0j, 3.14), 3.14j) |
|
212 self.assertAlmostEqual(complex(0.0, 3.14), 3.14j) |
|
213 self.assertAlmostEqual(complex("1"), 1+0j) |
|
214 self.assertAlmostEqual(complex("1j"), 1j) |
|
215 self.assertAlmostEqual(complex(), 0) |
|
216 self.assertAlmostEqual(complex("-1"), -1) |
|
217 self.assertAlmostEqual(complex("+1"), +1) |
|
218 |
|
219 class complex2(complex): pass |
|
220 self.assertAlmostEqual(complex(complex2(1+1j)), 1+1j) |
|
221 self.assertAlmostEqual(complex(real=17, imag=23), 17+23j) |
|
222 self.assertAlmostEqual(complex(real=17+23j), 17+23j) |
|
223 self.assertAlmostEqual(complex(real=17+23j, imag=23), 17+46j) |
|
224 self.assertAlmostEqual(complex(real=1+2j, imag=3+4j), -3+5j) |
|
225 |
|
226 c = 3.14 + 1j |
|
227 self.assert_(complex(c) is c) |
|
228 del c |
|
229 |
|
230 self.assertRaises(TypeError, complex, "1", "1") |
|
231 self.assertRaises(TypeError, complex, 1, "1") |
|
232 |
|
233 self.assertEqual(complex(" 3.14+J "), 3.14+1j) |
|
234 if test_support.have_unicode: |
|
235 self.assertEqual(complex(unicode(" 3.14+J ")), 3.14+1j) |
|
236 |
|
237 # SF bug 543840: complex(string) accepts strings with \0 |
|
238 # Fixed in 2.3. |
|
239 self.assertRaises(ValueError, complex, '1+1j\0j') |
|
240 |
|
241 self.assertRaises(TypeError, int, 5+3j) |
|
242 self.assertRaises(TypeError, long, 5+3j) |
|
243 self.assertRaises(TypeError, float, 5+3j) |
|
244 self.assertRaises(ValueError, complex, "") |
|
245 self.assertRaises(TypeError, complex, None) |
|
246 self.assertRaises(ValueError, complex, "\0") |
|
247 self.assertRaises(TypeError, complex, "1", "2") |
|
248 self.assertRaises(TypeError, complex, "1", 42) |
|
249 self.assertRaises(TypeError, complex, 1, "2") |
|
250 self.assertRaises(ValueError, complex, "1+") |
|
251 self.assertRaises(ValueError, complex, "1+1j+1j") |
|
252 self.assertRaises(ValueError, complex, "--") |
|
253 if test_support.have_unicode: |
|
254 self.assertRaises(ValueError, complex, unicode("1"*500)) |
|
255 self.assertRaises(ValueError, complex, unicode("x")) |
|
256 |
|
257 class EvilExc(Exception): |
|
258 pass |
|
259 |
|
260 class evilcomplex: |
|
261 def __complex__(self): |
|
262 raise EvilExc |
|
263 |
|
264 self.assertRaises(EvilExc, complex, evilcomplex()) |
|
265 |
|
266 class float2: |
|
267 def __init__(self, value): |
|
268 self.value = value |
|
269 def __float__(self): |
|
270 return self.value |
|
271 |
|
272 self.assertAlmostEqual(complex(float2(42.)), 42) |
|
273 self.assertAlmostEqual(complex(real=float2(17.), imag=float2(23.)), 17+23j) |
|
274 self.assertRaises(TypeError, complex, float2(None)) |
|
275 |
|
276 class complex0(complex): |
|
277 """Test usage of __complex__() when inheriting from 'complex'""" |
|
278 def __complex__(self): |
|
279 return 42j |
|
280 |
|
281 class complex1(complex): |
|
282 """Test usage of __complex__() with a __new__() method""" |
|
283 def __new__(self, value=0j): |
|
284 return complex.__new__(self, 2*value) |
|
285 def __complex__(self): |
|
286 return self |
|
287 |
|
288 class complex2(complex): |
|
289 """Make sure that __complex__() calls fail if anything other than a |
|
290 complex is returned""" |
|
291 def __complex__(self): |
|
292 return None |
|
293 |
|
294 self.assertAlmostEqual(complex(complex0(1j)), 42j) |
|
295 self.assertAlmostEqual(complex(complex1(1j)), 2j) |
|
296 self.assertRaises(TypeError, complex, complex2(1j)) |
|
297 |
|
298 def test_hash(self): |
|
299 for x in xrange(-30, 30): |
|
300 self.assertEqual(hash(x), hash(complex(x, 0))) |
|
301 x /= 3.0 # now check against floating point |
|
302 self.assertEqual(hash(x), hash(complex(x, 0.))) |
|
303 |
|
304 def test_abs(self): |
|
305 nums = [complex(x/3., y/7.) for x in xrange(-9,9) for y in xrange(-9,9)] |
|
306 for num in nums: |
|
307 self.assertAlmostEqual((num.real**2 + num.imag**2) ** 0.5, abs(num)) |
|
308 |
|
309 def test_repr(self): |
|
310 self.assertEqual(repr(1+6j), '(1+6j)') |
|
311 self.assertEqual(repr(1-6j), '(1-6j)') |
|
312 |
|
313 self.assertNotEqual(repr(-(1+0j)), '(-1+-0j)') |
|
314 |
|
315 def test_neg(self): |
|
316 self.assertEqual(-(1+6j), -1-6j) |
|
317 |
|
318 def test_file(self): |
|
319 a = 3.33+4.43j |
|
320 b = 5.1+2.3j |
|
321 |
|
322 fo = None |
|
323 try: |
|
324 fo = open(test_support.TESTFN, "wb") |
|
325 print >>fo, a, b |
|
326 fo.close() |
|
327 fo = open(test_support.TESTFN, "rb") |
|
328 self.assertEqual(fo.read(), "%s %s\n" % (a, b)) |
|
329 finally: |
|
330 if (fo is not None) and (not fo.closed): |
|
331 fo.close() |
|
332 try: |
|
333 os.remove(test_support.TESTFN) |
|
334 except (OSError, IOError): |
|
335 pass |
|
336 |
|
337 def test_main(): |
|
338 test_support.run_unittest(ComplexTest) |
|
339 |
|
340 if __name__ == "__main__": |
|
341 test_main() |