from .factor import *
from .utils import repr_elems,trim,product
import itertools

__all__ = """
    cgen
    cgenstats
    countrefs
    """.split()

def countrefs(tree):
    '''Count number of references (post-order).'''
    nrefs = []
    ids = itertools.count()
    def leaf(tree):
        nrefs.append(0)
        return ids.next()
    def branch(tree,reschildren):
        for childid in reschildren:
            nrefs[childid] += 1
        nrefs.append(0)
        return ids.next()
    folddetfac(tree,branch,leaf)
    return nrefs

def arr(nodeid):
    return 'a{0:03d}'.format(nodeid)
def fun(nodeid):
    return 'f{0:03d}'.format(nodeid)    
#def cumprod(xs):
#    prod = 1
#    for x in xs:
#        prod *= x
#        yield prod

def cgenstats(tree):
    '''Count number of multiplications, additions and lookups in plan.'''
    def leaf(tree):
        assert isinstance(tree,Multinom)
        return 0,0,0
    def shared(res):
        return 0,0,0
    def branch(tree,reschildren):
        if isinstance(tree,SumProd):
            nadd = len(I().product(*tree.fac))
            nmult = nadd * (len(tree.fac) - 1)
            nlook = nadd * len(tree.fac)
            triples = reschildren + [(nmult,nadd,nlook)]
            sums = map(sum,zip(*triples))
            return tuple(sums)
        else:
            assert False #unknown class
    return folddetfac(tree,branch=branch,leaf=leaf,shared=shared)

def cgen(tree):
    ids = itertools.count()
    refs = countrefs(tree)
    def leaf(tree):
        if isinstance(tree,Multinom):
            elt_type = 'float'
            pars=repr_elems(tree.par)
        elif isinstance(tree,Fun):
            elt_type = 'int'
            pars=repr_elems(tree.par)
        elif isinstance(tree,I):
            elt_type = 'float'
            pars='1.0'
        else:
            assert False            
        nodeid = ids.next()
        refs[nodeid] -= 1;
        funcode = '''
            {elt_type} *{fun}()
            {{
                {elt_type} stackarr[{arrlen}] = {{ {pars} }};
                {elt_type} *heaparr = malloc(sizeof({elt_type})*{arrlen});
                memcpy(heaparr,stackarr,sizeof({elt_type})*{arrlen});
                return heaparr;
            }}
            '''.format(elt_type=elt_type,fun=fun(nodeid),arrlen=len(tree),pars=pars)        
        maincode =\
            '{elt_type} *{arr} = {fun}();'.format(elt_type=elt_type,arr=arr(nodeid),fun=fun(nodeid))
        return trim(funcode,8),maincode,nodeid
    def branch(tree,reschildren):
        nodeid = ids.next()
        refs[nodeid] -= 1;
        chfuncode = '\n        '.join(code for code,_,_ in reschildren)
        chmaincode = '\n            '.join(code for _,code,_ in reschildren)        
        cids = [cid for _,_,cid in reschildren]
        #----------------------------------------------------------------------------------------
        if isinstance(tree,SumProd):
            imdomlist,imdomtypes = mergedoms(tree.fac)
            imdsize = len(imdomlist)
            strides = {}
            backstrides = {}
            for cid,f in zip(cids,tree.fac):
                fstrides = {}
                fbackstrides = {}                
                curstride = 1
                for d in f.domlist:
                    fstrides[d] = curstride
                    fbackstrides[d] = curstride * (len(f.domtypes[d])-1)
                    curstride *= len(f.domtypes[d])
                strides[cid] = [fstrides.get(imd,0) for imd in imdomlist]
                backstrides[cid] = [fbackstrides.get(imd,0) for imd in imdomlist]
            
            resstrides = []
            startingsumvar = True
            curstride = 1
            for imd in imdomlist:
                if imd in tree.arg:
                    if startingsumvar:
                        resstrides.append(0)
                    else:
                        resstrides.append(1-curstride)
                else:
                    startingsumvar = False
                    curstride *= len(imdomtypes[imd])
                    resstrides.append(1)

            formal_args = ', '.join(
                'float *{arr}'.format(arr=arr(cid))
                               for cid in cids)
            concrete_args = ', '.join(
                '{arr}'.format(arr=arr(cid))
                               for cid in cids)
            defstrides = '\n                    '.join(
                'int strides_{cid}[{imdsize}] = {{ {fslist} }};'.format(cid=cid,imdsize=imdsize,fslist=repr_elems(strides[cid]))
                                  for cid in cids)
            defbackstrides = '\n                    '.join(
                'int backstrides_{cid}[{imdsize}] = {{ {fbslist} }};'.format(cid=cid,imdsize=imdsize,fbslist=repr_elems(backstrides[cid]))
                                      for cid in cids)
            defresstrides =\
                'int resstrides[{imdsize}] = {{ {rslist} }};'.format(imdsize=imdsize,rslist=repr_elems(resstrides))
            defimdomcards =\
                'int imdomcards[{imdsize}] = {{ {sizelist} }};'.format(
                    imdsize=imdsize,sizelist=repr_elems(len(imdomtypes[d]) for d in imdomlist))
            declindices = '\n                    '.join(
                'int facix_{cid} = 0;'.format(cid=cid)
                                   for cid in cids)
            resultprod = ' * '.join('{arr}[facix_{cid}]'.format(arr=arr(cid),cid=cid) for cid in cids)
            takestrides = '\n                                '.join(
                'facix_{cid} += strides_{cid}[i];'.format(cid=cid)
                                   for cid in cids)
            takebackstrides = '\n                                '.join(
                'facix_{cid} -= backstrides_{cid}[i];'.format(cid=cid)
                                   for cid in cids)
            freechildren = '\n                    '.join(
                'free({arr});'.format(arr=arr(cid))
                                    for cid in cids if refs[cid]==0)
            arrlen = len(tree)
            arrid = arr(nodeid)
            funid = fun(nodeid)
            
            funcode = '''
                float *{funid}({formal_args})
                {{
                    {defstrides}
                    {defbackstrides}
                    {defresstrides}
                    {defimdomcards}
                    
                    {declindices}
                    int resix = 0;
                    float *result = calloc(sizeof(float),{arrlen});
                    int assignment[{imdsize}] = {{0}};
                    int i;
                    
                    while (1)
                    {{
                        outerloop:
                        result[resix] += {resultprod};
                        for (i=0; i<{imdsize}; i++)
                        {{
                            assignment[i] += 1;
                            if (assignment[i]==imdomcards[i])
                            {{
                                assignment[i]=0;
                                {takebackstrides}
                            }} else {{
                                {takestrides}
                                resix += resstrides[i];
                                goto outerloop;
                            }}
                        }}
                        break;
                    }}
                    
                    {freechildren}
                    
                    return result;
                }}
                '''.format(**locals())
            maincode =\
                'float *{arrid} = {funid}({concrete_args});'.format(**locals())

            return '\n        '.join([chfuncode,trim(funcode,8)]), '\n            '.join([chmaincode,maincode]), nodeid
        # -----------------------------------------------------------------------------------------------------------
        elif isinstance(tree,Index):
            detcids = cids[:-1]
            detdomlist,_ = mergedoms(tree.det)
            remdomlist = tree.domlist[:]
            for v in detdomlist[::-1]:
                assert remdomlist.pop()==v
            remdomsize = len(remdomlist)
            nresdom = len(tree.domlist)
            fstrides = {}
            fbackstrides = {}
            curstride = 1
            for d in tree.fac[0].domlist:
                fstrides[d] = curstride
                fbackstrides[d] = curstride * (len(tree.fac[0].domtypes[d])-1)
                curstride *= len(tree.fac[0].domtypes[d])
            facstrides = [fstrides.get(resvar,0) for resvar in tree.domlist]
            facbackstrides = [fbackstrides.get(resvar,0) for resvar in tree.domlist]
            codstride = {}
            for cid,df in zip(detcids,tree.det):
                codstride[cid] = fstrides.get(df.codlist[0])
            detstrides = {}
            detbackstrides = {}
            for cid,df in zip(detcids,tree.det):
                fstrides = {}
                fbackstrides = {}                
                curstride = 1
                for d in df.domlist:
                    fstrides[d] = curstride
                    fbackstrides[d] = curstride * (len(df.domtypes[d])-1)
                    curstride *= len(df.domtypes[d])
                detstrides[cid] = [fstrides.get(resvar,0) for resvar in tree.domlist]
                detbackstrides[cid] = [fbackstrides.get(resvar,0) for resvar in tree.domlist]

            initcurfunval = '\n                    '.join(
                'int curfunval_{cid} = {arr}[0];'.format(cid=cid,arr=arr(cid)) for cid in detcids)
            initfacix = '\n                    '.join(
                'facix += curfunval_{cid} * codstride_{cid};'.format(cid=cid,arr=arr(cid)) for cid in detcids)
            

            formal_args = ', '.join(
                '{elt_type} *{arr}'.format(arr=arr(cid), elt_type='int' if 'codlist' in ch else 'float')
                               for cid,ch in zip(cids,getdetfac(tree)))
            concrete_args = ', '.join(
                '{arr}'.format(arr=arr(cid))
                               for cid in cids)
            deffacstrides =\
                'int facstrides[{nresdom}] = {{ {fslist} }};'.format(nresdom=nresdom,fslist=repr_elems(facstrides))
            deffacbackstrides =\
                'int facbackstrides[{nresdom}] = {{ {fbslist} }};'.format(nresdom=nresdom,fbslist=repr_elems(facbackstrides))
            defdetstrides = '\n                    '.join(
                'int detstrides_{cid}[{nresdom}] = {{ {dslist} }};'.format(cid=cid,nresdom=nresdom,dslist=repr_elems(detstrides[cid]))
                                  for cid in detcids)
            defdetbackstrides = '\n                    '.join(
                'int detbackstrides_{cid}[{nresdom}] = {{ {dbslist} }};'.format(cid=cid,nresdom=nresdom,dbslist=repr_elems(detbackstrides[cid]))
                                      for cid in detcids)
            defcodstrides = '\n                    '.join(
                'int codstride_{cid} = {cs};'.format(cid=cid,cs=codstride[cid])
                                  for cid in detcids)            
            defdomcards =\
                'int domcards[{nresdom}] = {{ {sizelist} }};'.format(
                    nresdom=nresdom,sizelist=repr_elems(len(tree.domtypes[d]) for d in tree.domlist))
            decldetix = '\n                    '.join(
                'int detix_{cid} = 0;'.format(cid=cid)
                                   for cid in detcids)
            takedetstrides = '\n                                '.join(
                             '''detix_{cid} += detstrides_{cid}[i];
                                newfunval = {arr}[detix_{cid}];
                                facix += (newfunval - curfunval_{cid}) * codstride_{cid};
                                curfunval_{cid} = newfunval;'''.format(cid=cid,arr=arr(cid))
                                   for cid in detcids)
            takedetbackstrides = '\n                                '.join(
                'detix_{cid} -= detbackstrides_{cid}[i];'.format(cid=cid)
                                   for cid in detcids)
            freechildren = '\n                    '.join(
                'free({arr});'.format(arr=arr(cid))
                                    for cid in cids if refs[cid]==0)
            arrlen = len(tree)
            arrid = arr(nodeid)
            funid = fun(nodeid)
            facarr = arr(cids[-1])
            elt_type = 'int' if 'codlist' in tree else 'float'
            
            funcode = '''
                {elt_type} *{funid}({formal_args})
                {{
                    {deffacstrides}
                    {deffacbackstrides}
                    {defdetstrides}
                    {defdetbackstrides}
                    {defcodstrides}
                    {defdomcards}
                    
                    {initcurfunval}
                    int facix = 0;
                    {initfacix}
                    {decldetix}
                    {elt_type} *result = malloc(sizeof({elt_type}) * {arrlen});
                    int assignment[{nresdom}] = {{0}};
                    int i;
                    int resix;
                    int newfunval;
                    
                    for (resix=0; resix<{arrlen};)
                    {{
                        result[resix] = {facarr}[facix];
                        for (i=0; i<{remdomsize}; i++)
                        {{
                            assignment[i] += 1;
                            if (assignment[i]==domcards[i])
                            {{
                                assignment[i]=0;
                                facix -= facbackstrides[i];
                            }} else {{
                                facix += facstrides[i];
                                goto outerloop;
                            }}
                        }}
                        for (i={remdomsize}; i<{nresdom}; i++)
                        {{
                            assignment[i] += 1;
                            if (assignment[i]==domcards[i])
                            {{
                                assignment[i]=0;
                                facix -= facbackstrides[i];
                                {takedetbackstrides}
                            }} else {{
                                facix += facstrides[i];
                                {takedetstrides}                                
                                goto outerloop;
                            }}
                        }}
                        outerloop:
                        resix++;
                    }}
                    
                    {freechildren}
                    
                    return result;
                }}
                '''.format(**locals())
            maincode =\
                '{elt_type} *{arrid} = {funid}({concrete_args});'.format(**locals())

            return '\n        '.join([chfuncode,trim(funcode,8)]), '\n            '.join([chmaincode,maincode]), nodeid

        # -----------------------------------------------------------------------------------------------------------
        elif isinstance(tree,Reorder):
            nresdom = len(tree.domlist)
            faccid = cids[0]
            fstrides = {}
            fbackstrides = {}
            curstride = 1
            for d in tree.fac[0].domlist:
                fstrides[d] = curstride
                fbackstrides[d] = curstride * (len(tree.fac[0].domtypes[d])-1)
                curstride *= len(tree.fac[0].domtypes[d])
            facstrides = [fstrides[old] for one in tree.dic for _,old in one.iteritems()]
            facbackstrides = [fbackstrides[old] for one in tree.dic for _,old in one.iteritems()]

            formal_args = '{elt_type} *{arr}'.format(arr=arr(faccid), elt_type='int' if 'codlist' in tree.fac[0] else 'float')
            concrete_args = '{arr}'.format(arr=arr(faccid))
            deffacstrides =\
                'int facstrides[{nresdom}] = {{ {fslist} }};'.format(nresdom=nresdom,fslist=repr_elems(facstrides))
            deffacbackstrides =\
                'int facbackstrides[{nresdom}] = {{ {fbslist} }};'.format(nresdom=nresdom,fbslist=repr_elems(facbackstrides))
            defdomcards =\
                'int domcards[{nresdom}] = {{ {sizelist} }};'.format(
                    nresdom=nresdom,sizelist=repr_elems(len(tree.domtypes[d]) for d in tree.domlist))

            freechildren = '\n                    '.join(
                'free({arr});'.format(arr=arr(cid))
                                    for cid in cids if refs[cid]==0)
            arrlen = len(tree)
            arrid = arr(nodeid)
            funid = fun(nodeid)
            facarr = arr(faccid)
            elt_type = 'int' if 'codlist' in tree else 'float'
            
            funcode = '''
                {elt_type} *{funid}({formal_args})
                {{
                    {deffacstrides}
                    {deffacbackstrides}
                    {defdomcards}
                    
                    int facix = 0;
                    {elt_type} *result = malloc(sizeof({elt_type}) * {arrlen});
                    int assignment[{nresdom}] = {{0}};
                    int i;
                    int resix;
                    int newfunval;
                    
                    for (resix=0; resix<{arrlen};)
                    {{
                        result[resix] = {facarr}[facix];
                        for (i=0; i<{nresdom}; i++)
                        {{
                            assignment[i] += 1;
                            if (assignment[i]==domcards[i])
                            {{
                                assignment[i]=0;
                                facix -= facbackstrides[i];
                            }} else {{
                                facix += facstrides[i];
                                goto outerloop;
                            }}
                        }}
                        outerloop:
                        resix++;
                    }}
                    
                    {freechildren}
                    
                    return result;
                }}
                '''.format(**locals())
            maincode =\
                '{elt_type} *{arrid} = {funid}({concrete_args});'.format(**locals())

            return '\n        '.join([chfuncode,trim(funcode,8)]), '\n            '.join([chmaincode,maincode]), nodeid
        # -----------------------------------------------------------------------------------------------------------
        elif isinstance(tree,Embed):
            nresdom = len(tree.domlist)
            ndetdom = len(tree.det)
            detstrides = {}
            detbackstrides = {}
            for cid,df in zip(cids,tree.det):
                fstrides = {}
                fbackstrides = {}                
                curstride = 1
                for d in df.domlist:
                    fstrides[d] = curstride
                    fbackstrides[d] = curstride * (len(df.domtypes[d])-1)
                    curstride *= len(df.domtypes[d])
                detstrides[cid] = [fstrides.get(resvar,0) for resvar in tree.domlist]
                detbackstrides[cid] = [fbackstrides.get(resvar,0) for resvar in tree.domlist]
            codstride={}
            curstride=1
            for cid,d in zip(cids,tree.domlist):
                codstride[cid] = curstride
                curstride *= len(tree.domtypes[d])
            resstride = curstride

            initcurfunval = '\n                    '.join(
                'int curfunval_{cid} = {arr}[0];'.format(cid=cid,arr=arr(cid)) for cid in cids)
            initresix = '\n                    '.join(
                'resix += curfunval_{cid} * codstride_{cid};'.format(cid=cid,arr=arr(cid)) for cid in cids)
            

            formal_args = ', '.join(
                'int *{arr}'.format(arr=arr(cid))
                               for cid in cids)
            concrete_args = ', '.join(
                '{arr}'.format(arr=arr(cid))
                               for cid in cids)

            defdetstrides = '\n                    '.join(
                'int detstrides_{cid}[{nresdom}] = {{ {dslist} }};'.format(cid=cid,nresdom=nresdom,dslist=repr_elems(detstrides[cid]))
                                  for cid in cids)
            defdetbackstrides = '\n                    '.join(
                'int detbackstrides_{cid}[{nresdom}] = {{ {dbslist} }};'.format(cid=cid,nresdom=nresdom,dbslist=repr_elems(detbackstrides[cid]))
                                      for cid in cids)
            defcodstrides = '\n                    '.join(
                'int codstride_{cid} = {cs};'.format(cid=cid,cs=codstride[cid])
                                  for cid in cids)            
            defdomcards =\
                'int domcards[{nresdom}] = {{ {sizelist} }};'.format(
                    nresdom=nresdom,sizelist=repr_elems(len(tree.domtypes[d]) for d in tree.domlist))
            decldetix = '\n                    '.join(
                'int detix_{cid} = 0;'.format(cid=cid)
                                   for cid in cids)
            takedetstrides = '\n                                '.join(
                             '''detix_{cid} += detstrides_{cid}[i];
                                newfunval = {arr}[detix_{cid}];
                                resix += (newfunval - curfunval_{cid}) * codstride_{cid};
                                curfunval_{cid} = newfunval;'''.format(cid=cid,arr=arr(cid))
                                   for cid in cids)
            takedetbackstrides = '\n                                '.join(
                'detix_{cid} -= detbackstrides_{cid}[i];'.format(cid=cid)
                                   for cid in cids)
            freechildren = '\n                    '.join(
                'free({arr});'.format(arr=arr(cid))
                                    for cid in cids if refs[cid]==0)
            arrlen = len(tree)
            arrid = arr(nodeid)
            funid = fun(nodeid)
            
            funcode = '''
                float *{funid}({formal_args})
                {{
                    {defdetstrides}
                    {defdetbackstrides}
                    {defcodstrides}
                    {defdomcards}
                    
                    {initcurfunval}
                    int resix = 0;
                    {initresix}
                    {decldetix}
                    float *result = calloc(sizeof(float), {arrlen});
                    int assignment[{nresdom}] = {{0}};
                    int i;
                    int newfunval;
                    int resstride = {resstride};
                    
                    while (1)
                    {{
                        outerloop:
                        result[resix] = 1.0;
                        resix += resstride;
                        for (i={ndetdom}; i<{nresdom}; i++)
                        {{
                            assignment[i] += 1;
                            if (assignment[i]==domcards[i])
                            {{
                                assignment[i]=0;
                                {takedetbackstrides}
                            }} else {{
                                {takedetstrides}                                
                                goto outerloop;
                            }}
                        }}
                        goto breakloop;
                    }}
                    
                    breakloop:
                    {freechildren}
                    
                    return result;
                }}
                '''.format(**locals())
            maincode =\
                'float *{arrid} = {funid}({concrete_args});'.format(**locals())

            return '\n        '.join([chfuncode,trim(funcode,8)]), '\n            '.join([chmaincode,maincode]), nodeid

        elif isinstance(tree,dict):
            maincode = '\n            '.join([printarr(arr(cid),len(ch)) for cid,ch in zip(cids,getdetfac(tree))])
            return chfuncode, '\n                '.join([chmaincode,maincode]), nodeid
            
        else: #isinstance(...)
            assert False
    
    def shared(res):
        funcode,maincode,nodeid = res
        return '','',nodeid
    
    def printarr(arr,arrlen):
        return(trim('''
            printf("[");
            if ({arrlen}>0)
            {{
                printf("%f",{arr}[0]);
            }}
            for (i=1; i<{arrlen}; i++)
            {{
                printf(", %f",{arr}[i]);
            }}
            printf("]\\n");
            free({arr});
            '''.format(arr=arr,arrlen=arrlen),12))
    
    funcode,maincode,rootid = folddetfac(tree,branch=branch,leaf=leaf,shared=shared)
    if isinstance(tree,Factor):
        printcode=printarr(arr(rootid),len(tree))
    else:
        printcode=''
    code = '''
        #include <stdio.h>
        #include <stdlib.h>
        #include <string.h>
        
        {funcode}
        
        main()
        {{
            int i;
            {maincode}
            {printcode}
            return(0);
        }}
        '''.format(funcode=funcode,maincode=maincode,printcode=printcode)
    return trim(code)

