python-2.5.2/win32/Lib/test/test_socket.py
changeset 0 ae805ac0140d
equal deleted inserted replaced
-1:000000000000 0:ae805ac0140d
       
     1 #!/usr/bin/env python
       
     2 
       
     3 import unittest
       
     4 from test import test_support
       
     5 
       
     6 import socket
       
     7 import select
       
     8 import time
       
     9 import thread, threading
       
    10 import Queue
       
    11 import sys
       
    12 import array
       
    13 from weakref import proxy
       
    14 import signal
       
    15 
       
    16 PORT = 50007
       
    17 HOST = 'localhost'
       
    18 MSG = 'Michael Gilfix was here\n'
       
    19 
       
    20 class SocketTCPTest(unittest.TestCase):
       
    21 
       
    22     def setUp(self):
       
    23         self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
       
    24         self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
       
    25         global PORT
       
    26         PORT = test_support.bind_port(self.serv, HOST, PORT)
       
    27         self.serv.listen(1)
       
    28 
       
    29     def tearDown(self):
       
    30         self.serv.close()
       
    31         self.serv = None
       
    32 
       
    33 class SocketUDPTest(unittest.TestCase):
       
    34 
       
    35     def setUp(self):
       
    36         self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
       
    37         self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
       
    38         global PORT
       
    39         PORT = test_support.bind_port(self.serv, HOST, PORT)
       
    40 
       
    41     def tearDown(self):
       
    42         self.serv.close()
       
    43         self.serv = None
       
    44 
       
    45 class ThreadableTest:
       
    46     """Threadable Test class
       
    47 
       
    48     The ThreadableTest class makes it easy to create a threaded
       
    49     client/server pair from an existing unit test. To create a
       
    50     new threaded class from an existing unit test, use multiple
       
    51     inheritance:
       
    52 
       
    53         class NewClass (OldClass, ThreadableTest):
       
    54             pass
       
    55 
       
    56     This class defines two new fixture functions with obvious
       
    57     purposes for overriding:
       
    58 
       
    59         clientSetUp ()
       
    60         clientTearDown ()
       
    61 
       
    62     Any new test functions within the class must then define
       
    63     tests in pairs, where the test name is preceeded with a
       
    64     '_' to indicate the client portion of the test. Ex:
       
    65 
       
    66         def testFoo(self):
       
    67             # Server portion
       
    68 
       
    69         def _testFoo(self):
       
    70             # Client portion
       
    71 
       
    72     Any exceptions raised by the clients during their tests
       
    73     are caught and transferred to the main thread to alert
       
    74     the testing framework.
       
    75 
       
    76     Note, the server setup function cannot call any blocking
       
    77     functions that rely on the client thread during setup,
       
    78     unless serverExplicityReady() is called just before
       
    79     the blocking call (such as in setting up a client/server
       
    80     connection and performing the accept() in setUp().
       
    81     """
       
    82 
       
    83     def __init__(self):
       
    84         # Swap the true setup function
       
    85         self.__setUp = self.setUp
       
    86         self.__tearDown = self.tearDown
       
    87         self.setUp = self._setUp
       
    88         self.tearDown = self._tearDown
       
    89 
       
    90     def serverExplicitReady(self):
       
    91         """This method allows the server to explicitly indicate that
       
    92         it wants the client thread to proceed. This is useful if the
       
    93         server is about to execute a blocking routine that is
       
    94         dependent upon the client thread during its setup routine."""
       
    95         self.server_ready.set()
       
    96 
       
    97     def _setUp(self):
       
    98         self.server_ready = threading.Event()
       
    99         self.client_ready = threading.Event()
       
   100         self.done = threading.Event()
       
   101         self.queue = Queue.Queue(1)
       
   102 
       
   103         # Do some munging to start the client test.
       
   104         methodname = self.id()
       
   105         i = methodname.rfind('.')
       
   106         methodname = methodname[i+1:]
       
   107         test_method = getattr(self, '_' + methodname)
       
   108         self.client_thread = thread.start_new_thread(
       
   109             self.clientRun, (test_method,))
       
   110 
       
   111         self.__setUp()
       
   112         if not self.server_ready.isSet():
       
   113             self.server_ready.set()
       
   114         self.client_ready.wait()
       
   115 
       
   116     def _tearDown(self):
       
   117         self.__tearDown()
       
   118         self.done.wait()
       
   119 
       
   120         if not self.queue.empty():
       
   121             msg = self.queue.get()
       
   122             self.fail(msg)
       
   123 
       
   124     def clientRun(self, test_func):
       
   125         self.server_ready.wait()
       
   126         self.client_ready.set()
       
   127         self.clientSetUp()
       
   128         if not callable(test_func):
       
   129             raise TypeError, "test_func must be a callable function"
       
   130         try:
       
   131             test_func()
       
   132         except Exception, strerror:
       
   133             self.queue.put(strerror)
       
   134         self.clientTearDown()
       
   135 
       
   136     def clientSetUp(self):
       
   137         raise NotImplementedError, "clientSetUp must be implemented."
       
   138 
       
   139     def clientTearDown(self):
       
   140         self.done.set()
       
   141         thread.exit()
       
   142 
       
   143 class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest):
       
   144 
       
   145     def __init__(self, methodName='runTest'):
       
   146         SocketTCPTest.__init__(self, methodName=methodName)
       
   147         ThreadableTest.__init__(self)
       
   148 
       
   149     def clientSetUp(self):
       
   150         self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
       
   151 
       
   152     def clientTearDown(self):
       
   153         self.cli.close()
       
   154         self.cli = None
       
   155         ThreadableTest.clientTearDown(self)
       
   156 
       
   157 class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
       
   158 
       
   159     def __init__(self, methodName='runTest'):
       
   160         SocketUDPTest.__init__(self, methodName=methodName)
       
   161         ThreadableTest.__init__(self)
       
   162 
       
   163     def clientSetUp(self):
       
   164         self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
       
   165 
       
   166 class SocketConnectedTest(ThreadedTCPSocketTest):
       
   167 
       
   168     def __init__(self, methodName='runTest'):
       
   169         ThreadedTCPSocketTest.__init__(self, methodName=methodName)
       
   170 
       
   171     def setUp(self):
       
   172         ThreadedTCPSocketTest.setUp(self)
       
   173         # Indicate explicitly we're ready for the client thread to
       
   174         # proceed and then perform the blocking call to accept
       
   175         self.serverExplicitReady()
       
   176         conn, addr = self.serv.accept()
       
   177         self.cli_conn = conn
       
   178 
       
   179     def tearDown(self):
       
   180         self.cli_conn.close()
       
   181         self.cli_conn = None
       
   182         ThreadedTCPSocketTest.tearDown(self)
       
   183 
       
   184     def clientSetUp(self):
       
   185         ThreadedTCPSocketTest.clientSetUp(self)
       
   186         self.cli.connect((HOST, PORT))
       
   187         self.serv_conn = self.cli
       
   188 
       
   189     def clientTearDown(self):
       
   190         self.serv_conn.close()
       
   191         self.serv_conn = None
       
   192         ThreadedTCPSocketTest.clientTearDown(self)
       
   193 
       
   194 class SocketPairTest(unittest.TestCase, ThreadableTest):
       
   195 
       
   196     def __init__(self, methodName='runTest'):
       
   197         unittest.TestCase.__init__(self, methodName=methodName)
       
   198         ThreadableTest.__init__(self)
       
   199 
       
   200     def setUp(self):
       
   201         self.serv, self.cli = socket.socketpair()
       
   202 
       
   203     def tearDown(self):
       
   204         self.serv.close()
       
   205         self.serv = None
       
   206 
       
   207     def clientSetUp(self):
       
   208         pass
       
   209 
       
   210     def clientTearDown(self):
       
   211         self.cli.close()
       
   212         self.cli = None
       
   213         ThreadableTest.clientTearDown(self)
       
   214 
       
   215 
       
   216 #######################################################################
       
   217 ## Begin Tests
       
   218 
       
   219 class GeneralModuleTests(unittest.TestCase):
       
   220 
       
   221     def test_weakref(self):
       
   222         s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
       
   223         p = proxy(s)
       
   224         self.assertEqual(p.fileno(), s.fileno())
       
   225         s.close()
       
   226         s = None
       
   227         try:
       
   228             p.fileno()
       
   229         except ReferenceError:
       
   230             pass
       
   231         else:
       
   232             self.fail('Socket proxy still exists')
       
   233 
       
   234     def testSocketError(self):
       
   235         # Testing socket module exceptions
       
   236         def raise_error(*args, **kwargs):
       
   237             raise socket.error
       
   238         def raise_herror(*args, **kwargs):
       
   239             raise socket.herror
       
   240         def raise_gaierror(*args, **kwargs):
       
   241             raise socket.gaierror
       
   242         self.failUnlessRaises(socket.error, raise_error,
       
   243                               "Error raising socket exception.")
       
   244         self.failUnlessRaises(socket.error, raise_herror,
       
   245                               "Error raising socket exception.")
       
   246         self.failUnlessRaises(socket.error, raise_gaierror,
       
   247                               "Error raising socket exception.")
       
   248 
       
   249     def testCrucialConstants(self):
       
   250         # Testing for mission critical constants
       
   251         socket.AF_INET
       
   252         socket.SOCK_STREAM
       
   253         socket.SOCK_DGRAM
       
   254         socket.SOCK_RAW
       
   255         socket.SOCK_RDM
       
   256         socket.SOCK_SEQPACKET
       
   257         socket.SOL_SOCKET
       
   258         socket.SO_REUSEADDR
       
   259 
       
   260     def testHostnameRes(self):
       
   261         # Testing hostname resolution mechanisms
       
   262         hostname = socket.gethostname()
       
   263         try:
       
   264             ip = socket.gethostbyname(hostname)
       
   265         except socket.error:
       
   266             # Probably name lookup wasn't set up right; skip this test
       
   267             return
       
   268         self.assert_(ip.find('.') >= 0, "Error resolving host to ip.")
       
   269         try:
       
   270             hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
       
   271         except socket.error:
       
   272             # Probably a similar problem as above; skip this test
       
   273             return
       
   274         all_host_names = [hostname, hname] + aliases
       
   275         fqhn = socket.getfqdn(ip)
       
   276         if not fqhn in all_host_names:
       
   277             self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
       
   278 
       
   279     def testRefCountGetNameInfo(self):
       
   280         # Testing reference count for getnameinfo
       
   281         import sys
       
   282         if hasattr(sys, "getrefcount"):
       
   283             try:
       
   284                 # On some versions, this loses a reference
       
   285                 orig = sys.getrefcount(__name__)
       
   286                 socket.getnameinfo(__name__,0)
       
   287             except SystemError:
       
   288                 if sys.getrefcount(__name__) <> orig:
       
   289                     self.fail("socket.getnameinfo loses a reference")
       
   290 
       
   291     def testInterpreterCrash(self):
       
   292         # Making sure getnameinfo doesn't crash the interpreter
       
   293         try:
       
   294             # On some versions, this crashes the interpreter.
       
   295             socket.getnameinfo(('x', 0, 0, 0), 0)
       
   296         except socket.error:
       
   297             pass
       
   298 
       
   299     def testNtoH(self):
       
   300         # This just checks that htons etc. are their own inverse,
       
   301         # when looking at the lower 16 or 32 bits.
       
   302         sizes = {socket.htonl: 32, socket.ntohl: 32,
       
   303                  socket.htons: 16, socket.ntohs: 16}
       
   304         for func, size in sizes.items():
       
   305             mask = (1L<<size) - 1
       
   306             for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210):
       
   307                 self.assertEqual(i & mask, func(func(i&mask)) & mask)
       
   308 
       
   309             swapped = func(mask)
       
   310             self.assertEqual(swapped & mask, mask)
       
   311             self.assertRaises(OverflowError, func, 1L<<34)
       
   312 
       
   313     def testGetServBy(self):
       
   314         eq = self.assertEqual
       
   315         # Find one service that exists, then check all the related interfaces.
       
   316         # I've ordered this by protocols that have both a tcp and udp
       
   317         # protocol, at least for modern Linuxes.
       
   318         if sys.platform in ('linux2', 'freebsd4', 'freebsd5', 'freebsd6',
       
   319                             'freebsd7', 'darwin'):
       
   320             # avoid the 'echo' service on this platform, as there is an
       
   321             # assumption breaking non-standard port/protocol entry
       
   322             services = ('daytime', 'qotd', 'domain')
       
   323         else:
       
   324             services = ('echo', 'daytime', 'domain')
       
   325         for service in services:
       
   326             try:
       
   327                 port = socket.getservbyname(service, 'tcp')
       
   328                 break
       
   329             except socket.error:
       
   330                 pass
       
   331         else:
       
   332             raise socket.error
       
   333         # Try same call with optional protocol omitted
       
   334         port2 = socket.getservbyname(service)
       
   335         eq(port, port2)
       
   336         # Try udp, but don't barf it it doesn't exist
       
   337         try:
       
   338             udpport = socket.getservbyname(service, 'udp')
       
   339         except socket.error:
       
   340             udpport = None
       
   341         else:
       
   342             eq(udpport, port)
       
   343         # Now make sure the lookup by port returns the same service name
       
   344         eq(socket.getservbyport(port2), service)
       
   345         eq(socket.getservbyport(port, 'tcp'), service)
       
   346         if udpport is not None:
       
   347             eq(socket.getservbyport(udpport, 'udp'), service)
       
   348 
       
   349     def testDefaultTimeout(self):
       
   350         # Testing default timeout
       
   351         # The default timeout should initially be None
       
   352         self.assertEqual(socket.getdefaulttimeout(), None)
       
   353         s = socket.socket()
       
   354         self.assertEqual(s.gettimeout(), None)
       
   355         s.close()
       
   356 
       
   357         # Set the default timeout to 10, and see if it propagates
       
   358         socket.setdefaulttimeout(10)
       
   359         self.assertEqual(socket.getdefaulttimeout(), 10)
       
   360         s = socket.socket()
       
   361         self.assertEqual(s.gettimeout(), 10)
       
   362         s.close()
       
   363 
       
   364         # Reset the default timeout to None, and see if it propagates
       
   365         socket.setdefaulttimeout(None)
       
   366         self.assertEqual(socket.getdefaulttimeout(), None)
       
   367         s = socket.socket()
       
   368         self.assertEqual(s.gettimeout(), None)
       
   369         s.close()
       
   370 
       
   371         # Check that setting it to an invalid value raises ValueError
       
   372         self.assertRaises(ValueError, socket.setdefaulttimeout, -1)
       
   373 
       
   374         # Check that setting it to an invalid type raises TypeError
       
   375         self.assertRaises(TypeError, socket.setdefaulttimeout, "spam")
       
   376 
       
   377     def testIPv4toString(self):
       
   378         if not hasattr(socket, 'inet_pton'):
       
   379             return # No inet_pton() on this platform
       
   380         from socket import inet_aton as f, inet_pton, AF_INET
       
   381         g = lambda a: inet_pton(AF_INET, a)
       
   382 
       
   383         self.assertEquals('\x00\x00\x00\x00', f('0.0.0.0'))
       
   384         self.assertEquals('\xff\x00\xff\x00', f('255.0.255.0'))
       
   385         self.assertEquals('\xaa\xaa\xaa\xaa', f('170.170.170.170'))
       
   386         self.assertEquals('\x01\x02\x03\x04', f('1.2.3.4'))
       
   387         self.assertEquals('\xff\xff\xff\xff', f('255.255.255.255'))
       
   388 
       
   389         self.assertEquals('\x00\x00\x00\x00', g('0.0.0.0'))
       
   390         self.assertEquals('\xff\x00\xff\x00', g('255.0.255.0'))
       
   391         self.assertEquals('\xaa\xaa\xaa\xaa', g('170.170.170.170'))
       
   392         self.assertEquals('\xff\xff\xff\xff', g('255.255.255.255'))
       
   393 
       
   394     def testIPv6toString(self):
       
   395         if not hasattr(socket, 'inet_pton'):
       
   396             return # No inet_pton() on this platform
       
   397         try:
       
   398             from socket import inet_pton, AF_INET6, has_ipv6
       
   399             if not has_ipv6:
       
   400                 return
       
   401         except ImportError:
       
   402             return
       
   403         f = lambda a: inet_pton(AF_INET6, a)
       
   404 
       
   405         self.assertEquals('\x00' * 16, f('::'))
       
   406         self.assertEquals('\x00' * 16, f('0::0'))
       
   407         self.assertEquals('\x00\x01' + '\x00' * 14, f('1::'))
       
   408         self.assertEquals(
       
   409             '\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
       
   410             f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae')
       
   411         )
       
   412 
       
   413     def testStringToIPv4(self):
       
   414         if not hasattr(socket, 'inet_ntop'):
       
   415             return # No inet_ntop() on this platform
       
   416         from socket import inet_ntoa as f, inet_ntop, AF_INET
       
   417         g = lambda a: inet_ntop(AF_INET, a)
       
   418 
       
   419         self.assertEquals('1.0.1.0', f('\x01\x00\x01\x00'))
       
   420         self.assertEquals('170.85.170.85', f('\xaa\x55\xaa\x55'))
       
   421         self.assertEquals('255.255.255.255', f('\xff\xff\xff\xff'))
       
   422         self.assertEquals('1.2.3.4', f('\x01\x02\x03\x04'))
       
   423 
       
   424         self.assertEquals('1.0.1.0', g('\x01\x00\x01\x00'))
       
   425         self.assertEquals('170.85.170.85', g('\xaa\x55\xaa\x55'))
       
   426         self.assertEquals('255.255.255.255', g('\xff\xff\xff\xff'))
       
   427 
       
   428     def testStringToIPv6(self):
       
   429         if not hasattr(socket, 'inet_ntop'):
       
   430             return # No inet_ntop() on this platform
       
   431         try:
       
   432             from socket import inet_ntop, AF_INET6, has_ipv6
       
   433             if not has_ipv6:
       
   434                 return
       
   435         except ImportError:
       
   436             return
       
   437         f = lambda a: inet_ntop(AF_INET6, a)
       
   438 
       
   439         self.assertEquals('::', f('\x00' * 16))
       
   440         self.assertEquals('::1', f('\x00' * 15 + '\x01'))
       
   441         self.assertEquals(
       
   442             'aef:b01:506:1001:ffff:9997:55:170',
       
   443             f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70')
       
   444         )
       
   445 
       
   446     # XXX The following don't test module-level functionality...
       
   447 
       
   448     def testSockName(self):
       
   449         # Testing getsockname()
       
   450         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
       
   451         sock.bind(("0.0.0.0", PORT+1))
       
   452         name = sock.getsockname()
       
   453         # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate
       
   454         # it reasonable to get the host's addr in addition to 0.0.0.0.
       
   455         # At least for eCos.  This is required for the S/390 to pass.
       
   456         my_ip_addr = socket.gethostbyname(socket.gethostname())
       
   457         self.assert_(name[0] in ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
       
   458         self.assertEqual(name[1], PORT+1)
       
   459 
       
   460     def testGetSockOpt(self):
       
   461         # Testing getsockopt()
       
   462         # We know a socket should start without reuse==0
       
   463         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
       
   464         reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
       
   465         self.failIf(reuse != 0, "initial mode is reuse")
       
   466 
       
   467     def testSetSockOpt(self):
       
   468         # Testing setsockopt()
       
   469         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
       
   470         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
       
   471         reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
       
   472         self.failIf(reuse == 0, "failed to set reuse mode")
       
   473 
       
   474     def testSendAfterClose(self):
       
   475         # testing send() after close() with timeout
       
   476         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
       
   477         sock.settimeout(1)
       
   478         sock.close()
       
   479         self.assertRaises(socket.error, sock.send, "spam")
       
   480 
       
   481     def testNewAttributes(self):
       
   482         # testing .family, .type and .protocol
       
   483         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
       
   484         self.assertEqual(sock.family, socket.AF_INET)
       
   485         self.assertEqual(sock.type, socket.SOCK_STREAM)
       
   486         self.assertEqual(sock.proto, 0)
       
   487         sock.close()
       
   488 
       
   489 class BasicTCPTest(SocketConnectedTest):
       
   490 
       
   491     def __init__(self, methodName='runTest'):
       
   492         SocketConnectedTest.__init__(self, methodName=methodName)
       
   493 
       
   494     def testRecv(self):
       
   495         # Testing large receive over TCP
       
   496         msg = self.cli_conn.recv(1024)
       
   497         self.assertEqual(msg, MSG)
       
   498 
       
   499     def _testRecv(self):
       
   500         self.serv_conn.send(MSG)
       
   501 
       
   502     def testOverFlowRecv(self):
       
   503         # Testing receive in chunks over TCP
       
   504         seg1 = self.cli_conn.recv(len(MSG) - 3)
       
   505         seg2 = self.cli_conn.recv(1024)
       
   506         msg = seg1 + seg2
       
   507         self.assertEqual(msg, MSG)
       
   508 
       
   509     def _testOverFlowRecv(self):
       
   510         self.serv_conn.send(MSG)
       
   511 
       
   512     def testRecvFrom(self):
       
   513         # Testing large recvfrom() over TCP
       
   514         msg, addr = self.cli_conn.recvfrom(1024)
       
   515         self.assertEqual(msg, MSG)
       
   516 
       
   517     def _testRecvFrom(self):
       
   518         self.serv_conn.send(MSG)
       
   519 
       
   520     def testOverFlowRecvFrom(self):
       
   521         # Testing recvfrom() in chunks over TCP
       
   522         seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
       
   523         seg2, addr = self.cli_conn.recvfrom(1024)
       
   524         msg = seg1 + seg2
       
   525         self.assertEqual(msg, MSG)
       
   526 
       
   527     def _testOverFlowRecvFrom(self):
       
   528         self.serv_conn.send(MSG)
       
   529 
       
   530     def testSendAll(self):
       
   531         # Testing sendall() with a 2048 byte string over TCP
       
   532         msg = ''
       
   533         while 1:
       
   534             read = self.cli_conn.recv(1024)
       
   535             if not read:
       
   536                 break
       
   537             msg += read
       
   538         self.assertEqual(msg, 'f' * 2048)
       
   539 
       
   540     def _testSendAll(self):
       
   541         big_chunk = 'f' * 2048
       
   542         self.serv_conn.sendall(big_chunk)
       
   543 
       
   544     def testFromFd(self):
       
   545         # Testing fromfd()
       
   546         if not hasattr(socket, "fromfd"):
       
   547             return # On Windows, this doesn't exist
       
   548         fd = self.cli_conn.fileno()
       
   549         sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
       
   550         msg = sock.recv(1024)
       
   551         self.assertEqual(msg, MSG)
       
   552 
       
   553     def _testFromFd(self):
       
   554         self.serv_conn.send(MSG)
       
   555 
       
   556     def testShutdown(self):
       
   557         # Testing shutdown()
       
   558         msg = self.cli_conn.recv(1024)
       
   559         self.assertEqual(msg, MSG)
       
   560 
       
   561     def _testShutdown(self):
       
   562         self.serv_conn.send(MSG)
       
   563         self.serv_conn.shutdown(2)
       
   564 
       
   565 class BasicUDPTest(ThreadedUDPSocketTest):
       
   566 
       
   567     def __init__(self, methodName='runTest'):
       
   568         ThreadedUDPSocketTest.__init__(self, methodName=methodName)
       
   569 
       
   570     def testSendtoAndRecv(self):
       
   571         # Testing sendto() and Recv() over UDP
       
   572         msg = self.serv.recv(len(MSG))
       
   573         self.assertEqual(msg, MSG)
       
   574 
       
   575     def _testSendtoAndRecv(self):
       
   576         self.cli.sendto(MSG, 0, (HOST, PORT))
       
   577 
       
   578     def testRecvFrom(self):
       
   579         # Testing recvfrom() over UDP
       
   580         msg, addr = self.serv.recvfrom(len(MSG))
       
   581         self.assertEqual(msg, MSG)
       
   582 
       
   583     def _testRecvFrom(self):
       
   584         self.cli.sendto(MSG, 0, (HOST, PORT))
       
   585 
       
   586     def testRecvFromNegative(self):
       
   587         # Negative lengths passed to recvfrom should give ValueError.
       
   588         self.assertRaises(ValueError, self.serv.recvfrom, -1)
       
   589 
       
   590     def _testRecvFromNegative(self):
       
   591         self.cli.sendto(MSG, 0, (HOST, PORT))
       
   592 
       
   593 class TCPCloserTest(ThreadedTCPSocketTest):
       
   594 
       
   595     def testClose(self):
       
   596         conn, addr = self.serv.accept()
       
   597         conn.close()
       
   598 
       
   599         sd = self.cli
       
   600         read, write, err = select.select([sd], [], [], 1.0)
       
   601         self.assertEqual(read, [sd])
       
   602         self.assertEqual(sd.recv(1), '')
       
   603 
       
   604     def _testClose(self):
       
   605         self.cli.connect((HOST, PORT))
       
   606         time.sleep(1.0)
       
   607 
       
   608 class BasicSocketPairTest(SocketPairTest):
       
   609 
       
   610     def __init__(self, methodName='runTest'):
       
   611         SocketPairTest.__init__(self, methodName=methodName)
       
   612 
       
   613     def testRecv(self):
       
   614         msg = self.serv.recv(1024)
       
   615         self.assertEqual(msg, MSG)
       
   616 
       
   617     def _testRecv(self):
       
   618         self.cli.send(MSG)
       
   619 
       
   620     def testSend(self):
       
   621         self.serv.send(MSG)
       
   622 
       
   623     def _testSend(self):
       
   624         msg = self.cli.recv(1024)
       
   625         self.assertEqual(msg, MSG)
       
   626 
       
   627 class NonBlockingTCPTests(ThreadedTCPSocketTest):
       
   628 
       
   629     def __init__(self, methodName='runTest'):
       
   630         ThreadedTCPSocketTest.__init__(self, methodName=methodName)
       
   631 
       
   632     def testSetBlocking(self):
       
   633         # Testing whether set blocking works
       
   634         self.serv.setblocking(0)
       
   635         start = time.time()
       
   636         try:
       
   637             self.serv.accept()
       
   638         except socket.error:
       
   639             pass
       
   640         end = time.time()
       
   641         self.assert_((end - start) < 1.0, "Error setting non-blocking mode.")
       
   642 
       
   643     def _testSetBlocking(self):
       
   644         pass
       
   645 
       
   646     def testAccept(self):
       
   647         # Testing non-blocking accept
       
   648         self.serv.setblocking(0)
       
   649         try:
       
   650             conn, addr = self.serv.accept()
       
   651         except socket.error:
       
   652             pass
       
   653         else:
       
   654             self.fail("Error trying to do non-blocking accept.")
       
   655         read, write, err = select.select([self.serv], [], [])
       
   656         if self.serv in read:
       
   657             conn, addr = self.serv.accept()
       
   658         else:
       
   659             self.fail("Error trying to do accept after select.")
       
   660 
       
   661     def _testAccept(self):
       
   662         time.sleep(0.1)
       
   663         self.cli.connect((HOST, PORT))
       
   664 
       
   665     def testConnect(self):
       
   666         # Testing non-blocking connect
       
   667         conn, addr = self.serv.accept()
       
   668 
       
   669     def _testConnect(self):
       
   670         self.cli.settimeout(10)
       
   671         self.cli.connect((HOST, PORT))
       
   672 
       
   673     def testRecv(self):
       
   674         # Testing non-blocking recv
       
   675         conn, addr = self.serv.accept()
       
   676         conn.setblocking(0)
       
   677         try:
       
   678             msg = conn.recv(len(MSG))
       
   679         except socket.error:
       
   680             pass
       
   681         else:
       
   682             self.fail("Error trying to do non-blocking recv.")
       
   683         read, write, err = select.select([conn], [], [])
       
   684         if conn in read:
       
   685             msg = conn.recv(len(MSG))
       
   686             self.assertEqual(msg, MSG)
       
   687         else:
       
   688             self.fail("Error during select call to non-blocking socket.")
       
   689 
       
   690     def _testRecv(self):
       
   691         self.cli.connect((HOST, PORT))
       
   692         time.sleep(0.1)
       
   693         self.cli.send(MSG)
       
   694 
       
   695 class FileObjectClassTestCase(SocketConnectedTest):
       
   696 
       
   697     bufsize = -1 # Use default buffer size
       
   698 
       
   699     def __init__(self, methodName='runTest'):
       
   700         SocketConnectedTest.__init__(self, methodName=methodName)
       
   701 
       
   702     def setUp(self):
       
   703         SocketConnectedTest.setUp(self)
       
   704         self.serv_file = self.cli_conn.makefile('rb', self.bufsize)
       
   705 
       
   706     def tearDown(self):
       
   707         self.serv_file.close()
       
   708         self.assert_(self.serv_file.closed)
       
   709         self.serv_file = None
       
   710         SocketConnectedTest.tearDown(self)
       
   711 
       
   712     def clientSetUp(self):
       
   713         SocketConnectedTest.clientSetUp(self)
       
   714         self.cli_file = self.serv_conn.makefile('wb')
       
   715 
       
   716     def clientTearDown(self):
       
   717         self.cli_file.close()
       
   718         self.assert_(self.cli_file.closed)
       
   719         self.cli_file = None
       
   720         SocketConnectedTest.clientTearDown(self)
       
   721 
       
   722     def testSmallRead(self):
       
   723         # Performing small file read test
       
   724         first_seg = self.serv_file.read(len(MSG)-3)
       
   725         second_seg = self.serv_file.read(3)
       
   726         msg = first_seg + second_seg
       
   727         self.assertEqual(msg, MSG)
       
   728 
       
   729     def _testSmallRead(self):
       
   730         self.cli_file.write(MSG)
       
   731         self.cli_file.flush()
       
   732 
       
   733     def testFullRead(self):
       
   734         # read until EOF
       
   735         msg = self.serv_file.read()
       
   736         self.assertEqual(msg, MSG)
       
   737 
       
   738     def _testFullRead(self):
       
   739         self.cli_file.write(MSG)
       
   740         self.cli_file.close()
       
   741 
       
   742     def testUnbufferedRead(self):
       
   743         # Performing unbuffered file read test
       
   744         buf = ''
       
   745         while 1:
       
   746             char = self.serv_file.read(1)
       
   747             if not char:
       
   748                 break
       
   749             buf += char
       
   750         self.assertEqual(buf, MSG)
       
   751 
       
   752     def _testUnbufferedRead(self):
       
   753         self.cli_file.write(MSG)
       
   754         self.cli_file.flush()
       
   755 
       
   756     def testReadline(self):
       
   757         # Performing file readline test
       
   758         line = self.serv_file.readline()
       
   759         self.assertEqual(line, MSG)
       
   760 
       
   761     def _testReadline(self):
       
   762         self.cli_file.write(MSG)
       
   763         self.cli_file.flush()
       
   764 
       
   765     def testClosedAttr(self):
       
   766         self.assert_(not self.serv_file.closed)
       
   767 
       
   768     def _testClosedAttr(self):
       
   769         self.assert_(not self.cli_file.closed)
       
   770 
       
   771 class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
       
   772 
       
   773     """Repeat the tests from FileObjectClassTestCase with bufsize==0.
       
   774 
       
   775     In this case (and in this case only), it should be possible to
       
   776     create a file object, read a line from it, create another file
       
   777     object, read another line from it, without loss of data in the
       
   778     first file object's buffer.  Note that httplib relies on this
       
   779     when reading multiple requests from the same socket."""
       
   780 
       
   781     bufsize = 0 # Use unbuffered mode
       
   782 
       
   783     def testUnbufferedReadline(self):
       
   784         # Read a line, create a new file object, read another line with it
       
   785         line = self.serv_file.readline() # first line
       
   786         self.assertEqual(line, "A. " + MSG) # first line
       
   787         self.serv_file = self.cli_conn.makefile('rb', 0)
       
   788         line = self.serv_file.readline() # second line
       
   789         self.assertEqual(line, "B. " + MSG) # second line
       
   790 
       
   791     def _testUnbufferedReadline(self):
       
   792         self.cli_file.write("A. " + MSG)
       
   793         self.cli_file.write("B. " + MSG)
       
   794         self.cli_file.flush()
       
   795 
       
   796 class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase):
       
   797 
       
   798     bufsize = 1 # Default-buffered for reading; line-buffered for writing
       
   799 
       
   800 
       
   801 class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):
       
   802 
       
   803     bufsize = 2 # Exercise the buffering code
       
   804 
       
   805 
       
   806 class Urllib2FileobjectTest(unittest.TestCase):
       
   807 
       
   808     # urllib2.HTTPHandler has "borrowed" socket._fileobject, and requires that
       
   809     # it close the socket if the close c'tor argument is true
       
   810 
       
   811     def testClose(self):
       
   812         class MockSocket:
       
   813             closed = False
       
   814             def flush(self): pass
       
   815             def close(self): self.closed = True
       
   816 
       
   817         # must not close unless we request it: the original use of _fileobject
       
   818         # by module socket requires that the underlying socket not be closed until
       
   819         # the _socketobject that created the _fileobject is closed
       
   820         s = MockSocket()
       
   821         f = socket._fileobject(s)
       
   822         f.close()
       
   823         self.assert_(not s.closed)
       
   824 
       
   825         s = MockSocket()
       
   826         f = socket._fileobject(s, close=True)
       
   827         f.close()
       
   828         self.assert_(s.closed)
       
   829 
       
   830 class TCPTimeoutTest(SocketTCPTest):
       
   831 
       
   832     def testTCPTimeout(self):
       
   833         def raise_timeout(*args, **kwargs):
       
   834             self.serv.settimeout(1.0)
       
   835             self.serv.accept()
       
   836         self.failUnlessRaises(socket.timeout, raise_timeout,
       
   837                               "Error generating a timeout exception (TCP)")
       
   838 
       
   839     def testTimeoutZero(self):
       
   840         ok = False
       
   841         try:
       
   842             self.serv.settimeout(0.0)
       
   843             foo = self.serv.accept()
       
   844         except socket.timeout:
       
   845             self.fail("caught timeout instead of error (TCP)")
       
   846         except socket.error:
       
   847             ok = True
       
   848         except:
       
   849             self.fail("caught unexpected exception (TCP)")
       
   850         if not ok:
       
   851             self.fail("accept() returned success when we did not expect it")
       
   852 
       
   853     def testInterruptedTimeout(self):
       
   854         # XXX I don't know how to do this test on MSWindows or any other
       
   855         # plaform that doesn't support signal.alarm() or os.kill(), though
       
   856         # the bug should have existed on all platforms.
       
   857         if not hasattr(signal, "alarm"):
       
   858             return                  # can only test on *nix
       
   859         self.serv.settimeout(5.0)   # must be longer than alarm
       
   860         class Alarm(Exception):
       
   861             pass
       
   862         def alarm_handler(signal, frame):
       
   863             raise Alarm
       
   864         old_alarm = signal.signal(signal.SIGALRM, alarm_handler)
       
   865         try:
       
   866             signal.alarm(2)    # POSIX allows alarm to be up to 1 second early
       
   867             try:
       
   868                 foo = self.serv.accept()
       
   869             except socket.timeout:
       
   870                 self.fail("caught timeout instead of Alarm")
       
   871             except Alarm:
       
   872                 pass
       
   873             except:
       
   874                 self.fail("caught other exception instead of Alarm")
       
   875             else:
       
   876                 self.fail("nothing caught")
       
   877             signal.alarm(0)         # shut off alarm
       
   878         except Alarm:
       
   879             self.fail("got Alarm in wrong place")
       
   880         finally:
       
   881             # no alarm can be pending.  Safe to restore old handler.
       
   882             signal.signal(signal.SIGALRM, old_alarm)
       
   883 
       
   884 class UDPTimeoutTest(SocketTCPTest):
       
   885 
       
   886     def testUDPTimeout(self):
       
   887         def raise_timeout(*args, **kwargs):
       
   888             self.serv.settimeout(1.0)
       
   889             self.serv.recv(1024)
       
   890         self.failUnlessRaises(socket.timeout, raise_timeout,
       
   891                               "Error generating a timeout exception (UDP)")
       
   892 
       
   893     def testTimeoutZero(self):
       
   894         ok = False
       
   895         try:
       
   896             self.serv.settimeout(0.0)
       
   897             foo = self.serv.recv(1024)
       
   898         except socket.timeout:
       
   899             self.fail("caught timeout instead of error (UDP)")
       
   900         except socket.error:
       
   901             ok = True
       
   902         except:
       
   903             self.fail("caught unexpected exception (UDP)")
       
   904         if not ok:
       
   905             self.fail("recv() returned success when we did not expect it")
       
   906 
       
   907 class TestExceptions(unittest.TestCase):
       
   908 
       
   909     def testExceptionTree(self):
       
   910         self.assert_(issubclass(socket.error, Exception))
       
   911         self.assert_(issubclass(socket.herror, socket.error))
       
   912         self.assert_(issubclass(socket.gaierror, socket.error))
       
   913         self.assert_(issubclass(socket.timeout, socket.error))
       
   914 
       
   915 class TestLinuxAbstractNamespace(unittest.TestCase):
       
   916 
       
   917     UNIX_PATH_MAX = 108
       
   918 
       
   919     def testLinuxAbstractNamespace(self):
       
   920         address = "\x00python-test-hello\x00\xff"
       
   921         s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
       
   922         s1.bind(address)
       
   923         s1.listen(1)
       
   924         s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
       
   925         s2.connect(s1.getsockname())
       
   926         s1.accept()
       
   927         self.assertEqual(s1.getsockname(), address)
       
   928         self.assertEqual(s2.getpeername(), address)
       
   929 
       
   930     def testMaxName(self):
       
   931         address = "\x00" + "h" * (self.UNIX_PATH_MAX - 1)
       
   932         s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
       
   933         s.bind(address)
       
   934         self.assertEqual(s.getsockname(), address)
       
   935 
       
   936     def testNameOverflow(self):
       
   937         address = "\x00" + "h" * self.UNIX_PATH_MAX
       
   938         s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
       
   939         self.assertRaises(socket.error, s.bind, address)
       
   940 
       
   941 
       
   942 class BufferIOTest(SocketConnectedTest):
       
   943     """
       
   944     Test the buffer versions of socket.recv() and socket.send().
       
   945     """
       
   946     def __init__(self, methodName='runTest'):
       
   947         SocketConnectedTest.__init__(self, methodName=methodName)
       
   948 
       
   949     def testRecvInto(self):
       
   950         buf = array.array('c', ' '*1024)
       
   951         nbytes = self.cli_conn.recv_into(buf)
       
   952         self.assertEqual(nbytes, len(MSG))
       
   953         msg = buf.tostring()[:len(MSG)]
       
   954         self.assertEqual(msg, MSG)
       
   955 
       
   956     def _testRecvInto(self):
       
   957         buf = buffer(MSG)
       
   958         self.serv_conn.send(buf)
       
   959 
       
   960     def testRecvFromInto(self):
       
   961         buf = array.array('c', ' '*1024)
       
   962         nbytes, addr = self.cli_conn.recvfrom_into(buf)
       
   963         self.assertEqual(nbytes, len(MSG))
       
   964         msg = buf.tostring()[:len(MSG)]
       
   965         self.assertEqual(msg, MSG)
       
   966 
       
   967     def _testRecvFromInto(self):
       
   968         buf = buffer(MSG)
       
   969         self.serv_conn.send(buf)
       
   970 
       
   971 def test_main():
       
   972     tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
       
   973              TestExceptions, BufferIOTest]
       
   974     if sys.platform != 'mac':
       
   975         tests.extend([ BasicUDPTest, UDPTimeoutTest ])
       
   976 
       
   977     tests.extend([
       
   978         NonBlockingTCPTests,
       
   979         FileObjectClassTestCase,
       
   980         UnbufferedFileObjectClassTestCase,
       
   981         LineBufferedFileObjectClassTestCase,
       
   982         SmallBufferedFileObjectClassTestCase,
       
   983         Urllib2FileobjectTest,
       
   984     ])
       
   985     if hasattr(socket, "socketpair"):
       
   986         tests.append(BasicSocketPairTest)
       
   987     if sys.platform == 'linux2':
       
   988         tests.append(TestLinuxAbstractNamespace)
       
   989 
       
   990     thread_info = test_support.threading_setup()
       
   991     test_support.run_unittest(*tests)
       
   992     test_support.threading_cleanup(*thread_info)
       
   993 
       
   994 if __name__ == "__main__":
       
   995     test_main()