|
1 #!/usr/bin/env python2.5 |
|
2 # Copyright 2006 Google, Inc. All Rights Reserved. |
|
3 # Licensed to PSF under a Contributor Agreement. |
|
4 |
|
5 """Refactoring framework. |
|
6 |
|
7 Used as a main program, this can refactor any number of files and/or |
|
8 recursively descend down directories. Imported as a module, this |
|
9 provides infrastructure to write your own refactoring tool. |
|
10 """ |
|
11 |
|
12 __author__ = "Guido van Rossum <guido@python.org>" |
|
13 |
|
14 |
|
15 # Python imports |
|
16 import os |
|
17 import sys |
|
18 import difflib |
|
19 import logging |
|
20 import operator |
|
21 from collections import defaultdict |
|
22 from itertools import chain |
|
23 |
|
24 # Local imports |
|
25 from .pgen2 import driver |
|
26 from .pgen2 import tokenize |
|
27 |
|
28 from . import pytree |
|
29 from . import patcomp |
|
30 from . import fixes |
|
31 from . import pygram |
|
32 |
|
33 |
|
34 def get_all_fix_names(fixer_pkg, remove_prefix=True): |
|
35 """Return a sorted list of all available fix names in the given package.""" |
|
36 pkg = __import__(fixer_pkg, [], [], ["*"]) |
|
37 fixer_dir = os.path.dirname(pkg.__file__) |
|
38 fix_names = [] |
|
39 for name in sorted(os.listdir(fixer_dir)): |
|
40 if name.startswith("fix_") and name.endswith(".py"): |
|
41 if remove_prefix: |
|
42 name = name[4:] |
|
43 fix_names.append(name[:-3]) |
|
44 return fix_names |
|
45 |
|
46 def get_head_types(pat): |
|
47 """ Accepts a pytree Pattern Node and returns a set |
|
48 of the pattern types which will match first. """ |
|
49 |
|
50 if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)): |
|
51 # NodePatters must either have no type and no content |
|
52 # or a type and content -- so they don't get any farther |
|
53 # Always return leafs |
|
54 return set([pat.type]) |
|
55 |
|
56 if isinstance(pat, pytree.NegatedPattern): |
|
57 if pat.content: |
|
58 return get_head_types(pat.content) |
|
59 return set([None]) # Negated Patterns don't have a type |
|
60 |
|
61 if isinstance(pat, pytree.WildcardPattern): |
|
62 # Recurse on each node in content |
|
63 r = set() |
|
64 for p in pat.content: |
|
65 for x in p: |
|
66 r.update(get_head_types(x)) |
|
67 return r |
|
68 |
|
69 raise Exception("Oh no! I don't understand pattern %s" %(pat)) |
|
70 |
|
71 def get_headnode_dict(fixer_list): |
|
72 """ Accepts a list of fixers and returns a dictionary |
|
73 of head node type --> fixer list. """ |
|
74 head_nodes = defaultdict(list) |
|
75 for fixer in fixer_list: |
|
76 if not fixer.pattern: |
|
77 head_nodes[None].append(fixer) |
|
78 continue |
|
79 for t in get_head_types(fixer.pattern): |
|
80 head_nodes[t].append(fixer) |
|
81 return head_nodes |
|
82 |
|
83 def get_fixers_from_package(pkg_name): |
|
84 """ |
|
85 Return the fully qualified names for fixers in the package pkg_name. |
|
86 """ |
|
87 return [pkg_name + "." + fix_name |
|
88 for fix_name in get_all_fix_names(pkg_name, False)] |
|
89 |
|
90 |
|
91 class FixerError(Exception): |
|
92 """A fixer could not be loaded.""" |
|
93 |
|
94 |
|
95 class RefactoringTool(object): |
|
96 |
|
97 _default_options = {"print_function": False} |
|
98 |
|
99 CLASS_PREFIX = "Fix" # The prefix for fixer classes |
|
100 FILE_PREFIX = "fix_" # The prefix for modules with a fixer within |
|
101 |
|
102 def __init__(self, fixer_names, options=None, explicit=None): |
|
103 """Initializer. |
|
104 |
|
105 Args: |
|
106 fixer_names: a list of fixers to import |
|
107 options: an dict with configuration. |
|
108 explicit: a list of fixers to run even if they are explicit. |
|
109 """ |
|
110 self.fixers = fixer_names |
|
111 self.explicit = explicit or [] |
|
112 self.options = self._default_options.copy() |
|
113 if options is not None: |
|
114 self.options.update(options) |
|
115 self.errors = [] |
|
116 self.logger = logging.getLogger("RefactoringTool") |
|
117 self.fixer_log = [] |
|
118 self.wrote = False |
|
119 if self.options["print_function"]: |
|
120 del pygram.python_grammar.keywords["print"] |
|
121 self.driver = driver.Driver(pygram.python_grammar, |
|
122 convert=pytree.convert, |
|
123 logger=self.logger) |
|
124 self.pre_order, self.post_order = self.get_fixers() |
|
125 |
|
126 self.pre_order_mapping = get_headnode_dict(self.pre_order) |
|
127 self.post_order_mapping = get_headnode_dict(self.post_order) |
|
128 |
|
129 self.files = [] # List of files that were or should be modified |
|
130 |
|
131 def get_fixers(self): |
|
132 """Inspects the options to load the requested patterns and handlers. |
|
133 |
|
134 Returns: |
|
135 (pre_order, post_order), where pre_order is the list of fixers that |
|
136 want a pre-order AST traversal, and post_order is the list that want |
|
137 post-order traversal. |
|
138 """ |
|
139 pre_order_fixers = [] |
|
140 post_order_fixers = [] |
|
141 for fix_mod_path in self.fixers: |
|
142 mod = __import__(fix_mod_path, {}, {}, ["*"]) |
|
143 fix_name = fix_mod_path.rsplit(".", 1)[-1] |
|
144 if fix_name.startswith(self.FILE_PREFIX): |
|
145 fix_name = fix_name[len(self.FILE_PREFIX):] |
|
146 parts = fix_name.split("_") |
|
147 class_name = self.CLASS_PREFIX + "".join([p.title() for p in parts]) |
|
148 try: |
|
149 fix_class = getattr(mod, class_name) |
|
150 except AttributeError: |
|
151 raise FixerError("Can't find %s.%s" % (fix_name, class_name)) |
|
152 fixer = fix_class(self.options, self.fixer_log) |
|
153 if fixer.explicit and self.explicit is not True and \ |
|
154 fix_mod_path not in self.explicit: |
|
155 self.log_message("Skipping implicit fixer: %s", fix_name) |
|
156 continue |
|
157 |
|
158 self.log_debug("Adding transformation: %s", fix_name) |
|
159 if fixer.order == "pre": |
|
160 pre_order_fixers.append(fixer) |
|
161 elif fixer.order == "post": |
|
162 post_order_fixers.append(fixer) |
|
163 else: |
|
164 raise FixerError("Illegal fixer order: %r" % fixer.order) |
|
165 |
|
166 key_func = operator.attrgetter("run_order") |
|
167 pre_order_fixers.sort(key=key_func) |
|
168 post_order_fixers.sort(key=key_func) |
|
169 return (pre_order_fixers, post_order_fixers) |
|
170 |
|
171 def log_error(self, msg, *args, **kwds): |
|
172 """Called when an error occurs.""" |
|
173 raise |
|
174 |
|
175 def log_message(self, msg, *args): |
|
176 """Hook to log a message.""" |
|
177 if args: |
|
178 msg = msg % args |
|
179 self.logger.info(msg) |
|
180 |
|
181 def log_debug(self, msg, *args): |
|
182 if args: |
|
183 msg = msg % args |
|
184 self.logger.debug(msg) |
|
185 |
|
186 def print_output(self, lines): |
|
187 """Called with lines of output to give to the user.""" |
|
188 pass |
|
189 |
|
190 def refactor(self, items, write=False, doctests_only=False): |
|
191 """Refactor a list of files and directories.""" |
|
192 for dir_or_file in items: |
|
193 if os.path.isdir(dir_or_file): |
|
194 self.refactor_dir(dir_or_file, write, doctests_only) |
|
195 else: |
|
196 self.refactor_file(dir_or_file, write, doctests_only) |
|
197 |
|
198 def refactor_dir(self, dir_name, write=False, doctests_only=False): |
|
199 """Descends down a directory and refactor every Python file found. |
|
200 |
|
201 Python files are assumed to have a .py extension. |
|
202 |
|
203 Files and subdirectories starting with '.' are skipped. |
|
204 """ |
|
205 for dirpath, dirnames, filenames in os.walk(dir_name): |
|
206 self.log_debug("Descending into %s", dirpath) |
|
207 dirnames.sort() |
|
208 filenames.sort() |
|
209 for name in filenames: |
|
210 if not name.startswith(".") and name.endswith("py"): |
|
211 fullname = os.path.join(dirpath, name) |
|
212 self.refactor_file(fullname, write, doctests_only) |
|
213 # Modify dirnames in-place to remove subdirs with leading dots |
|
214 dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")] |
|
215 |
|
216 def refactor_file(self, filename, write=False, doctests_only=False): |
|
217 """Refactors a file.""" |
|
218 try: |
|
219 f = open(filename) |
|
220 except IOError, err: |
|
221 self.log_error("Can't open %s: %s", filename, err) |
|
222 return |
|
223 try: |
|
224 input = f.read() + "\n" # Silence certain parse errors |
|
225 finally: |
|
226 f.close() |
|
227 if doctests_only: |
|
228 self.log_debug("Refactoring doctests in %s", filename) |
|
229 output = self.refactor_docstring(input, filename) |
|
230 if output != input: |
|
231 self.processed_file(output, filename, input, write=write) |
|
232 else: |
|
233 self.log_debug("No doctest changes in %s", filename) |
|
234 else: |
|
235 tree = self.refactor_string(input, filename) |
|
236 if tree and tree.was_changed: |
|
237 # The [:-1] is to take off the \n we added earlier |
|
238 self.processed_file(str(tree)[:-1], filename, write=write) |
|
239 else: |
|
240 self.log_debug("No changes in %s", filename) |
|
241 |
|
242 def refactor_string(self, data, name): |
|
243 """Refactor a given input string. |
|
244 |
|
245 Args: |
|
246 data: a string holding the code to be refactored. |
|
247 name: a human-readable name for use in error/log messages. |
|
248 |
|
249 Returns: |
|
250 An AST corresponding to the refactored input stream; None if |
|
251 there were errors during the parse. |
|
252 """ |
|
253 try: |
|
254 tree = self.driver.parse_string(data) |
|
255 except Exception, err: |
|
256 self.log_error("Can't parse %s: %s: %s", |
|
257 name, err.__class__.__name__, err) |
|
258 return |
|
259 self.log_debug("Refactoring %s", name) |
|
260 self.refactor_tree(tree, name) |
|
261 return tree |
|
262 |
|
263 def refactor_stdin(self, doctests_only=False): |
|
264 input = sys.stdin.read() |
|
265 if doctests_only: |
|
266 self.log_debug("Refactoring doctests in stdin") |
|
267 output = self.refactor_docstring(input, "<stdin>") |
|
268 if output != input: |
|
269 self.processed_file(output, "<stdin>", input) |
|
270 else: |
|
271 self.log_debug("No doctest changes in stdin") |
|
272 else: |
|
273 tree = self.refactor_string(input, "<stdin>") |
|
274 if tree and tree.was_changed: |
|
275 self.processed_file(str(tree), "<stdin>", input) |
|
276 else: |
|
277 self.log_debug("No changes in stdin") |
|
278 |
|
279 def refactor_tree(self, tree, name): |
|
280 """Refactors a parse tree (modifying the tree in place). |
|
281 |
|
282 Args: |
|
283 tree: a pytree.Node instance representing the root of the tree |
|
284 to be refactored. |
|
285 name: a human-readable name for this tree. |
|
286 |
|
287 Returns: |
|
288 True if the tree was modified, False otherwise. |
|
289 """ |
|
290 # Two calls to chain are required because pre_order.values() |
|
291 # will be a list of lists of fixers: |
|
292 # [[<fixer ...>, <fixer ...>], [<fixer ...>]] |
|
293 all_fixers = chain(self.pre_order, self.post_order) |
|
294 for fixer in all_fixers: |
|
295 fixer.start_tree(tree, name) |
|
296 |
|
297 self.traverse_by(self.pre_order_mapping, tree.pre_order()) |
|
298 self.traverse_by(self.post_order_mapping, tree.post_order()) |
|
299 |
|
300 for fixer in all_fixers: |
|
301 fixer.finish_tree(tree, name) |
|
302 return tree.was_changed |
|
303 |
|
304 def traverse_by(self, fixers, traversal): |
|
305 """Traverse an AST, applying a set of fixers to each node. |
|
306 |
|
307 This is a helper method for refactor_tree(). |
|
308 |
|
309 Args: |
|
310 fixers: a list of fixer instances. |
|
311 traversal: a generator that yields AST nodes. |
|
312 |
|
313 Returns: |
|
314 None |
|
315 """ |
|
316 if not fixers: |
|
317 return |
|
318 for node in traversal: |
|
319 for fixer in fixers[node.type] + fixers[None]: |
|
320 results = fixer.match(node) |
|
321 if results: |
|
322 new = fixer.transform(node, results) |
|
323 if new is not None and (new != node or |
|
324 str(new) != str(node)): |
|
325 node.replace(new) |
|
326 node = new |
|
327 |
|
328 def processed_file(self, new_text, filename, old_text=None, write=False): |
|
329 """ |
|
330 Called when a file has been refactored, and there are changes. |
|
331 """ |
|
332 self.files.append(filename) |
|
333 if old_text is None: |
|
334 try: |
|
335 f = open(filename, "r") |
|
336 except IOError, err: |
|
337 self.log_error("Can't read %s: %s", filename, err) |
|
338 return |
|
339 try: |
|
340 old_text = f.read() |
|
341 finally: |
|
342 f.close() |
|
343 if old_text == new_text: |
|
344 self.log_debug("No changes to %s", filename) |
|
345 return |
|
346 self.print_output(diff_texts(old_text, new_text, filename)) |
|
347 if write: |
|
348 self.write_file(new_text, filename, old_text) |
|
349 else: |
|
350 self.log_debug("Not writing changes to %s", filename) |
|
351 |
|
352 def write_file(self, new_text, filename, old_text): |
|
353 """Writes a string to a file. |
|
354 |
|
355 It first shows a unified diff between the old text and the new text, and |
|
356 then rewrites the file; the latter is only done if the write option is |
|
357 set. |
|
358 """ |
|
359 try: |
|
360 f = open(filename, "w") |
|
361 except os.error, err: |
|
362 self.log_error("Can't create %s: %s", filename, err) |
|
363 return |
|
364 try: |
|
365 f.write(new_text) |
|
366 except os.error, err: |
|
367 self.log_error("Can't write %s: %s", filename, err) |
|
368 finally: |
|
369 f.close() |
|
370 self.log_debug("Wrote changes to %s", filename) |
|
371 self.wrote = True |
|
372 |
|
373 PS1 = ">>> " |
|
374 PS2 = "... " |
|
375 |
|
376 def refactor_docstring(self, input, filename): |
|
377 """Refactors a docstring, looking for doctests. |
|
378 |
|
379 This returns a modified version of the input string. It looks |
|
380 for doctests, which start with a ">>>" prompt, and may be |
|
381 continued with "..." prompts, as long as the "..." is indented |
|
382 the same as the ">>>". |
|
383 |
|
384 (Unfortunately we can't use the doctest module's parser, |
|
385 since, like most parsers, it is not geared towards preserving |
|
386 the original source.) |
|
387 """ |
|
388 result = [] |
|
389 block = None |
|
390 block_lineno = None |
|
391 indent = None |
|
392 lineno = 0 |
|
393 for line in input.splitlines(True): |
|
394 lineno += 1 |
|
395 if line.lstrip().startswith(self.PS1): |
|
396 if block is not None: |
|
397 result.extend(self.refactor_doctest(block, block_lineno, |
|
398 indent, filename)) |
|
399 block_lineno = lineno |
|
400 block = [line] |
|
401 i = line.find(self.PS1) |
|
402 indent = line[:i] |
|
403 elif (indent is not None and |
|
404 (line.startswith(indent + self.PS2) or |
|
405 line == indent + self.PS2.rstrip() + "\n")): |
|
406 block.append(line) |
|
407 else: |
|
408 if block is not None: |
|
409 result.extend(self.refactor_doctest(block, block_lineno, |
|
410 indent, filename)) |
|
411 block = None |
|
412 indent = None |
|
413 result.append(line) |
|
414 if block is not None: |
|
415 result.extend(self.refactor_doctest(block, block_lineno, |
|
416 indent, filename)) |
|
417 return "".join(result) |
|
418 |
|
419 def refactor_doctest(self, block, lineno, indent, filename): |
|
420 """Refactors one doctest. |
|
421 |
|
422 A doctest is given as a block of lines, the first of which starts |
|
423 with ">>>" (possibly indented), while the remaining lines start |
|
424 with "..." (identically indented). |
|
425 |
|
426 """ |
|
427 try: |
|
428 tree = self.parse_block(block, lineno, indent) |
|
429 except Exception, err: |
|
430 if self.log.isEnabledFor(logging.DEBUG): |
|
431 for line in block: |
|
432 self.log_debug("Source: %s", line.rstrip("\n")) |
|
433 self.log_error("Can't parse docstring in %s line %s: %s: %s", |
|
434 filename, lineno, err.__class__.__name__, err) |
|
435 return block |
|
436 if self.refactor_tree(tree, filename): |
|
437 new = str(tree).splitlines(True) |
|
438 # Undo the adjustment of the line numbers in wrap_toks() below. |
|
439 clipped, new = new[:lineno-1], new[lineno-1:] |
|
440 assert clipped == ["\n"] * (lineno-1), clipped |
|
441 if not new[-1].endswith("\n"): |
|
442 new[-1] += "\n" |
|
443 block = [indent + self.PS1 + new.pop(0)] |
|
444 if new: |
|
445 block += [indent + self.PS2 + line for line in new] |
|
446 return block |
|
447 |
|
448 def summarize(self): |
|
449 if self.wrote: |
|
450 were = "were" |
|
451 else: |
|
452 were = "need to be" |
|
453 if not self.files: |
|
454 self.log_message("No files %s modified.", were) |
|
455 else: |
|
456 self.log_message("Files that %s modified:", were) |
|
457 for file in self.files: |
|
458 self.log_message(file) |
|
459 if self.fixer_log: |
|
460 self.log_message("Warnings/messages while refactoring:") |
|
461 for message in self.fixer_log: |
|
462 self.log_message(message) |
|
463 if self.errors: |
|
464 if len(self.errors) == 1: |
|
465 self.log_message("There was 1 error:") |
|
466 else: |
|
467 self.log_message("There were %d errors:", len(self.errors)) |
|
468 for msg, args, kwds in self.errors: |
|
469 self.log_message(msg, *args, **kwds) |
|
470 |
|
471 def parse_block(self, block, lineno, indent): |
|
472 """Parses a block into a tree. |
|
473 |
|
474 This is necessary to get correct line number / offset information |
|
475 in the parser diagnostics and embedded into the parse tree. |
|
476 """ |
|
477 return self.driver.parse_tokens(self.wrap_toks(block, lineno, indent)) |
|
478 |
|
479 def wrap_toks(self, block, lineno, indent): |
|
480 """Wraps a tokenize stream to systematically modify start/end.""" |
|
481 tokens = tokenize.generate_tokens(self.gen_lines(block, indent).next) |
|
482 for type, value, (line0, col0), (line1, col1), line_text in tokens: |
|
483 line0 += lineno - 1 |
|
484 line1 += lineno - 1 |
|
485 # Don't bother updating the columns; this is too complicated |
|
486 # since line_text would also have to be updated and it would |
|
487 # still break for tokens spanning lines. Let the user guess |
|
488 # that the column numbers for doctests are relative to the |
|
489 # end of the prompt string (PS1 or PS2). |
|
490 yield type, value, (line0, col0), (line1, col1), line_text |
|
491 |
|
492 |
|
493 def gen_lines(self, block, indent): |
|
494 """Generates lines as expected by tokenize from a list of lines. |
|
495 |
|
496 This strips the first len(indent + self.PS1) characters off each line. |
|
497 """ |
|
498 prefix1 = indent + self.PS1 |
|
499 prefix2 = indent + self.PS2 |
|
500 prefix = prefix1 |
|
501 for line in block: |
|
502 if line.startswith(prefix): |
|
503 yield line[len(prefix):] |
|
504 elif line == prefix.rstrip() + "\n": |
|
505 yield "\n" |
|
506 else: |
|
507 raise AssertionError("line=%r, prefix=%r" % (line, prefix)) |
|
508 prefix = prefix2 |
|
509 while True: |
|
510 yield "" |
|
511 |
|
512 |
|
513 def diff_texts(a, b, filename): |
|
514 """Return a unified diff of two strings.""" |
|
515 a = a.splitlines() |
|
516 b = b.splitlines() |
|
517 return difflib.unified_diff(a, b, filename, filename, |
|
518 "(original)", "(refactored)", |
|
519 lineterm="") |