#! /usr/bin/python
# -*- coding: utf-8 -*-

import re, os, sys
from subprocess import Popen, PIPE
from optparse import OptionParser


num_re = r"(-?\d*\.?\d*e?[+-]?\d*f?)"

class UIDefs(object):

    class UID(dict):
        def write(self, fp, v):
            if self.has_key("alias"):
                return
            self["tail"] = "".join([", "+x for x in self["value"]])
            if self.has_key("enum"):
                enumvals = ",".join(['"%s"' % x for x in self["enum"].split("|")]+["0"])
                self["ename"] = ename = self["variable"] + "_values"
                fp.write('\tstatic const char *%s[] = {%s};\n' % (ename, enumvals))
                fp.write('\tregisterEnumVar("%(id)s","%(name)s","%(type)s","%(tooltip)s",%(ename)s,&%(variable)s%(tail)s);\n' % self)
            else:
                fp.write('\tregisterVar("%(id)s","%(name)s","%(type)s","%(tooltip)s",&%(variable)s%(tail)s);\n' % self)

        def __getitem__(self, n):
            try:
                return dict.__getitem__(self, n)
            except KeyError:
                return ""

    def __init__(self):
        self.ui = {}

    def add(self, element, key, value):
        try:
            uid = self.ui[element]
        except KeyError:
            self.ui[element] = uid = self.UID(variable=element)
        uid[key] = value

    def get(self, element, key):
        return self.ui[element][key]

    def has(self, element, key):
        return self.ui[element].has_key(key)

    def var_filter(self):
        s = "|".join([r"%s\s*=\s*%s" % (v, num_re) for v in self.ui.keys()])
        return re.compile(r"\s*("+s+");").match

    def write(self, fp):
        for v, r in self.ui.items():
            r.write(fp, v)

class Parser(object):

    def skip_until(self, exp):
        r = re.compile(exp)
        for line in self.lines:
            m = r.match(line)
            if m:
                return m
        return None

    def skip_while(self, exp):
        m = re.compile(exp).match
        for line in self.lines:
            if not m(line):
                return line
        return ""

    def copy(self, exp, line=None):
        cp = []
        if line:
            cp.append(line)
        m = re.compile(exp).match
        for line in self.lines:
            if m(line):
                break
            cp.append(line)
        # remove indentation
        m = re.compile(r"\t*").match
        n = 10
        for l in cp:
            if l != "\n":
                n = min(n, len(m(l).group(0)))
        return [l[n:] for l in cp]

    def get_section_list(self):
        return "var-decl", "var-init", "var-free", "ui", "compute"

    def getIO(self, s):
        e = r"\s*virtual int getNum%sputs\(\)\s*{\s*return\s*(\d+);\s*}" % s
        for line in self.lines:
            m = re.match(e, line)
            if m:
                return int(m.group(1))
        raise ValueError("getNum%sputs not found in source" % s)

    def readUI(self, exp):
        stop = re.compile(exp).match
        pre = r"\s*interface->"
        nm = '"([^"]*)"'
        vr = "([a-zA-Z_][0-9a-zA-Z_]*)"
        sarg = (r"%s,\s*&%s,\s*%s,\s*%s,\s*%s,\s*%s"
                % (nm, vr, num_re, num_re, num_re, num_re))
        openbox = re.compile(pre+r"open(Horizontal|Vertical)Box\(%s\);" % nm).match
        closebox = re.compile(pre+r"closeBox\(\);").match
        vslider = re.compile(pre+r"addVerticalSlider\(%s\);" % sarg).match
        hslider = re.compile(pre+r"addHorizontalSlider\(%s\);" % sarg).match
        numentry = re.compile(pre+r"addNumEntry\(%s\);" % sarg).match
        checkbutton = re.compile(pre+r"addCheckButton\(%s,\s*&%s\);" % (nm, vr)).match
        declare = re.compile(pre+r"declare\(&%s,\s*%s,\s*%s\);" % (vr, nm, nm)).match
        stack = []
        def make_name(nm):
            if nm.startswith("."):
                nm = nm[1:]
            else:
                stack.append(nm)
                nm = ".".join(stack)
                stack.pop()
            return nm

        for line in self.lines:
            if stop(line):
                return
            m = openbox(line)
            if m:
                grp = m.group(2)
                if not stack:
                    if self.toplevel and grp == self.modname:
                        grp = self.toplevel
                    if not self.topname:
                        self.topname = grp
                stack.append(grp)
                continue
            if closebox(line):
                stack.pop()
                continue
            m = vslider(line) or hslider(line) or numentry(line)
            if m:
                vn = m.group(2)
                self.ui.add(vn, "type", "S")
                self.ui.add(vn, "id", make_name(m.group(1)))
                self.ui.add(vn, "value", m.groups()[2:])
                continue
            m = checkbutton(line)
            if m:
                vn = m.group(2)
                self.ui.add(vn, "type", "B")
                self.ui.add(vn, "id", make_name(m.group(1)))
                self.ui.add(vn, "value", ("0.0","0.0","1.0","1.0"))
                continue
            m = declare(line)
            if m:
                self.ui.add(m.group(1), m.group(2), m.group(3))
                continue
            assert False, line

    def readMeta(self):
        "only needed for faust 9.4; not used at the moment"
        self.meta = {}
        stop = re.compile(r'// Code generated with Faust').match
        declare = re.compile(r'// ([^:]+):\s*"([^"]*)"\s*$').match
        for line in self.lines:
            if stop(line):
                return
            m = declare(line)
            if m:
                key = m.group(1)
                value = m.group(2)
                self.meta[key] = value
                if key == "name":
                    self.toplevel = value

    def readMeta2(self, stop_expr):
        self.meta = {}
        stop = re.compile(stop_expr).match
        declare = re.compile(r'\s*m->declare\s*\("([^"]+)"\s*,\s*"([^"]*)"\);').match
        for line in self.lines:
            if stop(line):
                return
            m = declare(line)
            if m:
                key = m.group(1)
                value = m.group(2)
                self.meta[key] = value
                if key == "name":
                    self.toplevel = value

    def change_var_decl(self, lines):
        param_matcher = re.compile(r"FAUSTFLOAT\s+([a-zA-Z_0-9]+);\n$").match
        array_matcher = re.compile(r"(int|float|double)\s+([a-zA-Z_0-9]+)\s*\[\s*(\d+)\s*\]\s*;\n$").match
        out = []
        for l in lines:
            m = param_matcher(l);
            if m:
                var = m.group(1)
                alias = self.ui.has(var,"alias")
                if alias:
                    l = ('FAUSTFLOAT&\t%s = get_alias("%s");\n'
                         % (var, self.ui.get(var, "id")))
            if self.memory_threshold:
                m = array_matcher(l)
                if m:
                    sz = {"int": 4, "float": 4, "double": 8}[m.group(1)]
                    alen = int(m.group(3))
                    if alen * sz > self.memory_threshold:
                        l = "%s *%s;\n" % (m.group(1), m.group(2))
                        self.memlist.append((m.group(2), m.group(1), alen))
            out.append(l)
        return out

    def add_var_alloc(self):
        l = []
        for v, t, s in self.memlist:
            l.append("if (!%s) %s = new %s[%d];\n" % (v, v, t, s))
        return l

    def add_var_free(self):
        l = []
        for v, t, s in self.memlist:
            l.append("if (%s) { delete %s; %s = 0; }\n" % (v, v, v))
        return l

    def __init__(self, lines, modname, memory_threshold):
        self.lines = lines
        self.modname = modname
        self.memory_threshold = memory_threshold;
        self.toplevel = None
        self.topname = None
        self.memlist = []
        s = {}
        self.ui = UIDefs()
        #self.readMeta()  # (needed only for faust 9.4
        self.skip_until(r"  private:")
        var_decl = self.copy(r"  public:")
        self.skip_until(r"^\s*static\s+void\s+metadata\s*\(\s*Meta\s*\*\s*m\s*\)\s*{")
        self.readMeta2(r"\s*}\s*\n$")
        self.numInputs = self.getIO("In")
        self.numOutputs = self.getIO("Out")
        self.skip_until(r"\s*static void classInit")
        s["var-init"] = self.copy(r"\s*}$")
        self.skip_until(r"\s*virtual void instanceInit")
        s["var-init"] += self.copy(r"\s*}$")
        self.skip_until(r"\s*virtual void buildUserInterface")
        s["ui"] = self.readUI(r"\s*}$")
        s["var-decl"] = self.change_var_decl(var_decl)
        s["var-init"] = self.add_var_alloc() + s["var-init"]
        s["var-free"] = self.add_var_free()
        self.skip_until(r"\s*virtual void compute")
        iodef = r"\s*(float|FAUSTFLOAT)\s*\*\s*(in|out)put(\d+)\s*=\s*\2put\[\3\];"
        s["compute"] = self.copy(iodef)
        line = self.skip_while(iodef)
        s["compute"] += self.copy("\s*}$", line)
        self.sections = s
        if self.topname is None:
            self.topname = self.modname
        # ignore any following definitions of static class members

    def getNumInputs(self):
        return self.numInputs

    def getNumOutputs(self):
        return self.numOutputs

    def __getitem__(self, n):
        return self.sections[n]

    def write(self, fp, sect, indent=0, filt=lambda l: False):
        pre = "\t" * indent
        for l in self.sections[sect]:
            if filt(l):
                continue
            fp.write(pre)
            fp.write(l)

activate = """
void activate(bool start, int samplingFreq)
{
	if (start) {
		if (!is_inited()) {
			init(samplingFreq);
		}
	} else {
		if (is_inited()) {
			mem_free();
		}
	}
}

"""

def output(fp, p, fname):
    has_activate = len(p.memlist) > 0
    fp.write("namespace %s {\n" % p.modname)
    fp.write("// generated from file '%s'\n\n" % fname)
    if has_activate:
        fp.write("volatile bool inited = false;\n")
    p.write(fp, "var-decl")
    fp.write("int\tfSamplingFreq;\n\n")
    fp.write("void init(int samplingFreq)\n{\n")
    p.write(fp, "var-init", 1, filt=p.ui.var_filter())
    if has_activate:
        fp.write("\tinited = true;\n")
    fp.write("}\n\n")
    if has_activate:
        fp.write("void mem_free()\n{\n")
        fp.write("\tinited = false;\n")
        fp.write("\tjack_sync();\n")
        p.write(fp, "var-free", 1)
        fp.write("}\n\n")
        fp.write("inline bool is_inited()\n")
        fp.write("{\n    return inited;\n}\n\n")
        fp.write(activate)
    fp.write("void compute(int count%s%s)\n{\n" % (
        "".join([", float *input%d" % i for i in range(p.getNumInputs())]),
        "".join([", float *output%d" % i for i in range(p.getNumOutputs())])))
    p.write(fp, "compute", 1)
    fp.write("\t}\n}\n\n")
    fp.write("static struct RegisterParams { RegisterParams(); } RegisterParams;\n")
    fp.write("RegisterParams::RegisterParams()\n{\n")
    p.ui.write(fp)
    fp.write('\tregisterInit("%s", init);\n' % p.topname)
    fp.write("}\n")
    fp.write("\n} // end namespace %s\n" % p.modname)

def main():
    op = OptionParser(usage="usage: %prog [options] <faust-dsp-file>")
    op.add_option("-o", "--output", dest="oname",
                  help="write c++ code to FILE", metavar="FILE")
    op.add_option("-d", "--double", dest="faust", action="append",
                  help="additional faust options, build with double precision") 
    op.add_option("-f", "--float", dest="faustf", action="append",
                  help="additional faust options, build with single precision")
    op.add_option("-s", "--memory-threshold", dest="memory_threshold",
                  default=0, type="int",
                  help="change static memory allocations above threshold to dynamic ones")
    options, args = op.parse_args()
    if len(args) != 1:
        op.error("exactly one input filename expected\n")
    fname = args[0]
    if not os.path.exists(fname):
        print "error: can't open '%s'" % fname
        raise SystemExit, 1
    if options.faust:
        precision = ' -double'
    elif options.faustf:
        precision = ' -single'
    else:
        precision = ''
    faust = Popen("faust %s%s" % (fname, precision), shell=True, stdout=PIPE)
    try:
        parser = Parser(faust.stdout,
                        os.path.splitext(os.path.basename(fname))[0],
                        options.memory_threshold)
    except ValueError, e:
        if faust.wait() == 0:
            print e
        raise SystemExit, 1
    if options.oname:
        outp = file(options.oname,"w")
    else:
        outp = sys.stdout
    output(outp, parser, fname)

if __name__ == "__main__":
    main()
