symbian-qemu-0.9.1-12/python-2.6.1/Parser/asdl_c.py
changeset 1 2fb8b9db1c86
equal deleted inserted replaced
0:ffa851df0825 1:2fb8b9db1c86
       
     1 #! /usr/bin/env python
       
     2 """Generate C code from an ASDL description."""
       
     3 
       
     4 # TO DO
       
     5 # handle fields that have a type but no name
       
     6 
       
     7 import os, sys
       
     8 
       
     9 import asdl
       
    10 
       
    11 TABSIZE = 8
       
    12 MAX_COL = 80
       
    13 
       
    14 def get_c_type(name):
       
    15     """Return a string for the C name of the type.
       
    16 
       
    17     This function special cases the default types provided by asdl:
       
    18     identifier, string, int, bool.
       
    19     """
       
    20     # XXX ack!  need to figure out where Id is useful and where string
       
    21     if isinstance(name, asdl.Id):
       
    22         name = name.value
       
    23     if name in asdl.builtin_types:
       
    24         return name
       
    25     else:
       
    26         return "%s_ty" % name
       
    27 
       
    28 def reflow_lines(s, depth):
       
    29     """Reflow the line s indented depth tabs.
       
    30 
       
    31     Return a sequence of lines where no line extends beyond MAX_COL
       
    32     when properly indented.  The first line is properly indented based
       
    33     exclusively on depth * TABSIZE.  All following lines -- these are
       
    34     the reflowed lines generated by this function -- start at the same
       
    35     column as the first character beyond the opening { in the first
       
    36     line.
       
    37     """
       
    38     size = MAX_COL - depth * TABSIZE
       
    39     if len(s) < size:
       
    40         return [s]
       
    41 
       
    42     lines = []
       
    43     cur = s
       
    44     padding = ""
       
    45     while len(cur) > size:
       
    46         i = cur.rfind(' ', 0, size)
       
    47         # XXX this should be fixed for real
       
    48         if i == -1 and 'GeneratorExp' in cur:
       
    49             i = size + 3
       
    50         assert i != -1, "Impossible line %d to reflow: %r" % (size, s)
       
    51         lines.append(padding + cur[:i])
       
    52         if len(lines) == 1:
       
    53             # find new size based on brace
       
    54             j = cur.find('{', 0, i)
       
    55             if j >= 0:
       
    56                 j += 2 # account for the brace and the space after it
       
    57                 size -= j
       
    58                 padding = " " * j
       
    59             else:
       
    60                 j = cur.find('(', 0, i)
       
    61                 if j >= 0:
       
    62                     j += 1 # account for the paren (no space after it)
       
    63                     size -= j
       
    64                     padding = " " * j
       
    65         cur = cur[i+1:]
       
    66     else:
       
    67         lines.append(padding + cur)
       
    68     return lines
       
    69 
       
    70 def is_simple(sum):
       
    71     """Return True if a sum is a simple.
       
    72 
       
    73     A sum is simple if its types have no fields, e.g.
       
    74     unaryop = Invert | Not | UAdd | USub
       
    75     """
       
    76     for t in sum.types:
       
    77         if t.fields:
       
    78             return False
       
    79     return True
       
    80 
       
    81 
       
    82 class EmitVisitor(asdl.VisitorBase):
       
    83     """Visit that emits lines"""
       
    84 
       
    85     def __init__(self, file):
       
    86         self.file = file
       
    87         super(EmitVisitor, self).__init__()
       
    88 
       
    89     def emit(self, s, depth, reflow=1):
       
    90         # XXX reflow long lines?
       
    91         if reflow:
       
    92             lines = reflow_lines(s, depth)
       
    93         else:
       
    94             lines = [s]
       
    95         for line in lines:
       
    96             line = (" " * TABSIZE * depth) + line + "\n"
       
    97             self.file.write(line)
       
    98 
       
    99 
       
   100 class TypeDefVisitor(EmitVisitor):
       
   101     def visitModule(self, mod):
       
   102         for dfn in mod.dfns:
       
   103             self.visit(dfn)
       
   104 
       
   105     def visitType(self, type, depth=0):
       
   106         self.visit(type.value, type.name, depth)
       
   107 
       
   108     def visitSum(self, sum, name, depth):
       
   109         if is_simple(sum):
       
   110             self.simple_sum(sum, name, depth)
       
   111         else:
       
   112             self.sum_with_constructors(sum, name, depth)
       
   113 
       
   114     def simple_sum(self, sum, name, depth):
       
   115         enum = []
       
   116         for i in range(len(sum.types)):
       
   117             type = sum.types[i]
       
   118             enum.append("%s=%d" % (type.name, i + 1))
       
   119         enums = ", ".join(enum)
       
   120         ctype = get_c_type(name)
       
   121         s = "typedef enum _%s { %s } %s;" % (name, enums, ctype)
       
   122         self.emit(s, depth)
       
   123         self.emit("", depth)
       
   124 
       
   125     def sum_with_constructors(self, sum, name, depth):
       
   126         ctype = get_c_type(name)
       
   127         s = "typedef struct _%(name)s *%(ctype)s;" % locals()
       
   128         self.emit(s, depth)
       
   129         self.emit("", depth)
       
   130 
       
   131     def visitProduct(self, product, name, depth):
       
   132         ctype = get_c_type(name)
       
   133         s = "typedef struct _%(name)s *%(ctype)s;" % locals()
       
   134         self.emit(s, depth)
       
   135         self.emit("", depth)
       
   136 
       
   137 
       
   138 class StructVisitor(EmitVisitor):
       
   139     """Visitor to generate typdefs for AST."""
       
   140 
       
   141     def visitModule(self, mod):
       
   142         for dfn in mod.dfns:
       
   143             self.visit(dfn)
       
   144 
       
   145     def visitType(self, type, depth=0):
       
   146         self.visit(type.value, type.name, depth)
       
   147 
       
   148     def visitSum(self, sum, name, depth):
       
   149         if not is_simple(sum):
       
   150             self.sum_with_constructors(sum, name, depth)
       
   151 
       
   152     def sum_with_constructors(self, sum, name, depth):
       
   153         def emit(s, depth=depth):
       
   154             self.emit(s % sys._getframe(1).f_locals, depth)
       
   155         enum = []
       
   156         for i in range(len(sum.types)):
       
   157             type = sum.types[i]
       
   158             enum.append("%s_kind=%d" % (type.name, i + 1))
       
   159 
       
   160         emit("enum _%(name)s_kind {" + ", ".join(enum) + "};")
       
   161 
       
   162         emit("struct _%(name)s {")
       
   163         emit("enum _%(name)s_kind kind;", depth + 1)
       
   164         emit("union {", depth + 1)
       
   165         for t in sum.types:
       
   166             self.visit(t, depth + 2)
       
   167         emit("} v;", depth + 1)
       
   168         for field in sum.attributes:
       
   169             # rudimentary attribute handling
       
   170             type = str(field.type)
       
   171             assert type in asdl.builtin_types, type
       
   172             emit("%s %s;" % (type, field.name), depth + 1);
       
   173         emit("};")
       
   174         emit("")
       
   175 
       
   176     def visitConstructor(self, cons, depth):
       
   177         if cons.fields:
       
   178             self.emit("struct {", depth)
       
   179             for f in cons.fields:
       
   180                 self.visit(f, depth + 1)
       
   181             self.emit("} %s;" % cons.name, depth)
       
   182             self.emit("", depth)
       
   183         else:
       
   184             # XXX not sure what I want here, nothing is probably fine
       
   185             pass
       
   186 
       
   187     def visitField(self, field, depth):
       
   188         # XXX need to lookup field.type, because it might be something
       
   189         # like a builtin...
       
   190         ctype = get_c_type(field.type)
       
   191         name = field.name
       
   192         if field.seq:
       
   193             if field.type.value in ('cmpop',):
       
   194                 self.emit("asdl_int_seq *%(name)s;" % locals(), depth)
       
   195             else:
       
   196                 self.emit("asdl_seq *%(name)s;" % locals(), depth)
       
   197         else:
       
   198             self.emit("%(ctype)s %(name)s;" % locals(), depth)
       
   199 
       
   200     def visitProduct(self, product, name, depth):
       
   201         self.emit("struct _%(name)s {" % locals(), depth)
       
   202         for f in product.fields:
       
   203             self.visit(f, depth + 1)
       
   204         self.emit("};", depth)
       
   205         self.emit("", depth)
       
   206 
       
   207 
       
   208 class PrototypeVisitor(EmitVisitor):
       
   209     """Generate function prototypes for the .h file"""
       
   210 
       
   211     def visitModule(self, mod):
       
   212         for dfn in mod.dfns:
       
   213             self.visit(dfn)
       
   214 
       
   215     def visitType(self, type):
       
   216         self.visit(type.value, type.name)
       
   217 
       
   218     def visitSum(self, sum, name):
       
   219         if is_simple(sum):
       
   220             pass # XXX
       
   221         else:
       
   222             for t in sum.types:
       
   223                 self.visit(t, name, sum.attributes)
       
   224 
       
   225     def get_args(self, fields):
       
   226         """Return list of C argument into, one for each field.
       
   227 
       
   228         Argument info is 3-tuple of a C type, variable name, and flag
       
   229         that is true if type can be NULL.
       
   230         """
       
   231         args = []
       
   232         unnamed = {}
       
   233         for f in fields:
       
   234             if f.name is None:
       
   235                 name = f.type
       
   236                 c = unnamed[name] = unnamed.get(name, 0) + 1
       
   237                 if c > 1:
       
   238                     name = "name%d" % (c - 1)
       
   239             else:
       
   240                 name = f.name
       
   241             # XXX should extend get_c_type() to handle this
       
   242             if f.seq:
       
   243                 if f.type.value in ('cmpop',):
       
   244                     ctype = "asdl_int_seq *"
       
   245                 else:
       
   246                     ctype = "asdl_seq *"
       
   247             else:
       
   248                 ctype = get_c_type(f.type)
       
   249             args.append((ctype, name, f.opt or f.seq))
       
   250         return args
       
   251 
       
   252     def visitConstructor(self, cons, type, attrs):
       
   253         args = self.get_args(cons.fields)
       
   254         attrs = self.get_args(attrs)
       
   255         ctype = get_c_type(type)
       
   256         self.emit_function(cons.name, ctype, args, attrs)
       
   257 
       
   258     def emit_function(self, name, ctype, args, attrs, union=1):
       
   259         args = args + attrs
       
   260         if args:
       
   261             argstr = ", ".join(["%s %s" % (atype, aname)
       
   262                                 for atype, aname, opt in args])
       
   263             argstr += ", PyArena *arena"
       
   264         else:
       
   265             argstr = "PyArena *arena"
       
   266         margs = "a0"
       
   267         for i in range(1, len(args)+1):
       
   268             margs += ", a%d" % i
       
   269         self.emit("#define %s(%s) _Py_%s(%s)" % (name, margs, name, margs), 0,
       
   270                 reflow = 0)
       
   271         self.emit("%s _Py_%s(%s);" % (ctype, name, argstr), 0)
       
   272 
       
   273     def visitProduct(self, prod, name):
       
   274         self.emit_function(name, get_c_type(name),
       
   275                            self.get_args(prod.fields), [], union=0)
       
   276 
       
   277 
       
   278 class FunctionVisitor(PrototypeVisitor):
       
   279     """Visitor to generate constructor functions for AST."""
       
   280 
       
   281     def emit_function(self, name, ctype, args, attrs, union=1):
       
   282         def emit(s, depth=0, reflow=1):
       
   283             self.emit(s, depth, reflow)
       
   284         argstr = ", ".join(["%s %s" % (atype, aname)
       
   285                             for atype, aname, opt in args + attrs])
       
   286         if argstr:
       
   287             argstr += ", PyArena *arena"
       
   288         else:
       
   289             argstr = "PyArena *arena"
       
   290         self.emit("%s" % ctype, 0)
       
   291         emit("%s(%s)" % (name, argstr))
       
   292         emit("{")
       
   293         emit("%s p;" % ctype, 1)
       
   294         for argtype, argname, opt in args:
       
   295             # XXX hack alert: false is allowed for a bool
       
   296             if not opt and not (argtype == "bool" or argtype == "int"):
       
   297                 emit("if (!%s) {" % argname, 1)
       
   298                 emit("PyErr_SetString(PyExc_ValueError,", 2)
       
   299                 msg = "field %s is required for %s" % (argname, name)
       
   300                 emit('                "%s");' % msg,
       
   301                      2, reflow=0)
       
   302                 emit('return NULL;', 2)
       
   303                 emit('}', 1)
       
   304 
       
   305         emit("p = (%s)PyArena_Malloc(arena, sizeof(*p));" % ctype, 1);
       
   306         emit("if (!p)", 1)
       
   307         emit("return NULL;", 2)
       
   308         if union:
       
   309             self.emit_body_union(name, args, attrs)
       
   310         else:
       
   311             self.emit_body_struct(name, args, attrs)
       
   312         emit("return p;", 1)
       
   313         emit("}")
       
   314         emit("")
       
   315 
       
   316     def emit_body_union(self, name, args, attrs):
       
   317         def emit(s, depth=0, reflow=1):
       
   318             self.emit(s, depth, reflow)
       
   319         emit("p->kind = %s_kind;" % name, 1)
       
   320         for argtype, argname, opt in args:
       
   321             emit("p->v.%s.%s = %s;" % (name, argname, argname), 1)
       
   322         for argtype, argname, opt in attrs:
       
   323             emit("p->%s = %s;" % (argname, argname), 1)
       
   324 
       
   325     def emit_body_struct(self, name, args, attrs):
       
   326         def emit(s, depth=0, reflow=1):
       
   327             self.emit(s, depth, reflow)
       
   328         for argtype, argname, opt in args:
       
   329             emit("p->%s = %s;" % (argname, argname), 1)
       
   330         assert not attrs
       
   331 
       
   332 
       
   333 class PickleVisitor(EmitVisitor):
       
   334 
       
   335     def visitModule(self, mod):
       
   336         for dfn in mod.dfns:
       
   337             self.visit(dfn)
       
   338 
       
   339     def visitType(self, type):
       
   340         self.visit(type.value, type.name)
       
   341 
       
   342     def visitSum(self, sum, name):
       
   343         pass
       
   344 
       
   345     def visitProduct(self, sum, name):
       
   346         pass
       
   347 
       
   348     def visitConstructor(self, cons, name):
       
   349         pass
       
   350 
       
   351     def visitField(self, sum):
       
   352         pass
       
   353 
       
   354 
       
   355 class Obj2ModPrototypeVisitor(PickleVisitor):
       
   356     def visitProduct(self, prod, name):
       
   357         code = "static int obj2ast_%s(PyObject* obj, %s* out, PyArena* arena);"
       
   358         self.emit(code % (name, get_c_type(name)), 0)
       
   359 
       
   360     visitSum = visitProduct
       
   361 
       
   362 
       
   363 class Obj2ModVisitor(PickleVisitor):
       
   364     def funcHeader(self, name):
       
   365         ctype = get_c_type(name)
       
   366         self.emit("int", 0)
       
   367         self.emit("obj2ast_%s(PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
       
   368         self.emit("{", 0)
       
   369         self.emit("PyObject* tmp = NULL;", 1)
       
   370         self.emit("", 0)
       
   371 
       
   372     def sumTrailer(self, name):
       
   373         self.emit("", 0)
       
   374         self.emit("tmp = PyObject_Repr(obj);", 1)
       
   375         # there's really nothing more we can do if this fails ...
       
   376         self.emit("if (tmp == NULL) goto failed;", 1)
       
   377         error = "expected some sort of %s, but got %%.400s" % name
       
   378         format = "PyErr_Format(PyExc_TypeError, \"%s\", PyString_AS_STRING(tmp));"
       
   379         self.emit(format % error, 1, reflow=False)
       
   380         self.emit("failed:", 0)
       
   381         self.emit("Py_XDECREF(tmp);", 1)
       
   382         self.emit("return 1;", 1)
       
   383         self.emit("}", 0)
       
   384         self.emit("", 0)
       
   385 
       
   386     def simpleSum(self, sum, name):
       
   387         self.funcHeader(name)
       
   388         for t in sum.types:
       
   389             self.emit("if (PyObject_IsInstance(obj, (PyObject*)%s_type)) {" % t.name, 1)
       
   390             self.emit("*out = %s;" % t.name, 2)
       
   391             self.emit("return 0;", 2)
       
   392             self.emit("}", 1)
       
   393         self.sumTrailer(name)
       
   394 
       
   395     def buildArgs(self, fields):
       
   396         return ", ".join(fields + ["arena"])
       
   397 
       
   398     def complexSum(self, sum, name):
       
   399         self.funcHeader(name)
       
   400         for a in sum.attributes:
       
   401             self.visitAttributeDeclaration(a, name, sum=sum)
       
   402         self.emit("", 0)
       
   403         # XXX: should we only do this for 'expr'?
       
   404         self.emit("if (obj == Py_None) {", 1)
       
   405         self.emit("*out = NULL;", 2)
       
   406         self.emit("return 0;", 2)
       
   407         self.emit("}", 1)
       
   408         for a in sum.attributes:
       
   409             self.visitField(a, name, sum=sum, depth=1)
       
   410         for t in sum.types:
       
   411             self.emit("if (PyObject_IsInstance(obj, (PyObject*)%s_type)) {" % t.name, 1)
       
   412             for f in t.fields:
       
   413                 self.visitFieldDeclaration(f, t.name, sum=sum, depth=2)
       
   414             self.emit("", 0)
       
   415             for f in t.fields:
       
   416                 self.visitField(f, t.name, sum=sum, depth=2)
       
   417             args = [f.name.value for f in t.fields] + [a.name.value for a in sum.attributes]
       
   418             self.emit("*out = %s(%s);" % (t.name, self.buildArgs(args)), 2)
       
   419             self.emit("if (*out == NULL) goto failed;", 2)
       
   420             self.emit("return 0;", 2)
       
   421             self.emit("}", 1)
       
   422         self.sumTrailer(name)
       
   423 
       
   424     def visitAttributeDeclaration(self, a, name, sum=sum):
       
   425         ctype = get_c_type(a.type)
       
   426         self.emit("%s %s;" % (ctype, a.name), 1)
       
   427 
       
   428     def visitSum(self, sum, name):
       
   429         if is_simple(sum):
       
   430             self.simpleSum(sum, name)
       
   431         else:
       
   432             self.complexSum(sum, name)
       
   433 
       
   434     def visitProduct(self, prod, name):
       
   435         ctype = get_c_type(name)
       
   436         self.emit("int", 0)
       
   437         self.emit("obj2ast_%s(PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
       
   438         self.emit("{", 0)
       
   439         self.emit("PyObject* tmp = NULL;", 1)
       
   440         for f in prod.fields:
       
   441             self.visitFieldDeclaration(f, name, prod=prod, depth=1)
       
   442         self.emit("", 0)
       
   443         for f in prod.fields:
       
   444             self.visitField(f, name, prod=prod, depth=1)
       
   445         args = [f.name.value for f in prod.fields]
       
   446         self.emit("*out = %s(%s);" % (name, self.buildArgs(args)), 1)
       
   447         self.emit("return 0;", 1)
       
   448         self.emit("failed:", 0)
       
   449         self.emit("Py_XDECREF(tmp);", 1)
       
   450         self.emit("return 1;", 1)
       
   451         self.emit("}", 0)
       
   452         self.emit("", 0)
       
   453 
       
   454     def visitFieldDeclaration(self, field, name, sum=None, prod=None, depth=0):
       
   455         ctype = get_c_type(field.type)
       
   456         if field.seq:
       
   457             if self.isSimpleType(field):
       
   458                 self.emit("asdl_int_seq* %s;" % field.name, depth)
       
   459             else:
       
   460                 self.emit("asdl_seq* %s;" % field.name, depth)
       
   461         else:
       
   462             ctype = get_c_type(field.type)
       
   463             self.emit("%s %s;" % (ctype, field.name), depth)
       
   464 
       
   465     def isSimpleSum(self, field):
       
   466         # XXX can the members of this list be determined automatically?
       
   467         return field.type.value in ('expr_context', 'boolop', 'operator',
       
   468                                     'unaryop', 'cmpop')
       
   469 
       
   470     def isNumeric(self, field):
       
   471         return get_c_type(field.type) in ("int", "bool")
       
   472 
       
   473     def isSimpleType(self, field):
       
   474         return self.isSimpleSum(field) or self.isNumeric(field)
       
   475 
       
   476     def visitField(self, field, name, sum=None, prod=None, depth=0):
       
   477         ctype = get_c_type(field.type)
       
   478         self.emit("if (PyObject_HasAttrString(obj, \"%s\")) {" % field.name, depth)
       
   479         self.emit("int res;", depth+1)
       
   480         if field.seq:
       
   481             self.emit("Py_ssize_t len;", depth+1)
       
   482             self.emit("Py_ssize_t i;", depth+1)
       
   483         self.emit("tmp = PyObject_GetAttrString(obj, \"%s\");" % field.name, depth+1)
       
   484         self.emit("if (tmp == NULL) goto failed;", depth+1)
       
   485         if field.seq:
       
   486             self.emit("if (!PyList_Check(tmp)) {", depth+1)
       
   487             self.emit("PyErr_Format(PyExc_TypeError, \"%s field \\\"%s\\\" must "
       
   488                       "be a list, not a %%.200s\", tmp->ob_type->tp_name);" %
       
   489                       (name, field.name),
       
   490                       depth+2, reflow=False)
       
   491             self.emit("goto failed;", depth+2)
       
   492             self.emit("}", depth+1)
       
   493             self.emit("len = PyList_GET_SIZE(tmp);", depth+1)
       
   494             if self.isSimpleType(field):
       
   495                 self.emit("%s = asdl_int_seq_new(len, arena);" % field.name, depth+1)
       
   496             else:
       
   497                 self.emit("%s = asdl_seq_new(len, arena);" % field.name, depth+1)
       
   498             self.emit("if (%s == NULL) goto failed;" % field.name, depth+1)
       
   499             self.emit("for (i = 0; i < len; i++) {", depth+1)
       
   500             self.emit("%s value;" % ctype, depth+2)
       
   501             self.emit("res = obj2ast_%s(PyList_GET_ITEM(tmp, i), &value, arena);" %
       
   502                       field.type, depth+2, reflow=False)
       
   503             self.emit("if (res != 0) goto failed;", depth+2)
       
   504             self.emit("asdl_seq_SET(%s, i, value);" % field.name, depth+2)
       
   505             self.emit("}", depth+1)
       
   506         else:
       
   507             self.emit("res = obj2ast_%s(tmp, &%s, arena);" %
       
   508                       (field.type, field.name), depth+1)
       
   509             self.emit("if (res != 0) goto failed;", depth+1)
       
   510 
       
   511         self.emit("Py_XDECREF(tmp);", depth+1)
       
   512         self.emit("tmp = NULL;", depth+1)
       
   513         self.emit("} else {", depth)
       
   514         if not field.opt:
       
   515             message = "required field \\\"%s\\\" missing from %s" % (field.name, name)
       
   516             format = "PyErr_SetString(PyExc_TypeError, \"%s\");"
       
   517             self.emit(format % message, depth+1, reflow=False)
       
   518             self.emit("return 1;", depth+1)
       
   519         else:
       
   520             if self.isNumeric(field):
       
   521                 self.emit("%s = 0;" % field.name, depth+1)
       
   522             elif not self.isSimpleType(field):
       
   523                 self.emit("%s = NULL;" % field.name, depth+1)
       
   524             else:
       
   525                 raise TypeError("could not determine the default value for %s" % field.name)
       
   526         self.emit("}", depth)
       
   527 
       
   528 
       
   529 class MarshalPrototypeVisitor(PickleVisitor):
       
   530 
       
   531     def prototype(self, sum, name):
       
   532         ctype = get_c_type(name)
       
   533         self.emit("static int marshal_write_%s(PyObject **, int *, %s);"
       
   534                   % (name, ctype), 0)
       
   535 
       
   536     visitProduct = visitSum = prototype
       
   537 
       
   538 
       
   539 class PyTypesDeclareVisitor(PickleVisitor):
       
   540 
       
   541     def visitProduct(self, prod, name):
       
   542         self.emit("static PyTypeObject *%s_type;" % name, 0)
       
   543         self.emit("static PyObject* ast2obj_%s(void*);" % name, 0)
       
   544         if prod.fields:
       
   545             self.emit("static char *%s_fields[]={" % name,0)
       
   546             for f in prod.fields:
       
   547                 self.emit('"%s",' % f.name, 1)
       
   548             self.emit("};", 0)
       
   549 
       
   550     def visitSum(self, sum, name):
       
   551         self.emit("static PyTypeObject *%s_type;" % name, 0)
       
   552         if sum.attributes:
       
   553             self.emit("static char *%s_attributes[] = {" % name, 0)
       
   554             for a in sum.attributes:
       
   555                 self.emit('"%s",' % a.name, 1)
       
   556             self.emit("};", 0)
       
   557         ptype = "void*"
       
   558         if is_simple(sum):
       
   559             ptype = get_c_type(name)
       
   560             tnames = []
       
   561             for t in sum.types:
       
   562                 tnames.append(str(t.name)+"_singleton")
       
   563             tnames = ", *".join(tnames)
       
   564             self.emit("static PyObject *%s;" % tnames, 0)
       
   565         self.emit("static PyObject* ast2obj_%s(%s);" % (name, ptype), 0)
       
   566         for t in sum.types:
       
   567             self.visitConstructor(t, name)
       
   568 
       
   569     def visitConstructor(self, cons, name):
       
   570         self.emit("static PyTypeObject *%s_type;" % cons.name, 0)
       
   571         if cons.fields:
       
   572             self.emit("static char *%s_fields[]={" % cons.name, 0)
       
   573             for t in cons.fields:
       
   574                 self.emit('"%s",' % t.name, 1)
       
   575             self.emit("};",0)
       
   576 
       
   577 class PyTypesVisitor(PickleVisitor):
       
   578 
       
   579     def visitModule(self, mod):
       
   580         self.emit("""
       
   581 static int
       
   582 ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
       
   583 {
       
   584     Py_ssize_t i, numfields = 0;
       
   585     int res = -1;
       
   586     PyObject *key, *value, *fields;
       
   587     fields = PyObject_GetAttrString((PyObject*)Py_TYPE(self), "_fields");
       
   588     if (!fields)
       
   589         PyErr_Clear();
       
   590     if (fields) {
       
   591         numfields = PySequence_Size(fields);
       
   592         if (numfields == -1)
       
   593             goto cleanup;
       
   594     }
       
   595     res = 0; /* if no error occurs, this stays 0 to the end */
       
   596     if (PyTuple_GET_SIZE(args) > 0) {
       
   597         if (numfields != PyTuple_GET_SIZE(args)) {
       
   598             PyErr_Format(PyExc_TypeError, "%.400s constructor takes %s"
       
   599                          "%zd positional argument%s",
       
   600                          Py_TYPE(self)->tp_name,
       
   601                          numfields == 0 ? "" : "either 0 or ",
       
   602                          numfields, numfields == 1 ? "" : "s");
       
   603             res = -1;
       
   604             goto cleanup;
       
   605         }
       
   606         for (i = 0; i < PyTuple_GET_SIZE(args); i++) {
       
   607             /* cannot be reached when fields is NULL */
       
   608             PyObject *name = PySequence_GetItem(fields, i);
       
   609             if (!name) {
       
   610                 res = -1;
       
   611                 goto cleanup;
       
   612             }
       
   613             res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i));
       
   614             Py_DECREF(name);
       
   615             if (res < 0)
       
   616                 goto cleanup;
       
   617         }
       
   618     }
       
   619     if (kw) {
       
   620         i = 0;  /* needed by PyDict_Next */
       
   621         while (PyDict_Next(kw, &i, &key, &value)) {
       
   622             res = PyObject_SetAttr(self, key, value);
       
   623             if (res < 0)
       
   624                 goto cleanup;
       
   625         }
       
   626     }
       
   627   cleanup:
       
   628     Py_XDECREF(fields);
       
   629     return res;
       
   630 }
       
   631 
       
   632 /* Pickling support */
       
   633 static PyObject *
       
   634 ast_type_reduce(PyObject *self, PyObject *unused)
       
   635 {
       
   636     PyObject *res;
       
   637     PyObject *dict = PyObject_GetAttrString(self, "__dict__");
       
   638     if (dict == NULL) {
       
   639         if (PyErr_ExceptionMatches(PyExc_AttributeError))
       
   640             PyErr_Clear();
       
   641         else
       
   642             return NULL;
       
   643     }
       
   644     if (dict) {
       
   645         res = Py_BuildValue("O()O", Py_TYPE(self), dict);
       
   646         Py_DECREF(dict);
       
   647         return res;
       
   648     }
       
   649     return Py_BuildValue("O()", Py_TYPE(self));
       
   650 }
       
   651 
       
   652 static PyMethodDef ast_type_methods[] = {
       
   653     {"__reduce__", ast_type_reduce, METH_NOARGS, NULL},
       
   654     {NULL}
       
   655 };
       
   656 
       
   657 static PyTypeObject AST_type = {
       
   658     PyVarObject_HEAD_INIT(&PyType_Type, 0)
       
   659     "_ast.AST",
       
   660     sizeof(PyObject),
       
   661     0,
       
   662     0,                       /* tp_dealloc */
       
   663     0,                       /* tp_print */
       
   664     0,                       /* tp_getattr */
       
   665     0,                       /* tp_setattr */
       
   666     0,                       /* tp_compare */
       
   667     0,                       /* tp_repr */
       
   668     0,                       /* tp_as_number */
       
   669     0,                       /* tp_as_sequence */
       
   670     0,                       /* tp_as_mapping */
       
   671     0,                       /* tp_hash */
       
   672     0,                       /* tp_call */
       
   673     0,                       /* tp_str */
       
   674     PyObject_GenericGetAttr, /* tp_getattro */
       
   675     PyObject_GenericSetAttr, /* tp_setattro */
       
   676     0,                       /* tp_as_buffer */
       
   677     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
       
   678     0,                       /* tp_doc */
       
   679     0,                       /* tp_traverse */
       
   680     0,                       /* tp_clear */
       
   681     0,                       /* tp_richcompare */
       
   682     0,                       /* tp_weaklistoffset */
       
   683     0,                       /* tp_iter */
       
   684     0,                       /* tp_iternext */
       
   685     ast_type_methods,        /* tp_methods */
       
   686     0,                       /* tp_members */
       
   687     0,                       /* tp_getset */
       
   688     0,                       /* tp_base */
       
   689     0,                       /* tp_dict */
       
   690     0,                       /* tp_descr_get */
       
   691     0,                       /* tp_descr_set */
       
   692     0,                       /* tp_dictoffset */
       
   693     (initproc)ast_type_init, /* tp_init */
       
   694     PyType_GenericAlloc,     /* tp_alloc */
       
   695     PyType_GenericNew,       /* tp_new */
       
   696     PyObject_Del,            /* tp_free */
       
   697 };
       
   698 
       
   699 
       
   700 static PyTypeObject* make_type(char *type, PyTypeObject* base, char**fields, int num_fields)
       
   701 {
       
   702     PyObject *fnames, *result;
       
   703     int i;
       
   704     fnames = PyTuple_New(num_fields);
       
   705     if (!fnames) return NULL;
       
   706     for (i = 0; i < num_fields; i++) {
       
   707         PyObject *field = PyString_FromString(fields[i]);
       
   708         if (!field) {
       
   709             Py_DECREF(fnames);
       
   710             return NULL;
       
   711         }
       
   712         PyTuple_SET_ITEM(fnames, i, field);
       
   713     }
       
   714     result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){sOss}",
       
   715                     type, base, "_fields", fnames, "__module__", "_ast");
       
   716     Py_DECREF(fnames);
       
   717     return (PyTypeObject*)result;
       
   718 }
       
   719 
       
   720 static int add_attributes(PyTypeObject* type, char**attrs, int num_fields)
       
   721 {
       
   722     int i, result;
       
   723     PyObject *s, *l = PyTuple_New(num_fields);
       
   724     if (!l) return 0;
       
   725     for(i = 0; i < num_fields; i++) {
       
   726         s = PyString_FromString(attrs[i]);
       
   727         if (!s) {
       
   728             Py_DECREF(l);
       
   729             return 0;
       
   730         }
       
   731         PyTuple_SET_ITEM(l, i, s);
       
   732     }
       
   733     result = PyObject_SetAttrString((PyObject*)type, "_attributes", l) >= 0;
       
   734     Py_DECREF(l);
       
   735     return result;
       
   736 }
       
   737 
       
   738 /* Conversion AST -> Python */
       
   739 
       
   740 static PyObject* ast2obj_list(asdl_seq *seq, PyObject* (*func)(void*))
       
   741 {
       
   742     int i, n = asdl_seq_LEN(seq);
       
   743     PyObject *result = PyList_New(n);
       
   744     PyObject *value;
       
   745     if (!result)
       
   746         return NULL;
       
   747     for (i = 0; i < n; i++) {
       
   748         value = func(asdl_seq_GET(seq, i));
       
   749         if (!value) {
       
   750             Py_DECREF(result);
       
   751             return NULL;
       
   752         }
       
   753         PyList_SET_ITEM(result, i, value);
       
   754     }
       
   755     return result;
       
   756 }
       
   757 
       
   758 static PyObject* ast2obj_object(void *o)
       
   759 {
       
   760     if (!o)
       
   761         o = Py_None;
       
   762     Py_INCREF((PyObject*)o);
       
   763     return (PyObject*)o;
       
   764 }
       
   765 #define ast2obj_identifier ast2obj_object
       
   766 #define ast2obj_string ast2obj_object
       
   767 static PyObject* ast2obj_bool(bool b)
       
   768 {
       
   769     return PyBool_FromLong(b);
       
   770 }
       
   771 
       
   772 static PyObject* ast2obj_int(long b)
       
   773 {
       
   774     return PyInt_FromLong(b);
       
   775 }
       
   776 
       
   777 /* Conversion Python -> AST */
       
   778 
       
   779 static int obj2ast_object(PyObject* obj, PyObject** out, PyArena* arena)
       
   780 {
       
   781     if (obj == Py_None)
       
   782         obj = NULL;
       
   783     if (obj)
       
   784         PyArena_AddPyObject(arena, obj);
       
   785     Py_XINCREF(obj);
       
   786     *out = obj;
       
   787     return 0;
       
   788 }
       
   789 
       
   790 #define obj2ast_identifier obj2ast_object
       
   791 #define obj2ast_string obj2ast_object
       
   792 
       
   793 static int obj2ast_int(PyObject* obj, int* out, PyArena* arena)
       
   794 {
       
   795     int i;
       
   796     if (!PyInt_Check(obj) && !PyLong_Check(obj)) {
       
   797         PyObject *s = PyObject_Repr(obj);
       
   798         if (s == NULL) return 1;
       
   799         PyErr_Format(PyExc_ValueError, "invalid integer value: %.400s",
       
   800                      PyString_AS_STRING(s));
       
   801         Py_DECREF(s);
       
   802         return 1;
       
   803     }
       
   804 
       
   805     i = (int)PyLong_AsLong(obj);
       
   806     if (i == -1 && PyErr_Occurred())
       
   807         return 1;
       
   808     *out = i;
       
   809     return 0;
       
   810 }
       
   811 
       
   812 static int obj2ast_bool(PyObject* obj, bool* out, PyArena* arena)
       
   813 {
       
   814     if (!PyBool_Check(obj)) {
       
   815         PyObject *s = PyObject_Repr(obj);
       
   816         if (s == NULL) return 1;
       
   817         PyErr_Format(PyExc_ValueError, "invalid boolean value: %.400s",
       
   818                      PyString_AS_STRING(s));
       
   819         Py_DECREF(s);
       
   820         return 1;
       
   821     }
       
   822 
       
   823     *out = (obj == Py_True);
       
   824     return 0;
       
   825 }
       
   826 
       
   827 static int add_ast_fields(void)
       
   828 {
       
   829     PyObject *empty_tuple, *d;
       
   830     if (PyType_Ready(&AST_type) < 0)
       
   831         return -1;
       
   832     d = AST_type.tp_dict;
       
   833     empty_tuple = PyTuple_New(0);
       
   834     if (!empty_tuple ||
       
   835         PyDict_SetItemString(d, "_fields", empty_tuple) < 0 ||
       
   836         PyDict_SetItemString(d, "_attributes", empty_tuple) < 0) {
       
   837         Py_XDECREF(empty_tuple);
       
   838         return -1;
       
   839     }
       
   840     Py_DECREF(empty_tuple);
       
   841     return 0;
       
   842 }
       
   843 
       
   844 """, 0, reflow=False)
       
   845 
       
   846         self.emit("static int init_types(void)",0)
       
   847         self.emit("{", 0)
       
   848         self.emit("static int initialized;", 1)
       
   849         self.emit("if (initialized) return 1;", 1)
       
   850         self.emit("if (add_ast_fields() < 0) return 0;", 1)
       
   851         for dfn in mod.dfns:
       
   852             self.visit(dfn)
       
   853         self.emit("initialized = 1;", 1)
       
   854         self.emit("return 1;", 1);
       
   855         self.emit("}", 0)
       
   856 
       
   857     def visitProduct(self, prod, name):
       
   858         if prod.fields:
       
   859             fields = name.value+"_fields"
       
   860         else:
       
   861             fields = "NULL"
       
   862         self.emit('%s_type = make_type("%s", &AST_type, %s, %d);' %
       
   863                         (name, name, fields, len(prod.fields)), 1)
       
   864         self.emit("if (!%s_type) return 0;" % name, 1)
       
   865 
       
   866     def visitSum(self, sum, name):
       
   867         self.emit('%s_type = make_type("%s", &AST_type, NULL, 0);' %
       
   868                   (name, name), 1)
       
   869         self.emit("if (!%s_type) return 0;" % name, 1)
       
   870         if sum.attributes:
       
   871             self.emit("if (!add_attributes(%s_type, %s_attributes, %d)) return 0;" %
       
   872                             (name, name, len(sum.attributes)), 1)
       
   873         else:
       
   874             self.emit("if (!add_attributes(%s_type, NULL, 0)) return 0;" % name, 1)
       
   875         simple = is_simple(sum)
       
   876         for t in sum.types:
       
   877             self.visitConstructor(t, name, simple)
       
   878 
       
   879     def visitConstructor(self, cons, name, simple):
       
   880         if cons.fields:
       
   881             fields = cons.name.value+"_fields"
       
   882         else:
       
   883             fields = "NULL"
       
   884         self.emit('%s_type = make_type("%s", %s_type, %s, %d);' %
       
   885                             (cons.name, cons.name, name, fields, len(cons.fields)), 1)
       
   886         self.emit("if (!%s_type) return 0;" % cons.name, 1)
       
   887         if simple:
       
   888             self.emit("%s_singleton = PyType_GenericNew(%s_type, NULL, NULL);" %
       
   889                              (cons.name, cons.name), 1)
       
   890             self.emit("if (!%s_singleton) return 0;" % cons.name, 1)
       
   891 
       
   892 
       
   893 def parse_version(mod):
       
   894     return mod.version.value[12:-3]
       
   895 
       
   896 class ASTModuleVisitor(PickleVisitor):
       
   897 
       
   898     def visitModule(self, mod):
       
   899         self.emit("PyMODINIT_FUNC", 0)
       
   900         self.emit("init_ast(void)", 0)
       
   901         self.emit("{", 0)
       
   902         self.emit("PyObject *m, *d;", 1)
       
   903         self.emit("if (!init_types()) return;", 1)
       
   904         self.emit('m = Py_InitModule3("_ast", NULL, NULL);', 1)
       
   905         self.emit("if (!m) return;", 1)
       
   906         self.emit("d = PyModule_GetDict(m);", 1)
       
   907         self.emit('if (PyDict_SetItemString(d, "AST", (PyObject*)&AST_type) < 0) return;', 1)
       
   908         self.emit('if (PyModule_AddIntConstant(m, "PyCF_ONLY_AST", PyCF_ONLY_AST) < 0)', 1)
       
   909         self.emit("return;", 2)
       
   910         # Value of version: "$Revision: 67146 $"
       
   911         self.emit('if (PyModule_AddStringConstant(m, "__version__", "%s") < 0)'
       
   912                 % parse_version(mod), 1)
       
   913         self.emit("return;", 2)
       
   914         for dfn in mod.dfns:
       
   915             self.visit(dfn)
       
   916         self.emit("}", 0)
       
   917 
       
   918     def visitProduct(self, prod, name):
       
   919         self.addObj(name)
       
   920 
       
   921     def visitSum(self, sum, name):
       
   922         self.addObj(name)
       
   923         for t in sum.types:
       
   924             self.visitConstructor(t, name)
       
   925 
       
   926     def visitConstructor(self, cons, name):
       
   927         self.addObj(cons.name)
       
   928 
       
   929     def addObj(self, name):
       
   930         self.emit('if (PyDict_SetItemString(d, "%s", (PyObject*)%s_type) < 0) return;' % (name, name), 1)
       
   931 
       
   932 
       
   933 _SPECIALIZED_SEQUENCES = ('stmt', 'expr')
       
   934 
       
   935 def find_sequence(fields, doing_specialization):
       
   936     """Return True if any field uses a sequence."""
       
   937     for f in fields:
       
   938         if f.seq:
       
   939             if not doing_specialization:
       
   940                 return True
       
   941             if str(f.type) not in _SPECIALIZED_SEQUENCES:
       
   942                 return True
       
   943     return False
       
   944 
       
   945 def has_sequence(types, doing_specialization):
       
   946     for t in types:
       
   947         if find_sequence(t.fields, doing_specialization):
       
   948             return True
       
   949     return False
       
   950 
       
   951 
       
   952 class StaticVisitor(PickleVisitor):
       
   953     CODE = '''Very simple, always emit this static code.  Overide CODE'''
       
   954 
       
   955     def visit(self, object):
       
   956         self.emit(self.CODE, 0, reflow=False)
       
   957 
       
   958 
       
   959 class ObjVisitor(PickleVisitor):
       
   960 
       
   961     def func_begin(self, name):
       
   962         ctype = get_c_type(name)
       
   963         self.emit("PyObject*", 0)
       
   964         self.emit("ast2obj_%s(void* _o)" % (name), 0)
       
   965         self.emit("{", 0)
       
   966         self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
       
   967         self.emit("PyObject *result = NULL, *value = NULL;", 1)
       
   968         self.emit('if (!o) {', 1)
       
   969         self.emit("Py_INCREF(Py_None);", 2)
       
   970         self.emit('return Py_None;', 2)
       
   971         self.emit("}", 1)
       
   972         self.emit('', 0)
       
   973 
       
   974     def func_end(self):
       
   975         self.emit("return result;", 1)
       
   976         self.emit("failed:", 0)
       
   977         self.emit("Py_XDECREF(value);", 1)
       
   978         self.emit("Py_XDECREF(result);", 1)
       
   979         self.emit("return NULL;", 1)
       
   980         self.emit("}", 0)
       
   981         self.emit("", 0)
       
   982 
       
   983     def visitSum(self, sum, name):
       
   984         if is_simple(sum):
       
   985             self.simpleSum(sum, name)
       
   986             return
       
   987         self.func_begin(name)
       
   988         self.emit("switch (o->kind) {", 1)
       
   989         for i in range(len(sum.types)):
       
   990             t = sum.types[i]
       
   991             self.visitConstructor(t, i + 1, name)
       
   992         self.emit("}", 1)
       
   993         for a in sum.attributes:
       
   994             self.emit("value = ast2obj_%s(o->%s);" % (a.type, a.name), 1)
       
   995             self.emit("if (!value) goto failed;", 1)
       
   996             self.emit('if (PyObject_SetAttrString(result, "%s", value) < 0)' % a.name, 1)
       
   997             self.emit('goto failed;', 2)
       
   998             self.emit('Py_DECREF(value);', 1)
       
   999         self.func_end()
       
  1000 
       
  1001     def simpleSum(self, sum, name):
       
  1002         self.emit("PyObject* ast2obj_%s(%s_ty o)" % (name, name), 0)
       
  1003         self.emit("{", 0)
       
  1004         self.emit("switch(o) {", 1)
       
  1005         for t in sum.types:
       
  1006             self.emit("case %s:" % t.name, 2)
       
  1007             self.emit("Py_INCREF(%s_singleton);" % t.name, 3)
       
  1008             self.emit("return %s_singleton;" % t.name, 3)
       
  1009         self.emit("default:" % name, 2)
       
  1010         self.emit('/* should never happen, but just in case ... */', 3)
       
  1011         code = "PyErr_Format(PyExc_SystemError, \"unknown %s found\");" % name
       
  1012         self.emit(code, 3, reflow=False)
       
  1013         self.emit("return NULL;", 3)
       
  1014         self.emit("}", 1)
       
  1015         self.emit("}", 0)
       
  1016 
       
  1017     def visitProduct(self, prod, name):
       
  1018         self.func_begin(name)
       
  1019         self.emit("result = PyType_GenericNew(%s_type, NULL, NULL);" % name, 1);
       
  1020         self.emit("if (!result) return NULL;", 1)
       
  1021         for field in prod.fields:
       
  1022             self.visitField(field, name, 1, True)
       
  1023         self.func_end()
       
  1024 
       
  1025     def visitConstructor(self, cons, enum, name):
       
  1026         self.emit("case %s_kind:" % cons.name, 1)
       
  1027         self.emit("result = PyType_GenericNew(%s_type, NULL, NULL);" % cons.name, 2);
       
  1028         self.emit("if (!result) goto failed;", 2)
       
  1029         for f in cons.fields:
       
  1030             self.visitField(f, cons.name, 2, False)
       
  1031         self.emit("break;", 2)
       
  1032 
       
  1033     def visitField(self, field, name, depth, product):
       
  1034         def emit(s, d):
       
  1035             self.emit(s, depth + d)
       
  1036         if product:
       
  1037             value = "o->%s" % field.name
       
  1038         else:
       
  1039             value = "o->v.%s.%s" % (name, field.name)
       
  1040         self.set(field, value, depth)
       
  1041         emit("if (!value) goto failed;", 0)
       
  1042         emit('if (PyObject_SetAttrString(result, "%s", value) == -1)' % field.name, 0)
       
  1043         emit("goto failed;", 1)
       
  1044         emit("Py_DECREF(value);", 0)
       
  1045 
       
  1046     def emitSeq(self, field, value, depth, emit):
       
  1047         emit("seq = %s;" % value, 0)
       
  1048         emit("n = asdl_seq_LEN(seq);", 0)
       
  1049         emit("value = PyList_New(n);", 0)
       
  1050         emit("if (!value) goto failed;", 0)
       
  1051         emit("for (i = 0; i < n; i++) {", 0)
       
  1052         self.set("value", field, "asdl_seq_GET(seq, i)", depth + 1)
       
  1053         emit("if (!value1) goto failed;", 1)
       
  1054         emit("PyList_SET_ITEM(value, i, value1);", 1)
       
  1055         emit("value1 = NULL;", 1)
       
  1056         emit("}", 0)
       
  1057 
       
  1058     def set(self, field, value, depth):
       
  1059         if field.seq:
       
  1060             # XXX should really check for is_simple, but that requires a symbol table
       
  1061             if field.type.value == "cmpop":
       
  1062                 # While the sequence elements are stored as void*,
       
  1063                 # ast2obj_cmpop expects an enum
       
  1064                 self.emit("{", depth)
       
  1065                 self.emit("int i, n = asdl_seq_LEN(%s);" % value, depth+1)
       
  1066                 self.emit("value = PyList_New(n);", depth+1)
       
  1067                 self.emit("if (!value) goto failed;", depth+1)
       
  1068                 self.emit("for(i = 0; i < n; i++)", depth+1)
       
  1069                 # This cannot fail, so no need for error handling
       
  1070                 self.emit("PyList_SET_ITEM(value, i, ast2obj_cmpop((cmpop_ty)asdl_seq_GET(%s, i)));" % value,
       
  1071                           depth+2, reflow=False)
       
  1072                 self.emit("}", depth)
       
  1073             else:
       
  1074                 self.emit("value = ast2obj_list(%s, ast2obj_%s);" % (value, field.type), depth)
       
  1075         else:
       
  1076             ctype = get_c_type(field.type)
       
  1077             self.emit("value = ast2obj_%s(%s);" % (field.type, value), depth, reflow=False)
       
  1078 
       
  1079 
       
  1080 class PartingShots(StaticVisitor):
       
  1081 
       
  1082     CODE = """
       
  1083 PyObject* PyAST_mod2obj(mod_ty t)
       
  1084 {
       
  1085     init_types();
       
  1086     return ast2obj_mod(t);
       
  1087 }
       
  1088 
       
  1089 /* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */
       
  1090 mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode)
       
  1091 {
       
  1092     mod_ty res;
       
  1093     PyObject *req_type[] = {(PyObject*)Module_type, (PyObject*)Expression_type,
       
  1094                             (PyObject*)Interactive_type};
       
  1095     char *req_name[] = {"Module", "Expression", "Interactive"};
       
  1096     assert(0 <= mode && mode <= 2);
       
  1097 
       
  1098     init_types();
       
  1099 
       
  1100     if (!PyObject_IsInstance(ast, req_type[mode])) {
       
  1101         PyErr_Format(PyExc_TypeError, "expected %s node, got %.400s",
       
  1102                      req_name[mode], Py_TYPE(ast)->tp_name);
       
  1103         return NULL;
       
  1104     }
       
  1105     if (obj2ast_mod(ast, &res, arena) != 0)
       
  1106         return NULL;
       
  1107     else
       
  1108         return res;
       
  1109 }
       
  1110 
       
  1111 int PyAST_Check(PyObject* obj)
       
  1112 {
       
  1113     init_types();
       
  1114     return PyObject_IsInstance(obj, (PyObject*)&AST_type);
       
  1115 }
       
  1116 """
       
  1117 
       
  1118 class ChainOfVisitors:
       
  1119     def __init__(self, *visitors):
       
  1120         self.visitors = visitors
       
  1121 
       
  1122     def visit(self, object):
       
  1123         for v in self.visitors:
       
  1124             v.visit(object)
       
  1125             v.emit("", 0)
       
  1126 
       
  1127 common_msg = "/* File automatically generated by %s. */\n\n"
       
  1128 
       
  1129 c_file_msg = """
       
  1130 /*
       
  1131    __version__ %s.
       
  1132 
       
  1133    This module must be committed separately after each AST grammar change;
       
  1134    The __version__ number is set to the revision number of the commit
       
  1135    containing the grammar change.
       
  1136 */
       
  1137 
       
  1138 """
       
  1139 
       
  1140 def main(srcfile):
       
  1141     argv0 = sys.argv[0]
       
  1142     components = argv0.split(os.sep)
       
  1143     argv0 = os.sep.join(components[-2:])
       
  1144     auto_gen_msg = common_msg % argv0
       
  1145     mod = asdl.parse(srcfile)
       
  1146     if not asdl.check(mod):
       
  1147         sys.exit(1)
       
  1148     if INC_DIR:
       
  1149         p = "%s/%s-ast.h" % (INC_DIR, mod.name)
       
  1150         f = open(p, "wb")
       
  1151         f.write(auto_gen_msg)
       
  1152         f.write('#include "asdl.h"\n\n')
       
  1153         c = ChainOfVisitors(TypeDefVisitor(f),
       
  1154                             StructVisitor(f),
       
  1155                             PrototypeVisitor(f),
       
  1156                             )
       
  1157         c.visit(mod)
       
  1158         f.write("PyObject* PyAST_mod2obj(mod_ty t);\n")
       
  1159         f.write("mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode);\n")
       
  1160         f.write("int PyAST_Check(PyObject* obj);\n")
       
  1161         f.close()
       
  1162 
       
  1163     if SRC_DIR:
       
  1164         p = os.path.join(SRC_DIR, str(mod.name) + "-ast.c")
       
  1165         f = open(p, "wb")
       
  1166         f.write(auto_gen_msg)
       
  1167         f.write(c_file_msg % parse_version(mod))
       
  1168         f.write('#include "Python.h"\n')
       
  1169         f.write('#include "%s-ast.h"\n' % mod.name)
       
  1170         f.write('\n')
       
  1171         f.write("static PyTypeObject AST_type;\n")
       
  1172         v = ChainOfVisitors(
       
  1173             PyTypesDeclareVisitor(f),
       
  1174             PyTypesVisitor(f),
       
  1175             Obj2ModPrototypeVisitor(f),
       
  1176             FunctionVisitor(f),
       
  1177             ObjVisitor(f),
       
  1178             Obj2ModVisitor(f),
       
  1179             ASTModuleVisitor(f),
       
  1180             PartingShots(f),
       
  1181             )
       
  1182         v.visit(mod)
       
  1183         f.close()
       
  1184 
       
  1185 if __name__ == "__main__":
       
  1186     import sys
       
  1187     import getopt
       
  1188 
       
  1189     INC_DIR = ''
       
  1190     SRC_DIR = ''
       
  1191     opts, args = getopt.getopt(sys.argv[1:], "h:c:")
       
  1192     if len(opts) != 1:
       
  1193         print "Must specify exactly one output file"
       
  1194         sys.exit(1)
       
  1195     for o, v in opts:
       
  1196         if o == '-h':
       
  1197             INC_DIR = v
       
  1198         if o == '-c':
       
  1199             SRC_DIR = v
       
  1200     if len(args) != 1:
       
  1201         print "Must specify single input file"
       
  1202         sys.exit(1)
       
  1203     main(args[0])