|
1 # Wrapper module for _ssl, providing some additional facilities |
|
2 # implemented in Python. Written by Bill Janssen. |
|
3 |
|
4 """\ |
|
5 This module provides some more Pythonic support for SSL. |
|
6 |
|
7 Object types: |
|
8 |
|
9 SSLSocket -- subtype of socket.socket which does SSL over the socket |
|
10 |
|
11 Exceptions: |
|
12 |
|
13 SSLError -- exception raised for I/O errors |
|
14 |
|
15 Functions: |
|
16 |
|
17 cert_time_to_seconds -- convert time string used for certificate |
|
18 notBefore and notAfter functions to integer |
|
19 seconds past the Epoch (the time values |
|
20 returned from time.time()) |
|
21 |
|
22 fetch_server_certificate (HOST, PORT) -- fetch the certificate provided |
|
23 by the server running on HOST at port PORT. No |
|
24 validation of the certificate is performed. |
|
25 |
|
26 Integer constants: |
|
27 |
|
28 SSL_ERROR_ZERO_RETURN |
|
29 SSL_ERROR_WANT_READ |
|
30 SSL_ERROR_WANT_WRITE |
|
31 SSL_ERROR_WANT_X509_LOOKUP |
|
32 SSL_ERROR_SYSCALL |
|
33 SSL_ERROR_SSL |
|
34 SSL_ERROR_WANT_CONNECT |
|
35 |
|
36 SSL_ERROR_EOF |
|
37 SSL_ERROR_INVALID_ERROR_CODE |
|
38 |
|
39 The following group define certificate requirements that one side is |
|
40 allowing/requiring from the other side: |
|
41 |
|
42 CERT_NONE - no certificates from the other side are required (or will |
|
43 be looked at if provided) |
|
44 CERT_OPTIONAL - certificates are not required, but if provided will be |
|
45 validated, and if validation fails, the connection will |
|
46 also fail |
|
47 CERT_REQUIRED - certificates are required, and will be validated, and |
|
48 if validation fails, the connection will also fail |
|
49 |
|
50 The following constants identify various SSL protocol variants: |
|
51 |
|
52 PROTOCOL_SSLv2 |
|
53 PROTOCOL_SSLv3 |
|
54 PROTOCOL_SSLv23 |
|
55 PROTOCOL_TLSv1 |
|
56 """ |
|
57 |
|
58 import textwrap |
|
59 |
|
60 import _ssl # if we can't import it, let the error propagate |
|
61 |
|
62 from _ssl import SSLError |
|
63 from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED |
|
64 from _ssl import PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 |
|
65 from _ssl import RAND_status, RAND_egd, RAND_add |
|
66 from _ssl import \ |
|
67 SSL_ERROR_ZERO_RETURN, \ |
|
68 SSL_ERROR_WANT_READ, \ |
|
69 SSL_ERROR_WANT_WRITE, \ |
|
70 SSL_ERROR_WANT_X509_LOOKUP, \ |
|
71 SSL_ERROR_SYSCALL, \ |
|
72 SSL_ERROR_SSL, \ |
|
73 SSL_ERROR_WANT_CONNECT, \ |
|
74 SSL_ERROR_EOF, \ |
|
75 SSL_ERROR_INVALID_ERROR_CODE |
|
76 |
|
77 from socket import socket, _fileobject |
|
78 from socket import getnameinfo as _getnameinfo |
|
79 import base64 # for DER-to-PEM translation |
|
80 |
|
81 class SSLSocket (socket): |
|
82 |
|
83 """This class implements a subtype of socket.socket that wraps |
|
84 the underlying OS socket in an SSL context when necessary, and |
|
85 provides read and write methods over that channel.""" |
|
86 |
|
87 def __init__(self, sock, keyfile=None, certfile=None, |
|
88 server_side=False, cert_reqs=CERT_NONE, |
|
89 ssl_version=PROTOCOL_SSLv23, ca_certs=None, |
|
90 do_handshake_on_connect=True, |
|
91 suppress_ragged_eofs=True): |
|
92 socket.__init__(self, _sock=sock._sock) |
|
93 # the initializer for socket trashes the methods (tsk, tsk), so... |
|
94 self.send = lambda data, flags=0: SSLSocket.send(self, data, flags) |
|
95 self.sendto = lambda data, addr, flags=0: SSLSocket.sendto(self, data, addr, flags) |
|
96 self.recv = lambda buflen=1024, flags=0: SSLSocket.recv(self, buflen, flags) |
|
97 self.recvfrom = lambda addr, buflen=1024, flags=0: SSLSocket.recvfrom(self, addr, buflen, flags) |
|
98 self.recv_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recv_into(self, buffer, nbytes, flags) |
|
99 self.recvfrom_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recvfrom_into(self, buffer, nbytes, flags) |
|
100 |
|
101 if certfile and not keyfile: |
|
102 keyfile = certfile |
|
103 # see if it's connected |
|
104 try: |
|
105 socket.getpeername(self) |
|
106 except: |
|
107 # no, no connection yet |
|
108 self._sslobj = None |
|
109 else: |
|
110 # yes, create the SSL object |
|
111 self._sslobj = _ssl.sslwrap(self._sock, server_side, |
|
112 keyfile, certfile, |
|
113 cert_reqs, ssl_version, ca_certs) |
|
114 if do_handshake_on_connect: |
|
115 timeout = self.gettimeout() |
|
116 try: |
|
117 self.settimeout(None) |
|
118 self.do_handshake() |
|
119 finally: |
|
120 self.settimeout(timeout) |
|
121 self.keyfile = keyfile |
|
122 self.certfile = certfile |
|
123 self.cert_reqs = cert_reqs |
|
124 self.ssl_version = ssl_version |
|
125 self.ca_certs = ca_certs |
|
126 self.do_handshake_on_connect = do_handshake_on_connect |
|
127 self.suppress_ragged_eofs = suppress_ragged_eofs |
|
128 self._makefile_refs = 0 |
|
129 |
|
130 def read(self, len=1024): |
|
131 |
|
132 """Read up to LEN bytes and return them. |
|
133 Return zero-length string on EOF.""" |
|
134 |
|
135 try: |
|
136 return self._sslobj.read(len) |
|
137 except SSLError, x: |
|
138 if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: |
|
139 return '' |
|
140 else: |
|
141 raise |
|
142 |
|
143 def write(self, data): |
|
144 |
|
145 """Write DATA to the underlying SSL channel. Returns |
|
146 number of bytes of DATA actually transmitted.""" |
|
147 |
|
148 return self._sslobj.write(data) |
|
149 |
|
150 def getpeercert(self, binary_form=False): |
|
151 |
|
152 """Returns a formatted version of the data in the |
|
153 certificate provided by the other end of the SSL channel. |
|
154 Return None if no certificate was provided, {} if a |
|
155 certificate was provided, but not validated.""" |
|
156 |
|
157 return self._sslobj.peer_certificate(binary_form) |
|
158 |
|
159 def cipher (self): |
|
160 |
|
161 if not self._sslobj: |
|
162 return None |
|
163 else: |
|
164 return self._sslobj.cipher() |
|
165 |
|
166 def send (self, data, flags=0): |
|
167 if self._sslobj: |
|
168 if flags != 0: |
|
169 raise ValueError( |
|
170 "non-zero flags not allowed in calls to send() on %s" % |
|
171 self.__class__) |
|
172 while True: |
|
173 try: |
|
174 v = self._sslobj.write(data) |
|
175 except SSLError, x: |
|
176 if x.args[0] == SSL_ERROR_WANT_READ: |
|
177 return 0 |
|
178 elif x.args[0] == SSL_ERROR_WANT_WRITE: |
|
179 return 0 |
|
180 else: |
|
181 raise |
|
182 else: |
|
183 return v |
|
184 else: |
|
185 return socket.send(self, data, flags) |
|
186 |
|
187 def sendto (self, data, addr, flags=0): |
|
188 if self._sslobj: |
|
189 raise ValueError("sendto not allowed on instances of %s" % |
|
190 self.__class__) |
|
191 else: |
|
192 return socket.sendto(self, data, addr, flags) |
|
193 |
|
194 def sendall (self, data, flags=0): |
|
195 if self._sslobj: |
|
196 if flags != 0: |
|
197 raise ValueError( |
|
198 "non-zero flags not allowed in calls to sendall() on %s" % |
|
199 self.__class__) |
|
200 amount = len(data) |
|
201 count = 0 |
|
202 while (count < amount): |
|
203 v = self.send(data[count:]) |
|
204 count += v |
|
205 return amount |
|
206 else: |
|
207 return socket.sendall(self, data, flags) |
|
208 |
|
209 def recv (self, buflen=1024, flags=0): |
|
210 if self._sslobj: |
|
211 if flags != 0: |
|
212 raise ValueError( |
|
213 "non-zero flags not allowed in calls to sendall() on %s" % |
|
214 self.__class__) |
|
215 while True: |
|
216 try: |
|
217 return self.read(buflen) |
|
218 except SSLError, x: |
|
219 if x.args[0] == SSL_ERROR_WANT_READ: |
|
220 continue |
|
221 else: |
|
222 raise x |
|
223 else: |
|
224 return socket.recv(self, buflen, flags) |
|
225 |
|
226 def recv_into (self, buffer, nbytes=None, flags=0): |
|
227 if buffer and (nbytes is None): |
|
228 nbytes = len(buffer) |
|
229 elif nbytes is None: |
|
230 nbytes = 1024 |
|
231 if self._sslobj: |
|
232 if flags != 0: |
|
233 raise ValueError( |
|
234 "non-zero flags not allowed in calls to recv_into() on %s" % |
|
235 self.__class__) |
|
236 while True: |
|
237 try: |
|
238 tmp_buffer = self.read(nbytes) |
|
239 v = len(tmp_buffer) |
|
240 buffer[:v] = tmp_buffer |
|
241 return v |
|
242 except SSLError as x: |
|
243 if x.args[0] == SSL_ERROR_WANT_READ: |
|
244 continue |
|
245 else: |
|
246 raise x |
|
247 else: |
|
248 return socket.recv_into(self, buffer, nbytes, flags) |
|
249 |
|
250 def recvfrom (self, addr, buflen=1024, flags=0): |
|
251 if self._sslobj: |
|
252 raise ValueError("recvfrom not allowed on instances of %s" % |
|
253 self.__class__) |
|
254 else: |
|
255 return socket.recvfrom(self, addr, buflen, flags) |
|
256 |
|
257 def recvfrom_into (self, buffer, nbytes=None, flags=0): |
|
258 if self._sslobj: |
|
259 raise ValueError("recvfrom_into not allowed on instances of %s" % |
|
260 self.__class__) |
|
261 else: |
|
262 return socket.recvfrom_into(self, buffer, nbytes, flags) |
|
263 |
|
264 def pending (self): |
|
265 if self._sslobj: |
|
266 return self._sslobj.pending() |
|
267 else: |
|
268 return 0 |
|
269 |
|
270 def unwrap (self): |
|
271 if self._sslobj: |
|
272 s = self._sslobj.shutdown() |
|
273 self._sslobj = None |
|
274 return s |
|
275 else: |
|
276 raise ValueError("No SSL wrapper around " + str(self)) |
|
277 |
|
278 def shutdown (self, how): |
|
279 self._sslobj = None |
|
280 socket.shutdown(self, how) |
|
281 |
|
282 def close (self): |
|
283 if self._makefile_refs < 1: |
|
284 self._sslobj = None |
|
285 socket.close(self) |
|
286 else: |
|
287 self._makefile_refs -= 1 |
|
288 |
|
289 def do_handshake (self): |
|
290 |
|
291 """Perform a TLS/SSL handshake.""" |
|
292 |
|
293 self._sslobj.do_handshake() |
|
294 |
|
295 def connect(self, addr): |
|
296 |
|
297 """Connects to remote ADDR, and then wraps the connection in |
|
298 an SSL channel.""" |
|
299 |
|
300 # Here we assume that the socket is client-side, and not |
|
301 # connected at the time of the call. We connect it, then wrap it. |
|
302 if self._sslobj: |
|
303 raise ValueError("attempt to connect already-connected SSLSocket!") |
|
304 socket.connect(self, addr) |
|
305 self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, |
|
306 self.cert_reqs, self.ssl_version, |
|
307 self.ca_certs) |
|
308 if self.do_handshake_on_connect: |
|
309 self.do_handshake() |
|
310 |
|
311 def accept(self): |
|
312 |
|
313 """Accepts a new connection from a remote client, and returns |
|
314 a tuple containing that new connection wrapped with a server-side |
|
315 SSL channel, and the address of the remote client.""" |
|
316 |
|
317 newsock, addr = socket.accept(self) |
|
318 return (SSLSocket(newsock, |
|
319 keyfile=self.keyfile, |
|
320 certfile=self.certfile, |
|
321 server_side=True, |
|
322 cert_reqs=self.cert_reqs, |
|
323 ssl_version=self.ssl_version, |
|
324 ca_certs=self.ca_certs, |
|
325 do_handshake_on_connect=self.do_handshake_on_connect, |
|
326 suppress_ragged_eofs=self.suppress_ragged_eofs), |
|
327 addr) |
|
328 |
|
329 def makefile(self, mode='r', bufsize=-1): |
|
330 |
|
331 """Make and return a file-like object that |
|
332 works with the SSL connection. Just use the code |
|
333 from the socket module.""" |
|
334 |
|
335 self._makefile_refs += 1 |
|
336 return _fileobject(self, mode, bufsize) |
|
337 |
|
338 |
|
339 |
|
340 def wrap_socket(sock, keyfile=None, certfile=None, |
|
341 server_side=False, cert_reqs=CERT_NONE, |
|
342 ssl_version=PROTOCOL_SSLv23, ca_certs=None, |
|
343 do_handshake_on_connect=True, |
|
344 suppress_ragged_eofs=True): |
|
345 |
|
346 return SSLSocket(sock, keyfile=keyfile, certfile=certfile, |
|
347 server_side=server_side, cert_reqs=cert_reqs, |
|
348 ssl_version=ssl_version, ca_certs=ca_certs, |
|
349 do_handshake_on_connect=do_handshake_on_connect, |
|
350 suppress_ragged_eofs=suppress_ragged_eofs) |
|
351 |
|
352 |
|
353 # some utility functions |
|
354 |
|
355 def cert_time_to_seconds(cert_time): |
|
356 |
|
357 """Takes a date-time string in standard ASN1_print form |
|
358 ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return |
|
359 a Python time value in seconds past the epoch.""" |
|
360 |
|
361 import time |
|
362 return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT")) |
|
363 |
|
364 PEM_HEADER = "-----BEGIN CERTIFICATE-----" |
|
365 PEM_FOOTER = "-----END CERTIFICATE-----" |
|
366 |
|
367 def DER_cert_to_PEM_cert(der_cert_bytes): |
|
368 |
|
369 """Takes a certificate in binary DER format and returns the |
|
370 PEM version of it as a string.""" |
|
371 |
|
372 if hasattr(base64, 'standard_b64encode'): |
|
373 # preferred because older API gets line-length wrong |
|
374 f = base64.standard_b64encode(der_cert_bytes) |
|
375 return (PEM_HEADER + '\n' + |
|
376 textwrap.fill(f, 64) + |
|
377 PEM_FOOTER + '\n') |
|
378 else: |
|
379 return (PEM_HEADER + '\n' + |
|
380 base64.encodestring(der_cert_bytes) + |
|
381 PEM_FOOTER + '\n') |
|
382 |
|
383 def PEM_cert_to_DER_cert(pem_cert_string): |
|
384 |
|
385 """Takes a certificate in ASCII PEM format and returns the |
|
386 DER-encoded version of it as a byte sequence""" |
|
387 |
|
388 if not pem_cert_string.startswith(PEM_HEADER): |
|
389 raise ValueError("Invalid PEM encoding; must start with %s" |
|
390 % PEM_HEADER) |
|
391 if not pem_cert_string.strip().endswith(PEM_FOOTER): |
|
392 raise ValueError("Invalid PEM encoding; must end with %s" |
|
393 % PEM_FOOTER) |
|
394 d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] |
|
395 return base64.decodestring(d) |
|
396 |
|
397 def get_server_certificate (addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): |
|
398 |
|
399 """Retrieve the certificate from the server at the specified address, |
|
400 and return it as a PEM-encoded string. |
|
401 If 'ca_certs' is specified, validate the server cert against it. |
|
402 If 'ssl_version' is specified, use it in the connection attempt.""" |
|
403 |
|
404 host, port = addr |
|
405 if (ca_certs is not None): |
|
406 cert_reqs = CERT_REQUIRED |
|
407 else: |
|
408 cert_reqs = CERT_NONE |
|
409 s = wrap_socket(socket(), ssl_version=ssl_version, |
|
410 cert_reqs=cert_reqs, ca_certs=ca_certs) |
|
411 s.connect(addr) |
|
412 dercert = s.getpeercert(True) |
|
413 s.close() |
|
414 return DER_cert_to_PEM_cert(dercert) |
|
415 |
|
416 def get_protocol_name (protocol_code): |
|
417 if protocol_code == PROTOCOL_TLSv1: |
|
418 return "TLSv1" |
|
419 elif protocol_code == PROTOCOL_SSLv23: |
|
420 return "SSLv23" |
|
421 elif protocol_code == PROTOCOL_SSLv2: |
|
422 return "SSLv2" |
|
423 elif protocol_code == PROTOCOL_SSLv3: |
|
424 return "SSLv3" |
|
425 else: |
|
426 return "<unknown>" |
|
427 |
|
428 |
|
429 # a replacement for the old socket.ssl function |
|
430 |
|
431 def sslwrap_simple (sock, keyfile=None, certfile=None): |
|
432 |
|
433 """A replacement for the old socket.ssl function. Designed |
|
434 for compability with Python 2.5 and earlier. Will disappear in |
|
435 Python 3.0.""" |
|
436 |
|
437 if hasattr(sock, "_sock"): |
|
438 sock = sock._sock |
|
439 |
|
440 ssl_sock = _ssl.sslwrap(sock, 0, keyfile, certfile, CERT_NONE, |
|
441 PROTOCOL_SSLv23, None) |
|
442 try: |
|
443 sock.getpeername() |
|
444 except: |
|
445 # no, no connection yet |
|
446 pass |
|
447 else: |
|
448 # yes, do the handshake |
|
449 ssl_sock.do_handshake() |
|
450 |
|
451 return ssl_sock |