|
1 #-*- coding: ISO-8859-1 -*- |
|
2 # pysqlite2/test/factory.py: tests for the various factories in pysqlite |
|
3 # |
|
4 # Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de> |
|
5 # |
|
6 # This file is part of pysqlite. |
|
7 # |
|
8 # This software is provided 'as-is', without any express or implied |
|
9 # warranty. In no event will the authors be held liable for any damages |
|
10 # arising from the use of this software. |
|
11 # |
|
12 # Permission is granted to anyone to use this software for any purpose, |
|
13 # including commercial applications, and to alter it and redistribute it |
|
14 # freely, subject to the following restrictions: |
|
15 # |
|
16 # 1. The origin of this software must not be misrepresented; you must not |
|
17 # claim that you wrote the original software. If you use this software |
|
18 # in a product, an acknowledgment in the product documentation would be |
|
19 # appreciated but is not required. |
|
20 # 2. Altered source versions must be plainly marked as such, and must not be |
|
21 # misrepresented as being the original software. |
|
22 # 3. This notice may not be removed or altered from any source distribution. |
|
23 |
|
24 import unittest |
|
25 import sqlite3 as sqlite |
|
26 |
|
27 class MyConnection(sqlite.Connection): |
|
28 def __init__(self, *args, **kwargs): |
|
29 sqlite.Connection.__init__(self, *args, **kwargs) |
|
30 |
|
31 def dict_factory(cursor, row): |
|
32 d = {} |
|
33 for idx, col in enumerate(cursor.description): |
|
34 d[col[0]] = row[idx] |
|
35 return d |
|
36 |
|
37 class MyCursor(sqlite.Cursor): |
|
38 def __init__(self, *args, **kwargs): |
|
39 sqlite.Cursor.__init__(self, *args, **kwargs) |
|
40 self.row_factory = dict_factory |
|
41 |
|
42 class ConnectionFactoryTests(unittest.TestCase): |
|
43 def setUp(self): |
|
44 self.con = sqlite.connect(":memory:", factory=MyConnection) |
|
45 |
|
46 def tearDown(self): |
|
47 self.con.close() |
|
48 |
|
49 def CheckIsInstance(self): |
|
50 self.failUnless(isinstance(self.con, |
|
51 MyConnection), |
|
52 "connection is not instance of MyConnection") |
|
53 |
|
54 class CursorFactoryTests(unittest.TestCase): |
|
55 def setUp(self): |
|
56 self.con = sqlite.connect(":memory:") |
|
57 |
|
58 def tearDown(self): |
|
59 self.con.close() |
|
60 |
|
61 def CheckIsInstance(self): |
|
62 cur = self.con.cursor(factory=MyCursor) |
|
63 self.failUnless(isinstance(cur, |
|
64 MyCursor), |
|
65 "cursor is not instance of MyCursor") |
|
66 |
|
67 class RowFactoryTestsBackwardsCompat(unittest.TestCase): |
|
68 def setUp(self): |
|
69 self.con = sqlite.connect(":memory:") |
|
70 |
|
71 def CheckIsProducedByFactory(self): |
|
72 cur = self.con.cursor(factory=MyCursor) |
|
73 cur.execute("select 4+5 as foo") |
|
74 row = cur.fetchone() |
|
75 self.failUnless(isinstance(row, |
|
76 dict), |
|
77 "row is not instance of dict") |
|
78 cur.close() |
|
79 |
|
80 def tearDown(self): |
|
81 self.con.close() |
|
82 |
|
83 class RowFactoryTests(unittest.TestCase): |
|
84 def setUp(self): |
|
85 self.con = sqlite.connect(":memory:") |
|
86 |
|
87 def CheckCustomFactory(self): |
|
88 self.con.row_factory = lambda cur, row: list(row) |
|
89 row = self.con.execute("select 1, 2").fetchone() |
|
90 self.failUnless(isinstance(row, |
|
91 list), |
|
92 "row is not instance of list") |
|
93 |
|
94 def CheckSqliteRowIndex(self): |
|
95 self.con.row_factory = sqlite.Row |
|
96 row = self.con.execute("select 1 as a, 2 as b").fetchone() |
|
97 self.failUnless(isinstance(row, |
|
98 sqlite.Row), |
|
99 "row is not instance of sqlite.Row") |
|
100 |
|
101 col1, col2 = row["a"], row["b"] |
|
102 self.failUnless(col1 == 1, "by name: wrong result for column 'a'") |
|
103 self.failUnless(col2 == 2, "by name: wrong result for column 'a'") |
|
104 |
|
105 col1, col2 = row["A"], row["B"] |
|
106 self.failUnless(col1 == 1, "by name: wrong result for column 'A'") |
|
107 self.failUnless(col2 == 2, "by name: wrong result for column 'B'") |
|
108 |
|
109 col1, col2 = row[0], row[1] |
|
110 self.failUnless(col1 == 1, "by index: wrong result for column 0") |
|
111 self.failUnless(col2 == 2, "by index: wrong result for column 1") |
|
112 |
|
113 def CheckSqliteRowIter(self): |
|
114 """Checks if the row object is iterable""" |
|
115 self.con.row_factory = sqlite.Row |
|
116 row = self.con.execute("select 1 as a, 2 as b").fetchone() |
|
117 for col in row: |
|
118 pass |
|
119 |
|
120 def CheckSqliteRowAsTuple(self): |
|
121 """Checks if the row object can be converted to a tuple""" |
|
122 self.con.row_factory = sqlite.Row |
|
123 row = self.con.execute("select 1 as a, 2 as b").fetchone() |
|
124 t = tuple(row) |
|
125 |
|
126 def CheckSqliteRowAsDict(self): |
|
127 """Checks if the row object can be correctly converted to a dictionary""" |
|
128 self.con.row_factory = sqlite.Row |
|
129 row = self.con.execute("select 1 as a, 2 as b").fetchone() |
|
130 d = dict(row) |
|
131 self.failUnlessEqual(d["a"], row["a"]) |
|
132 self.failUnlessEqual(d["b"], row["b"]) |
|
133 |
|
134 def CheckSqliteRowHashCmp(self): |
|
135 """Checks if the row object compares and hashes correctly""" |
|
136 self.con.row_factory = sqlite.Row |
|
137 row_1 = self.con.execute("select 1 as a, 2 as b").fetchone() |
|
138 row_2 = self.con.execute("select 1 as a, 2 as b").fetchone() |
|
139 row_3 = self.con.execute("select 1 as a, 3 as b").fetchone() |
|
140 |
|
141 self.failUnless(row_1 == row_1) |
|
142 self.failUnless(row_1 == row_2) |
|
143 self.failUnless(row_2 != row_3) |
|
144 |
|
145 self.failIf(row_1 != row_1) |
|
146 self.failIf(row_1 != row_2) |
|
147 self.failIf(row_2 == row_3) |
|
148 |
|
149 self.failUnlessEqual(row_1, row_2) |
|
150 self.failUnlessEqual(hash(row_1), hash(row_2)) |
|
151 self.failIfEqual(row_1, row_3) |
|
152 self.failIfEqual(hash(row_1), hash(row_3)) |
|
153 |
|
154 def tearDown(self): |
|
155 self.con.close() |
|
156 |
|
157 class TextFactoryTests(unittest.TestCase): |
|
158 def setUp(self): |
|
159 self.con = sqlite.connect(":memory:") |
|
160 |
|
161 def CheckUnicode(self): |
|
162 austria = unicode("Österreich", "latin1") |
|
163 row = self.con.execute("select ?", (austria,)).fetchone() |
|
164 self.failUnless(type(row[0]) == unicode, "type of row[0] must be unicode") |
|
165 |
|
166 def CheckString(self): |
|
167 self.con.text_factory = str |
|
168 austria = unicode("Österreich", "latin1") |
|
169 row = self.con.execute("select ?", (austria,)).fetchone() |
|
170 self.failUnless(type(row[0]) == str, "type of row[0] must be str") |
|
171 self.failUnless(row[0] == austria.encode("utf-8"), "column must equal original data in UTF-8") |
|
172 |
|
173 def CheckCustom(self): |
|
174 self.con.text_factory = lambda x: unicode(x, "utf-8", "ignore") |
|
175 austria = unicode("Österreich", "latin1") |
|
176 row = self.con.execute("select ?", (austria.encode("latin1"),)).fetchone() |
|
177 self.failUnless(type(row[0]) == unicode, "type of row[0] must be unicode") |
|
178 self.failUnless(row[0].endswith(u"reich"), "column must contain original data") |
|
179 |
|
180 def CheckOptimizedUnicode(self): |
|
181 self.con.text_factory = sqlite.OptimizedUnicode |
|
182 austria = unicode("Österreich", "latin1") |
|
183 germany = unicode("Deutchland") |
|
184 a_row = self.con.execute("select ?", (austria,)).fetchone() |
|
185 d_row = self.con.execute("select ?", (germany,)).fetchone() |
|
186 self.failUnless(type(a_row[0]) == unicode, "type of non-ASCII row must be unicode") |
|
187 self.failUnless(type(d_row[0]) == str, "type of ASCII-only row must be str") |
|
188 |
|
189 def tearDown(self): |
|
190 self.con.close() |
|
191 |
|
192 def suite(): |
|
193 connection_suite = unittest.makeSuite(ConnectionFactoryTests, "Check") |
|
194 cursor_suite = unittest.makeSuite(CursorFactoryTests, "Check") |
|
195 row_suite_compat = unittest.makeSuite(RowFactoryTestsBackwardsCompat, "Check") |
|
196 row_suite = unittest.makeSuite(RowFactoryTests, "Check") |
|
197 text_suite = unittest.makeSuite(TextFactoryTests, "Check") |
|
198 return unittest.TestSuite((connection_suite, cursor_suite, row_suite_compat, row_suite, text_suite)) |
|
199 |
|
200 def test(): |
|
201 runner = unittest.TextTestRunner() |
|
202 runner.run(suite()) |
|
203 |
|
204 if __name__ == "__main__": |
|
205 test() |