#! /usr/bin/env python """Generate C code from an ASDL description.""" # TO DO # handle fields that have a type but no name import os, sys, traceback, operator import asdl TABSIZE = 8 MAX_COL = 80 def get_c_type(name): """Return a string for the C name of the type. This function special cases int; everything else is PyObject*: """ if isinstance(name, asdl.Id): name = name.value if name in ("int"): return name else: return "PyObject*" def reflow_lines(s, depth): """Reflow the line s indented depth tabs. Return a sequence of lines where no line extends beyond MAX_COL when properly indented. The first line is properly indented based exclusively on depth * TABSIZE. All following lines -- these are the reflowed lines generated by this function -- start at the same column as the first character beyond the opening { in the first line. """ size = MAX_COL - depth * TABSIZE if len(s) < size: return [s] lines = [] cur = s padding = "" while len(cur) > size: i = cur.rfind(' ', 0, size) # XXX this should be fixed for real if i == -1 and 'GeneratorExp' in cur: i = size + 3 assert i != -1, "Impossible line %d to reflow: %s" % (size, `s`) lines.append(padding + cur[:i]) if len(lines) == 1: # find new size based on brace j = cur.find('{', 0, i) if j >= 0: j += 2 # account for the brace and the space after it size -= j padding = " " * j else: j = cur.find('(', 0, i) if j >= 0: j += 1 # account for the paren (no space after it) size -= j padding = " " * j cur = cur[i+1:] else: lines.append(padding + cur) return lines def is_simple(sum): """Return True if a sum is a simple. A sum is simple if its types have no fields, e.g. unaryop = Invert | Not | UAdd | USub """ for t in sum.types: if t.fields: return False return True class EmitVisitor(asdl.VisitorBase): """Visit that emits lines""" def __init__(self, file): self.file = file super(EmitVisitor, self).__init__() def emit(self, s, depth, reflow=1): # XXX reflow long lines? if reflow: lines = reflow_lines(s, depth) else: lines = [s] for line in lines: line = (" " * TABSIZE * depth) + line + "\n" self.file.write(line) class TraversalVisitor(EmitVisitor): def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) def visitType(self, type): self.visit(type.value, type.name) def visitSum(self, sum, name): for t in sum.types: self.visit(t, name, sum.attributes) def get_args(self, fields): """Return list of C argument into, one for each field. Argument info is 3-tuple of a C type, variable name, and flag that is true if type can be NULL. """ args = [] for i,f in enumerate(fields): if f.name is None: name = "name%d" % i else: name = f.name ctype = get_c_type(f.type) args.append((ctype, name, f.opt or f.seq)) return args class HeaderVisitor(EmitVisitor): """Visitor to generate typdefs for AST.""" def emit_check(self, t, depth): self.emit("PyAPI_DATA(PyTypeObject) Py_%s_Type;" % t, depth) self.emit("#define %s_Check(op) PyObject_TypeCheck(op, &Py_%s_Type)" % (t, t), depth, reflow=False) self.emit("",depth) def emit_field_access(self, name, fields): for f in fields: self.emit("#define %s_%s(o) (((struct _%s*)o)->%s)" % (name, f.name, name, f.name), 0) def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) def visitType(self, type, depth=0): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): self.sum_with_constructors(sum, name, depth) def sum_with_constructors(self, sum, name, depth): def emit(s, depth=depth): self.emit(s % sys._getframe(1).f_locals, depth) self.emit_check(name, depth) emit("struct _%s{" % name) emit("PyObject_HEAD", depth + 1) names = ["%s_Dummy_kind"%name] names += [t.name.value+"_kind" for t in sum.types] emit("enum {%s} _kind;" % ", ".join(names), depth+1) for field in sum.attributes: type = str(field.type) assert type in asdl.builtin_types, type emit("%s %s;" % (type, field.name), depth + 1) emit("};") emit("#define %s_kind(o) (((struct _%s*)o)->_kind)" % (name, name)) emit("") for t in sum.types: self.visitConstructor(name, t, sum.attributes, depth) def visitConstructor(self, name, cons, attrs, depth): self.emit_check(cons.name, depth) self.emit("struct _%s{" % cons.name, depth) self.emit("struct _%s _base;" % name, depth+1) field_types = [] for f in cons.fields: field_types.append(get_c_type(f.type)) self.visit(f, depth + 1) for f in attrs: field_types.append(get_c_type(f.type)) self.emit("};", depth) args = ", ".join(field_types) or "void" self.emit("PyObject *Py_%s_New(%s);" % (cons.name, args), depth) # for convenience self.emit("#define %s Py_%s_New" % (cons.name, cons.name), depth) self.emit_field_access(cons.name, cons.fields) self.emit("", depth) def visitField(self, field, depth): ctype = get_c_type(field.type) self.emit("%s %s; /* %s */" % (ctype, field.name, field.type), depth) def visitProduct(self, product, name, depth): self.emit_check(str(name), depth) self.emit("struct _%(name)s {" % locals(), depth) self.emit("PyObject_HEAD", depth+1) field_types = [] for f in product.fields: field_types.append(get_c_type(f.type)) self.visit(f, depth + 1) self.emit("};", depth) self.emit("PyObject *Py_%s_New(%s);" % (name, ", ".join(field_types)), depth) self.emit("#define %s Py_%s_New" % (name, name), depth) self.emit_field_access(name, product.fields) self.emit("", depth) class ForwardVisitor(TraversalVisitor): def emit_validate(self, name): self.emit("static int %s_validate(PyObject*);" % name, 0) def visitSum(self, sum, name): self.emit_validate(name) for t in sum.types: self.visit(t, name, sum.attributes) def visitProduct(self, prod, name): self.emit_validate(name) def visitConstructor(self, cons, name, attrs): self.emit_validate(cons.name) class FunctionVisitor(TraversalVisitor): """Visitor to generate constructor functions for AST.""" def check(self, t): t = t.value if t in ("identifier", "string"): return "string_Check" elif t == "bool": return "PyBool_Check" else: return t+"_Check" def emit_ctor(self, name, args, attrs): def emit(s, depth=0, reflow=1): self.emit(s, depth, reflow) argstr = ["%s %s" % (get_c_type(f.type), f.name) for f in args] argstr += ["%s %s" % (argtype, argname) for argtype, argname, opt in attrs] argstr = ", ".join(argstr) emit("PyObject*") emit("Py_%s_New(%s)" % (name, argstr)) emit("{") emit("struct _%s *result = PyObject_New(struct _%s, &Py_%s_Type);" % (name, name, name), 1, 0) emit("if (result == NULL)", 1) emit("return NULL;", 2) for f in args: argtype = get_c_type(f.type) if argtype == "PyObject*": if f.opt: emit("if (%s == NULL) {" % f.name, 1) emit("Py_INCREF(Py_None);", 2) emit("%s = Py_None;" % f.name, 2) emit("}", 1) elif f.seq: emit("if (%s == NULL)" % f.name, 1) emit("%s = PyList_New(0);" % f.name, 2) emit("Py_INCREF(%s);" % f.name, 1) emit("result->%s = %s;" % (f.name, f.name), 1) if str(name)[0].isupper(): # HACK ! emit("result->_base._kind = %s_kind;" % name, 1) for argtype, argname, opt in attrs: if argtype == "PyObject*": emit("Py_INCREF(%s);" % argname, 1) emit("result->_base.%s = %s;" % (argname, argname), 1) emit("return (PyObject*)result;", 1) emit("}") emit("") def emit_dealloc(self, name, fields, attrs): def emit(s, depth=0, reflow=1): self.emit(s, depth, reflow) emit("static void") emit("%s_dealloc(PyObject* _self)" % name) emit("{") emit("struct _%s *self = (struct _%s*)_self;" % (name, name), 1) for argtype, argname, opt in fields: if argtype == "PyObject*": emit("Py_DECREF(self->%s);" % argname, 1) for argtype, argname, opt in attrs: if argtype == "PyObject*": emit("Py_DECREF(self->_base.%s);" % argname, 1) emit("PyObject_Del(self);", 1) emit("}") emit("") def emit_seq_check(self, f): depth = 1 def emit(s): self.emit(s, depth) emit("if (!PyList_Check(obj->%s)) {" % f.name) emit(' failed_check("%s", "list", obj->%s);' % (f.name, f.name)) emit(' return -1;') emit("}") emit("for(i = 0; i < PyList_Size(obj->%s); i++) {" % f.name) depth = 2 check = self.check(f.type) emit("if (!%s(PyList_GET_ITEM(obj->%s, i))) {" % (check, f.name)) emit(' failed_check("%s", "%s", PyList_GET_ITEM(obj->%s, i));' % (f.name, f.type, f.name)) emit(' return -1;') emit("}") if f.type.value not in ('identifier',): emit("if (%s_validate(PyList_GET_ITEM(obj->%s, i)) < 0)" % (f.type, f.name)) emit(" return -1;") depth = 1 emit("}") def emit_opt_check(self, f): depth = 1 def emit(s): self.emit(s, depth) emit("if (obj->%s == Py_None) /* empty */;" % f.name) check = self.check(f.type) emit("else if (!%s(obj->%s)) {" % (check, f.name)) emit(' failed_check("%s", "%s", obj->%s);' % (f.name, f.type, f.name)) emit(' return -1;') emit("}") if f.type.value not in ('identifier',): emit("else if (%s_validate(obj->%s) < 0)" % (f.type, f.name)) emit(" return -1;") def emit_field_check(self, f, inbase): depth = 1 def emit(s): self.emit(s, depth) base = "" if inbase: base = "_base." type = f[0] name = f[1] else: type = f.type name = f.name if type == "int": return check = self.check(type) emit("if (!%s(obj->%s%s)) {" % (check, base, name)) emit(' failed_check("%s", "%s", obj->%s%s);' % (name, type, base, name)) emit(' return -1;') emit("}") def emit_validate(self, name, fields, attrs): depth = 0 def emit(s): self.emit(s, depth) has_seq = reduce(operator.or_, ([f.seq for f in fields]), False) emit("static int") emit("%s_validate(PyObject *_obj)" % name) emit("{") depth = 1 if fields: emit("struct _%s *obj = (struct _%s*)_obj;" % (name,name)) if has_seq: emit("int i;") for f in fields: if f.seq: self.emit_seq_check(f) elif f.opt: self.emit_opt_check(f) else: self.emit_field_check(f, False) for f in attrs: self.emit_field_check(f, True) emit("return 0;") depth = 0 emit("}") emit("") def emit_sumvalidate(self, sum, name): depth = 0 def emit(s): self.emit(s, depth) emit("int") emit("%s_validate(PyObject* _obj)" % name) emit("{") depth = 1 emit("struct _%s *obj = (struct _%s*)_obj;" % (name,name)) # caller should have verified that this is the correct type emit("assert(%s_Check(_obj));" % (name)) emit("switch(obj->_kind) {") depth = 2 for t in sum.types: emit("case %s_kind:" % t.name) emit(" return %s_validate(_obj);" % t.name) emit("default:") emit(" break;") depth = 1 emit("}") emit('PyErr_SetString(PyExc_TypeError, "invalid _kind in %s");' % name) emit('return -1;') depth = 0 emit("}") def emit_type(self, name): depth = 0 def emit(s): self.emit(s, depth) def null(thing): emit("0,\t\t/* tp_%s */" % thing) emit("PyTypeObject Py_%s_Type = {" % name) depth = 1 emit("PyObject_HEAD_INIT(NULL)") emit("0,\t\t/*ob_size*/") emit('"%s",\t\t/*tp_name*/' % name) emit("sizeof(struct _%s),\t/*tp_basicsize*/" % name) null("itemsize") emit("%s_dealloc,\t\t/*tp_dealloc*/" % name) null("print") null("getattr") null("setattr") for m in ("compare", "repr", "as_number", "as_sequence", "as_mapping", "hash", "call", "str", "getattro", "setattro", "as_buffer"): null(m) emit("Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE,\t\t/*tp_flags*/") for m in ("doc", "traverse", "clear", "richcompare", "weaklistoffset", "iter", "iternext", "methods", "members", "getset", "base", "dict", "descr_get", "descr_set", "dictoffset", "init", "alloc", "new", "free", "is_gc"): null(m) depth = 0 emit("};") emit("") def visitSum(self, sum, name): self.emit("#define %s_dealloc 0" % name, 0) self.emit_type(name) self.emit_sumvalidate(sum, name) for t in sum.types: self.visit(t, name, sum.attributes) def visitProduct(self, prod, name): args = self.get_args(prod.fields) self.emit_ctor(str(name), prod.fields, []) self.emit_dealloc(str(name), args, []) self.emit_type(str(name)) self.emit_validate(str(name), prod.fields, []) def visitConstructor(self, cons, name, attrs): args = self.get_args(cons.fields) attrs = self.get_args(attrs) self.emit_ctor(cons.name, cons.fields, attrs) self.emit_dealloc(cons.name, args, attrs) self.emit_type(cons.name) self.emit_validate(cons.name, cons.fields, attrs) class PyAST_Validate(TraversalVisitor): def visitModule(self, mod): self.emit("int PyAST_Validate(PyObject *obj)", 0) self.emit("{", 0) for dfn in mod.dfns: self.visit(dfn) self.emit('PyErr_Format(PyExc_TypeError, "Not an AST node: %s", obj->ob_type->tp_name);', 1) self.emit("return -1;", 1) self.emit("}", 0) self.emit("", 0) def visitType(self, t): self.emit("if (%s_Check(obj))" % t.name, 1) self.emit("return %s_validate(obj);" % t.name, 2) class InitVisitor(TraversalVisitor): def visitModule(self, mod): self.emit("void init_ast(void)", 0) self.emit("{", 0) for dfn in mod.dfns: self.visit(dfn) self.emit("}", 0) def visitSum(self, sum, name): self.emit_init(name) for t in sum.types: self.visit(t, name, sum.attributes) def visitProduct(self, prod, name): self.emit_init(name) def visitConstructor(self, cons, name, attrs): self.emit_init(cons.name, name) def emit_init(self, name, base = None): if base: self.emit("Py_%s_Type.tp_base = &Py_%s_Type;" % (name, base), 1) self.emit("if (PyType_Ready(&Py_%s_Type) < 0)" % name, 1) self.emit("return;", 2) class ChainOfVisitors: def __init__(self, *visitors): self.visitors = visitors def visit(self, object): for v in self.visitors: v.visit(object) v.emit("", 0) static_code = """ static void failed_check(const char* field, const char* expected, PyObject *real) { PyErr_Format(PyExc_TypeError, "invalid %s: expected %s, found %s", field, expected, real->ob_type->tp_name); } /* Convenience macro to simplify asdl_c.py */ #define object_Check(x) 1 #define string_Check(x) (PyString_Check(x)||PyUnicode_Check(x)) """ def main(srcfile): auto_gen_msg = '/* File automatically generated by %s */\n' % sys.argv[0] mod = asdl.parse(srcfile) if not asdl.check(mod): sys.exit(1) if INC_DIR: p = "%s/%s-ast.h" % (INC_DIR, mod.name) else: p = "%s-ast.h" % mod.name f = open(p, "wb") print >> f, auto_gen_msg print >> f, "/* For convenience, this header provides several" print >> f, " macro, type and constant names which are not Py_-prefixed." print >> f, " Therefore, the file should not be included in Python.h;" print >> f, " all symbols relevant to linkage are Py_-prefixed. */" print >> f, "\n" c = HeaderVisitor(f) c.visit(mod) f.close() if SRC_DIR: p = "%s/%s-ast.c" % (SRC_DIR, mod.name) else: p = "%s-ast.c" % mod.name f = open(p, "wb") print >> f, auto_gen_msg print >> f, '#include "Python.h"' print >> f, '#include "%s-ast.h"' % mod.name print >> f print >> f, static_code v = ChainOfVisitors( ForwardVisitor(f), FunctionVisitor(f), PyAST_Validate(f), InitVisitor(f), ) v.visit(mod) f.close() if __name__ == "__main__": import sys import getopt INC_DIR = '' SRC_DIR = '' opts, args = getopt.getopt(sys.argv[1:], "h:c:") for o, v in opts: if o == '-h': INC_DIR = v if o == '-c': SRC_DIR = v if len(args) != 1: print "Must specify single input file" main(args[0])