from compiler.ast import *
import re


# Parse tree objects
class Declaration:
    def __init__(self, returns, args, func_name, func_parts):
        self.__dict__.update(locals())
    def ast(self):
        assert isinstance(self.args, ComplexType)
        assert self.args.name == 'Tuple'
        keys = filter(None, [arg.ast() for arg in self.args.parts])
        if self.returns and self.returns.ast():
            keys.append((Const("_RETURNS"), self.returns.ast()))
        return Assign([AssName('__types__', 'OP_ASSIGN')], Dict(keys))

class Role:
    def __init__(self, arg_name, type, special):
        self.__dict__.update(locals())
    def ast(self):
        if self.type:
            return (Const(self.arg_name), self.type.ast())
        else:
            return None

class Type:
    def __init__(self, type_name, args):
        self.type_name = type_name
        if type_name:
            self.type_parts = type_name.split(".")
        self.args = args
    def ast(self):
        typecode = Name(self.type_parts[0])
        for part in self.type_parts[1:]:
            typecode = Getattr(code, part)
        if self.args:
            rc = CallFunc(typecode, [arg.ast() for arg in self.args], None, None)
            return rc
        else:
            return typecode

class ComplexType(Type):
    def __init__(self):
        assert 0, "Abstract type instantiatied"

class UnionType(ComplexType):
    def __init__(self, name, parts):
        self.name = name
        self.parts = parts
        assert parts
        for part in parts:
            assert part, isinstance(part, Type)

    def ast(self):
        parts = []
        for part in self.parts:
            if part:
                parts.append(part.ast())
        return CallFunc(Name(self.name), parts, None, None)

class SequenceType(ComplexType):
    def __init__(self, name, parts):
        self.name = name
        self.parts = parts
        assert parts
        for part in parts:
            assert part, isinstance(part, Role)
    def ast(self):
        parts = []
        for part in self.parts:
            if part.type:
                parts.append(part.type.ast())

        return CallFunc(Name(self.name), parts, None, None)

class Arg:
    def __init__(self, typ):
        self.typ = typ
        assert isinstance(typ, Type)
    def ast(self):
        return self.typ.ast()

class KwArg(Arg):
    def __init__(self, keyword, typ):
        self.keyword = keyword
        self.typ = typ
        assert isinstance(typ, Type)

    def ast(self):
        return Keyword(self.keyword, self.typ.ast())

dotted_name_re = re.compile(r"\w+(\s*\.\s*\w+)*")
ws_re = re.compile(r"\s*")
kw_name_re = re.compile(r"\w+")

def reportError(data, pos, expected):
    lineno = len(data[0:pos].splitlines())
    val = "Parse error on this line:"
    line_start = data.rfind("\n", 0, pos)
    line_end = data.find("\n", pos)
    line = data[line_start:line_end]
    val += line + "\n"
    val += " "*(pos-line_start-1)+"^\n"
    val += "Expected %s\n" % expected
    return SyntaxError(val)

def parseString(data, pos, *expects_list):
    pos += stripwslen(data, pos)
    for expects in expects_list:
        if data[pos:pos+len(expects)]==expects:
            pos += len(expects)
            pos += stripwslen(data, pos)
            return expects, pos
    joiner = " or " 
    raise reportError(data, pos, joiner.join( expects_list ))

def parseExpression(data, pos, re, name):
    pos += stripwslen(data, pos)
    match = re.match(data, pos)
    if not match:
        raise reportError(data, pos, "'%s' (%s)" % (name, re.pattern))
    else:
        matched = match.group()
        pos += len(matched) 
        pos += stripwslen(data, pos)
        return matched, pos

def stripwslen(data, pos):
    return len( ws_re.match(data, pos).group())

def parse_declarations(data, pos=0):
    data = re.sub("#.*\n", "\n", data )

    declarations = []
    while pos<len(data):
        declaration, pos = parse_declaration(data, pos)
        declarations.append(declaration)
    return declarations

def parse_declaration(data, pos):
    _, pos = parseString(data, pos, "def")
    func_name, pos = parseExpression(data, pos, dotted_name_re, "Type name")
    func_parts = func_name.split(".")

    _ = parseString(data, pos, "(") # check next char is paren
                                # but do NOT advance pointer!
                                # guarding against square brackets

    args, pos = parse_parenthesized_type(data, pos)

    if data[pos:].startswith("->"):
        _, pos = parseString(data, pos, "->")
        if data[pos]=="(":
            returns, pos = parse_parenthesized_type(data, pos)
        else:
            role, pos = parse_role_and_type(data, pos)
            returns = role.type  # sometimes None
    else:
        returns = None

    return Declaration(returns, args, func_name, func_parts), pos

inverse={"{":"}", "[":"]", "(": ")", "<": ">"}
start_brackets = inverse.keys()

def parse_parenthesized_type(data, pos):
    bracket, _ = parseString(data, pos, "(", "[")
    types = []

    # look ahead to see if this is an "or"
    # XXX this doesn't work in general because parens can be nested!!!
    # XXX crap...
    first = 1000000000L
    if data.find("|", pos)>0:
        first = min(data.find("|",pos), first)
    if data.find(":", pos)>0:
        first = min(data.find(":",pos), first)
    if data.find(",", pos)>0:
        first = min(data.find(",",pos), first)

    if bracket == "(" and data[first] == "|":
        return parse_choice(data, pos)
    else:
        return parse_sequence(data, pos)

def parse_choice(data, pos):
    _, pos = parseString(data, pos, "(")
    types = []
    while 1:
        typ, pos = parse_type(data, pos)
        types.append(typ)
        if data[pos] == ")":
            break
        _, pos = parseString(data, pos, "|")

    _, pos = parseString(data, pos, ")")
    assert len(types)>1
    return UnionType("Union", types), pos

def parse_sequence(data, pos):
    bracket, pos = parseString(data, pos, "(", "[")
    types = []
    while 1:
        # look for ellipse
        if data[pos]==".":
            _, pos = parseString(data, pos, "...")
            assert 0, "... not handled yet"
        else:
            type, pos = parse_role_and_type(data, pos)

        types.append(type)

        if data[pos] == inverse[bracket]:
            break
        _, pos = parseString(data, pos, ",")

    _, pos = parseString(data, pos, inverse[bracket])
        
    if bracket=="(":
        return SequenceType("Tuple", types), pos
    else:
        return SequenceType("Sequence", types), pos

def parse_type(data, pos):
    if data[pos] in "([":
        typ, pos = parse_parenthesized_type(data, pos)
        assert isinstance(typ, Type)
    else:
        type_name, pos = parseExpression(data, pos, dotted_name_re, 
                                        "Type name")
        if pos<len(data) and data[pos] == "(":
            args, pos = parse_type_args(data, pos)
        else:
            args = None
        typ = Type(type_name, args)
    return typ, pos
    
def parse_type_args(data, pos):
    _, pos = parseString(data, pos, "(")
    args = []
    while 1:
        arg, pos = parse_type_arg(data, pos)
        args.append(arg)
        assert arg!=None
        if data[pos]==")":
            _, pos = parseString(data, pos, ")")
            break
        _, pos = parseString(data, pos, ",")
    return args, pos
        
def parse_type_arg(data, pos):
    # ignore keyword args for now...
    #keyword, pos = parseExpression(data, pos, kw_name_re, "Keyword name")
    #_, pos = parseString(data, pos, "=")
    value, pos = parse_type(data, pos)
    return Arg(value), pos

def parse_role_and_type(data, pos):
    asterisks = 0
    while data[pos]=="*": # ignore asterisks for now
        pos += 1 
        asterisks += 1
    role_name, pos = parseExpression(data, pos, dotted_name_re, "Arg name")

    if data[pos]==":":
        assert "." not in role_name, "Argument names cannot have dots"
        _, pos = parseString(data, pos, ":")
        typ, pos = parse_type(data, pos)
    else: 
        typ = None

    return Role(role_name, typ, asterisks), pos

# Sanity check
parse_declarations("""
def funcname2.funcname2.funcname (arg2:type.type.type,arg) # nother test
def blah( arg : IString, pos: IInteger  ) -> j: IInteger # test
def funcname2.funcname2.funcname (arg3:[word : abc],arg) # nother test
def _Environ.__setitem__(self, key: IString, item: IString)
""")

