symbian-qemu-0.9.1-12/python-win32-2.6.1/lib/lib2to3/refactor.py
changeset 1 2fb8b9db1c86
equal deleted inserted replaced
0:ffa851df0825 1:2fb8b9db1c86
       
     1 #!/usr/bin/env python2.5
       
     2 # Copyright 2006 Google, Inc. All Rights Reserved.
       
     3 # Licensed to PSF under a Contributor Agreement.
       
     4 
       
     5 """Refactoring framework.
       
     6 
       
     7 Used as a main program, this can refactor any number of files and/or
       
     8 recursively descend down directories.  Imported as a module, this
       
     9 provides infrastructure to write your own refactoring tool.
       
    10 """
       
    11 
       
    12 __author__ = "Guido van Rossum <guido@python.org>"
       
    13 
       
    14 
       
    15 # Python imports
       
    16 import os
       
    17 import sys
       
    18 import difflib
       
    19 import logging
       
    20 import operator
       
    21 from collections import defaultdict
       
    22 from itertools import chain
       
    23 
       
    24 # Local imports
       
    25 from .pgen2 import driver
       
    26 from .pgen2 import tokenize
       
    27 
       
    28 from . import pytree
       
    29 from . import patcomp
       
    30 from . import fixes
       
    31 from . import pygram
       
    32 
       
    33 
       
    34 def get_all_fix_names(fixer_pkg, remove_prefix=True):
       
    35     """Return a sorted list of all available fix names in the given package."""
       
    36     pkg = __import__(fixer_pkg, [], [], ["*"])
       
    37     fixer_dir = os.path.dirname(pkg.__file__)
       
    38     fix_names = []
       
    39     for name in sorted(os.listdir(fixer_dir)):
       
    40         if name.startswith("fix_") and name.endswith(".py"):
       
    41             if remove_prefix:
       
    42                 name = name[4:]
       
    43             fix_names.append(name[:-3])
       
    44     return fix_names
       
    45 
       
    46 def get_head_types(pat):
       
    47     """ Accepts a pytree Pattern Node and returns a set
       
    48         of the pattern types which will match first. """
       
    49 
       
    50     if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)):
       
    51         # NodePatters must either have no type and no content
       
    52         #   or a type and content -- so they don't get any farther
       
    53         # Always return leafs
       
    54         return set([pat.type])
       
    55 
       
    56     if isinstance(pat, pytree.NegatedPattern):
       
    57         if pat.content:
       
    58             return get_head_types(pat.content)
       
    59         return set([None]) # Negated Patterns don't have a type
       
    60 
       
    61     if isinstance(pat, pytree.WildcardPattern):
       
    62         # Recurse on each node in content
       
    63         r = set()
       
    64         for p in pat.content:
       
    65             for x in p:
       
    66                 r.update(get_head_types(x))
       
    67         return r
       
    68 
       
    69     raise Exception("Oh no! I don't understand pattern %s" %(pat))
       
    70 
       
    71 def get_headnode_dict(fixer_list):
       
    72     """ Accepts a list of fixers and returns a dictionary
       
    73         of head node type --> fixer list.  """
       
    74     head_nodes = defaultdict(list)
       
    75     for fixer in fixer_list:
       
    76         if not fixer.pattern:
       
    77             head_nodes[None].append(fixer)
       
    78             continue
       
    79         for t in get_head_types(fixer.pattern):
       
    80             head_nodes[t].append(fixer)
       
    81     return head_nodes
       
    82 
       
    83 def get_fixers_from_package(pkg_name):
       
    84     """
       
    85     Return the fully qualified names for fixers in the package pkg_name.
       
    86     """
       
    87     return [pkg_name + "." + fix_name
       
    88             for fix_name in get_all_fix_names(pkg_name, False)]
       
    89 
       
    90 
       
    91 class FixerError(Exception):
       
    92     """A fixer could not be loaded."""
       
    93 
       
    94 
       
    95 class RefactoringTool(object):
       
    96 
       
    97     _default_options = {"print_function": False}
       
    98 
       
    99     CLASS_PREFIX = "Fix" # The prefix for fixer classes
       
   100     FILE_PREFIX = "fix_" # The prefix for modules with a fixer within
       
   101 
       
   102     def __init__(self, fixer_names, options=None, explicit=None):
       
   103         """Initializer.
       
   104 
       
   105         Args:
       
   106             fixer_names: a list of fixers to import
       
   107             options: an dict with configuration.
       
   108             explicit: a list of fixers to run even if they are explicit.
       
   109         """
       
   110         self.fixers = fixer_names
       
   111         self.explicit = explicit or []
       
   112         self.options = self._default_options.copy()
       
   113         if options is not None:
       
   114             self.options.update(options)
       
   115         self.errors = []
       
   116         self.logger = logging.getLogger("RefactoringTool")
       
   117         self.fixer_log = []
       
   118         self.wrote = False
       
   119         if self.options["print_function"]:
       
   120             del pygram.python_grammar.keywords["print"]
       
   121         self.driver = driver.Driver(pygram.python_grammar,
       
   122                                     convert=pytree.convert,
       
   123                                     logger=self.logger)
       
   124         self.pre_order, self.post_order = self.get_fixers()
       
   125 
       
   126         self.pre_order_mapping = get_headnode_dict(self.pre_order)
       
   127         self.post_order_mapping = get_headnode_dict(self.post_order)
       
   128 
       
   129         self.files = []  # List of files that were or should be modified
       
   130 
       
   131     def get_fixers(self):
       
   132         """Inspects the options to load the requested patterns and handlers.
       
   133 
       
   134         Returns:
       
   135           (pre_order, post_order), where pre_order is the list of fixers that
       
   136           want a pre-order AST traversal, and post_order is the list that want
       
   137           post-order traversal.
       
   138         """
       
   139         pre_order_fixers = []
       
   140         post_order_fixers = []
       
   141         for fix_mod_path in self.fixers:
       
   142             mod = __import__(fix_mod_path, {}, {}, ["*"])
       
   143             fix_name = fix_mod_path.rsplit(".", 1)[-1]
       
   144             if fix_name.startswith(self.FILE_PREFIX):
       
   145                 fix_name = fix_name[len(self.FILE_PREFIX):]
       
   146             parts = fix_name.split("_")
       
   147             class_name = self.CLASS_PREFIX + "".join([p.title() for p in parts])
       
   148             try:
       
   149                 fix_class = getattr(mod, class_name)
       
   150             except AttributeError:
       
   151                 raise FixerError("Can't find %s.%s" % (fix_name, class_name))
       
   152             fixer = fix_class(self.options, self.fixer_log)
       
   153             if fixer.explicit and self.explicit is not True and \
       
   154                     fix_mod_path not in self.explicit:
       
   155                 self.log_message("Skipping implicit fixer: %s", fix_name)
       
   156                 continue
       
   157 
       
   158             self.log_debug("Adding transformation: %s", fix_name)
       
   159             if fixer.order == "pre":
       
   160                 pre_order_fixers.append(fixer)
       
   161             elif fixer.order == "post":
       
   162                 post_order_fixers.append(fixer)
       
   163             else:
       
   164                 raise FixerError("Illegal fixer order: %r" % fixer.order)
       
   165 
       
   166         key_func = operator.attrgetter("run_order")
       
   167         pre_order_fixers.sort(key=key_func)
       
   168         post_order_fixers.sort(key=key_func)
       
   169         return (pre_order_fixers, post_order_fixers)
       
   170 
       
   171     def log_error(self, msg, *args, **kwds):
       
   172         """Called when an error occurs."""
       
   173         raise
       
   174 
       
   175     def log_message(self, msg, *args):
       
   176         """Hook to log a message."""
       
   177         if args:
       
   178             msg = msg % args
       
   179         self.logger.info(msg)
       
   180 
       
   181     def log_debug(self, msg, *args):
       
   182         if args:
       
   183             msg = msg % args
       
   184         self.logger.debug(msg)
       
   185 
       
   186     def print_output(self, lines):
       
   187         """Called with lines of output to give to the user."""
       
   188         pass
       
   189 
       
   190     def refactor(self, items, write=False, doctests_only=False):
       
   191         """Refactor a list of files and directories."""
       
   192         for dir_or_file in items:
       
   193             if os.path.isdir(dir_or_file):
       
   194                 self.refactor_dir(dir_or_file, write, doctests_only)
       
   195             else:
       
   196                 self.refactor_file(dir_or_file, write, doctests_only)
       
   197 
       
   198     def refactor_dir(self, dir_name, write=False, doctests_only=False):
       
   199         """Descends down a directory and refactor every Python file found.
       
   200 
       
   201         Python files are assumed to have a .py extension.
       
   202 
       
   203         Files and subdirectories starting with '.' are skipped.
       
   204         """
       
   205         for dirpath, dirnames, filenames in os.walk(dir_name):
       
   206             self.log_debug("Descending into %s", dirpath)
       
   207             dirnames.sort()
       
   208             filenames.sort()
       
   209             for name in filenames:
       
   210                 if not name.startswith(".") and name.endswith("py"):
       
   211                     fullname = os.path.join(dirpath, name)
       
   212                     self.refactor_file(fullname, write, doctests_only)
       
   213             # Modify dirnames in-place to remove subdirs with leading dots
       
   214             dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")]
       
   215 
       
   216     def refactor_file(self, filename, write=False, doctests_only=False):
       
   217         """Refactors a file."""
       
   218         try:
       
   219             f = open(filename)
       
   220         except IOError, err:
       
   221             self.log_error("Can't open %s: %s", filename, err)
       
   222             return
       
   223         try:
       
   224             input = f.read() + "\n" # Silence certain parse errors
       
   225         finally:
       
   226             f.close()
       
   227         if doctests_only:
       
   228             self.log_debug("Refactoring doctests in %s", filename)
       
   229             output = self.refactor_docstring(input, filename)
       
   230             if output != input:
       
   231                 self.processed_file(output, filename, input, write=write)
       
   232             else:
       
   233                 self.log_debug("No doctest changes in %s", filename)
       
   234         else:
       
   235             tree = self.refactor_string(input, filename)
       
   236             if tree and tree.was_changed:
       
   237                 # The [:-1] is to take off the \n we added earlier
       
   238                 self.processed_file(str(tree)[:-1], filename, write=write)
       
   239             else:
       
   240                 self.log_debug("No changes in %s", filename)
       
   241 
       
   242     def refactor_string(self, data, name):
       
   243         """Refactor a given input string.
       
   244 
       
   245         Args:
       
   246             data: a string holding the code to be refactored.
       
   247             name: a human-readable name for use in error/log messages.
       
   248 
       
   249         Returns:
       
   250             An AST corresponding to the refactored input stream; None if
       
   251             there were errors during the parse.
       
   252         """
       
   253         try:
       
   254             tree = self.driver.parse_string(data)
       
   255         except Exception, err:
       
   256             self.log_error("Can't parse %s: %s: %s",
       
   257                            name, err.__class__.__name__, err)
       
   258             return
       
   259         self.log_debug("Refactoring %s", name)
       
   260         self.refactor_tree(tree, name)
       
   261         return tree
       
   262 
       
   263     def refactor_stdin(self, doctests_only=False):
       
   264         input = sys.stdin.read()
       
   265         if doctests_only:
       
   266             self.log_debug("Refactoring doctests in stdin")
       
   267             output = self.refactor_docstring(input, "<stdin>")
       
   268             if output != input:
       
   269                 self.processed_file(output, "<stdin>", input)
       
   270             else:
       
   271                 self.log_debug("No doctest changes in stdin")
       
   272         else:
       
   273             tree = self.refactor_string(input, "<stdin>")
       
   274             if tree and tree.was_changed:
       
   275                 self.processed_file(str(tree), "<stdin>", input)
       
   276             else:
       
   277                 self.log_debug("No changes in stdin")
       
   278 
       
   279     def refactor_tree(self, tree, name):
       
   280         """Refactors a parse tree (modifying the tree in place).
       
   281 
       
   282         Args:
       
   283             tree: a pytree.Node instance representing the root of the tree
       
   284                   to be refactored.
       
   285             name: a human-readable name for this tree.
       
   286 
       
   287         Returns:
       
   288             True if the tree was modified, False otherwise.
       
   289         """
       
   290         # Two calls to chain are required because pre_order.values()
       
   291         #   will be a list of lists of fixers:
       
   292         #   [[<fixer ...>, <fixer ...>], [<fixer ...>]]
       
   293         all_fixers = chain(self.pre_order, self.post_order)
       
   294         for fixer in all_fixers:
       
   295             fixer.start_tree(tree, name)
       
   296 
       
   297         self.traverse_by(self.pre_order_mapping, tree.pre_order())
       
   298         self.traverse_by(self.post_order_mapping, tree.post_order())
       
   299 
       
   300         for fixer in all_fixers:
       
   301             fixer.finish_tree(tree, name)
       
   302         return tree.was_changed
       
   303 
       
   304     def traverse_by(self, fixers, traversal):
       
   305         """Traverse an AST, applying a set of fixers to each node.
       
   306 
       
   307         This is a helper method for refactor_tree().
       
   308 
       
   309         Args:
       
   310             fixers: a list of fixer instances.
       
   311             traversal: a generator that yields AST nodes.
       
   312 
       
   313         Returns:
       
   314             None
       
   315         """
       
   316         if not fixers:
       
   317             return
       
   318         for node in traversal:
       
   319             for fixer in fixers[node.type] + fixers[None]:
       
   320                 results = fixer.match(node)
       
   321                 if results:
       
   322                     new = fixer.transform(node, results)
       
   323                     if new is not None and (new != node or
       
   324                                             str(new) != str(node)):
       
   325                         node.replace(new)
       
   326                         node = new
       
   327 
       
   328     def processed_file(self, new_text, filename, old_text=None, write=False):
       
   329         """
       
   330         Called when a file has been refactored, and there are changes.
       
   331         """
       
   332         self.files.append(filename)
       
   333         if old_text is None:
       
   334             try:
       
   335                 f = open(filename, "r")
       
   336             except IOError, err:
       
   337                 self.log_error("Can't read %s: %s", filename, err)
       
   338                 return
       
   339             try:
       
   340                 old_text = f.read()
       
   341             finally:
       
   342                 f.close()
       
   343         if old_text == new_text:
       
   344             self.log_debug("No changes to %s", filename)
       
   345             return
       
   346         self.print_output(diff_texts(old_text, new_text, filename))
       
   347         if write:
       
   348             self.write_file(new_text, filename, old_text)
       
   349         else:
       
   350             self.log_debug("Not writing changes to %s", filename)
       
   351 
       
   352     def write_file(self, new_text, filename, old_text):
       
   353         """Writes a string to a file.
       
   354 
       
   355         It first shows a unified diff between the old text and the new text, and
       
   356         then rewrites the file; the latter is only done if the write option is
       
   357         set.
       
   358         """
       
   359         try:
       
   360             f = open(filename, "w")
       
   361         except os.error, err:
       
   362             self.log_error("Can't create %s: %s", filename, err)
       
   363             return
       
   364         try:
       
   365             f.write(new_text)
       
   366         except os.error, err:
       
   367             self.log_error("Can't write %s: %s", filename, err)
       
   368         finally:
       
   369             f.close()
       
   370         self.log_debug("Wrote changes to %s", filename)
       
   371         self.wrote = True
       
   372 
       
   373     PS1 = ">>> "
       
   374     PS2 = "... "
       
   375 
       
   376     def refactor_docstring(self, input, filename):
       
   377         """Refactors a docstring, looking for doctests.
       
   378 
       
   379         This returns a modified version of the input string.  It looks
       
   380         for doctests, which start with a ">>>" prompt, and may be
       
   381         continued with "..." prompts, as long as the "..." is indented
       
   382         the same as the ">>>".
       
   383 
       
   384         (Unfortunately we can't use the doctest module's parser,
       
   385         since, like most parsers, it is not geared towards preserving
       
   386         the original source.)
       
   387         """
       
   388         result = []
       
   389         block = None
       
   390         block_lineno = None
       
   391         indent = None
       
   392         lineno = 0
       
   393         for line in input.splitlines(True):
       
   394             lineno += 1
       
   395             if line.lstrip().startswith(self.PS1):
       
   396                 if block is not None:
       
   397                     result.extend(self.refactor_doctest(block, block_lineno,
       
   398                                                         indent, filename))
       
   399                 block_lineno = lineno
       
   400                 block = [line]
       
   401                 i = line.find(self.PS1)
       
   402                 indent = line[:i]
       
   403             elif (indent is not None and
       
   404                   (line.startswith(indent + self.PS2) or
       
   405                    line == indent + self.PS2.rstrip() + "\n")):
       
   406                 block.append(line)
       
   407             else:
       
   408                 if block is not None:
       
   409                     result.extend(self.refactor_doctest(block, block_lineno,
       
   410                                                         indent, filename))
       
   411                 block = None
       
   412                 indent = None
       
   413                 result.append(line)
       
   414         if block is not None:
       
   415             result.extend(self.refactor_doctest(block, block_lineno,
       
   416                                                 indent, filename))
       
   417         return "".join(result)
       
   418 
       
   419     def refactor_doctest(self, block, lineno, indent, filename):
       
   420         """Refactors one doctest.
       
   421 
       
   422         A doctest is given as a block of lines, the first of which starts
       
   423         with ">>>" (possibly indented), while the remaining lines start
       
   424         with "..." (identically indented).
       
   425 
       
   426         """
       
   427         try:
       
   428             tree = self.parse_block(block, lineno, indent)
       
   429         except Exception, err:
       
   430             if self.log.isEnabledFor(logging.DEBUG):
       
   431                 for line in block:
       
   432                     self.log_debug("Source: %s", line.rstrip("\n"))
       
   433             self.log_error("Can't parse docstring in %s line %s: %s: %s",
       
   434                            filename, lineno, err.__class__.__name__, err)
       
   435             return block
       
   436         if self.refactor_tree(tree, filename):
       
   437             new = str(tree).splitlines(True)
       
   438             # Undo the adjustment of the line numbers in wrap_toks() below.
       
   439             clipped, new = new[:lineno-1], new[lineno-1:]
       
   440             assert clipped == ["\n"] * (lineno-1), clipped
       
   441             if not new[-1].endswith("\n"):
       
   442                 new[-1] += "\n"
       
   443             block = [indent + self.PS1 + new.pop(0)]
       
   444             if new:
       
   445                 block += [indent + self.PS2 + line for line in new]
       
   446         return block
       
   447 
       
   448     def summarize(self):
       
   449         if self.wrote:
       
   450             were = "were"
       
   451         else:
       
   452             were = "need to be"
       
   453         if not self.files:
       
   454             self.log_message("No files %s modified.", were)
       
   455         else:
       
   456             self.log_message("Files that %s modified:", were)
       
   457             for file in self.files:
       
   458                 self.log_message(file)
       
   459         if self.fixer_log:
       
   460             self.log_message("Warnings/messages while refactoring:")
       
   461             for message in self.fixer_log:
       
   462                 self.log_message(message)
       
   463         if self.errors:
       
   464             if len(self.errors) == 1:
       
   465                 self.log_message("There was 1 error:")
       
   466             else:
       
   467                 self.log_message("There were %d errors:", len(self.errors))
       
   468             for msg, args, kwds in self.errors:
       
   469                 self.log_message(msg, *args, **kwds)
       
   470 
       
   471     def parse_block(self, block, lineno, indent):
       
   472         """Parses a block into a tree.
       
   473 
       
   474         This is necessary to get correct line number / offset information
       
   475         in the parser diagnostics and embedded into the parse tree.
       
   476         """
       
   477         return self.driver.parse_tokens(self.wrap_toks(block, lineno, indent))
       
   478 
       
   479     def wrap_toks(self, block, lineno, indent):
       
   480         """Wraps a tokenize stream to systematically modify start/end."""
       
   481         tokens = tokenize.generate_tokens(self.gen_lines(block, indent).next)
       
   482         for type, value, (line0, col0), (line1, col1), line_text in tokens:
       
   483             line0 += lineno - 1
       
   484             line1 += lineno - 1
       
   485             # Don't bother updating the columns; this is too complicated
       
   486             # since line_text would also have to be updated and it would
       
   487             # still break for tokens spanning lines.  Let the user guess
       
   488             # that the column numbers for doctests are relative to the
       
   489             # end of the prompt string (PS1 or PS2).
       
   490             yield type, value, (line0, col0), (line1, col1), line_text
       
   491 
       
   492 
       
   493     def gen_lines(self, block, indent):
       
   494         """Generates lines as expected by tokenize from a list of lines.
       
   495 
       
   496         This strips the first len(indent + self.PS1) characters off each line.
       
   497         """
       
   498         prefix1 = indent + self.PS1
       
   499         prefix2 = indent + self.PS2
       
   500         prefix = prefix1
       
   501         for line in block:
       
   502             if line.startswith(prefix):
       
   503                 yield line[len(prefix):]
       
   504             elif line == prefix.rstrip() + "\n":
       
   505                 yield "\n"
       
   506             else:
       
   507                 raise AssertionError("line=%r, prefix=%r" % (line, prefix))
       
   508             prefix = prefix2
       
   509         while True:
       
   510             yield ""
       
   511 
       
   512 
       
   513 def diff_texts(a, b, filename):
       
   514     """Return a unified diff of two strings."""
       
   515     a = a.splitlines()
       
   516     b = b.splitlines()
       
   517     return difflib.unified_diff(a, b, filename, filename,
       
   518                                 "(original)", "(refactored)",
       
   519                                 lineterm="")