from .utils import product
from .factor import *

__all__= """
    ve_order   ve_minweight
    junctiontree    marginals   
    indextree  cost
    """.split()

def ve_order(facs,order):
    """Variable elimination with a given elimination order.
    
    facs: list of factors, or dict with factor values
    order: list of variables
    Returns a factor expression equal to the product of facs, with the
    variables summed out.
    """
    try:
        facs = facs.values() # if we get a dict as input instead of a list
    except AttributeError:
        pass
    for rv in order:
        rv_sum = SumProd([rv],[fac for fac in facs if rv in fac.domtypes])
        other_facs = [fac for fac in facs if rv not in fac.domtypes]
        facs = [rv_sum] + other_facs
    return SumProd([],facs)        

def ve_minweight(facs,query):
    """Variable elimination with minweight heuristic.
    
    facs: list of factors, or dict with factor values
    query: list of variables to keep
    Returns a factor expression equal to the product of facs, with all
    variables summed out except those in query.
    """
    try:
        facs = facs.values() # if we get a dict as input instead of a list
    except AttributeError:
        pass
    remaining = set().union(*[set(fac.domlist) for fac in facs])
    remaining -= set(query)
    while remaining:
        cand = None,None,float('inf')
        for rv in remaining:
             rv_facs = [fac for fac in facs if rv in fac.domtypes]
             rv_prod = SumProd([],rv_facs)
             rv_weight = len(rv_prod) # causes OverflowError if > sys.maxint (e.g. on 32-bit systems)
             if rv_weight < cand[2]:
                cand = rv,rv_prod,rv_weight
        rv,rv_prod,_ = cand
        remaining.remove(rv)
        rv_sum = SumProd([rv],rv_prod.fac)
        other_facs = [fac for fac in facs if rv not in fac.domtypes]
        facs = [rv_sum] + other_facs
    return SumProd([],facs)


# assume tree is in SumProd format where RIP holds
def junctiontree(tree):
    """Junction tree propagation based on a nested SumProd expression.
    
    Returns a dict {'fac': L}. Here, L is a list of factor expressions, one
    for each leaf in the original expression. Subexpressions of these are
    shared such that the evaluation cost is only about twice that of the
    original expression.
    """
    def downwards(tree,upfac): #upfac.domlist is subset of tree.domlist
        if isinstance(tree,SumProd):
            out = []
            for f in tree.fac:
                prod_others = SumProd([],[upfac]+[other for other in tree.fac if other is not f])
                facres = downwards(f,prod_others.sumto(f.domlist))
                out.extend(facres)
            return out
        elif isinstance(tree,(Multinom,SimpleIndex,Index)):
            return [SumProd([],[tree,upfac])]
        else:
            raise TypeError('unexpected factor type: '+str(type(tree)))
    return setfac({},downwards(tree,Multinom([],[1.0]))) # TODO: use I()

def marginals(multifac):
    """Single-variable marginals, based on a multi-factor expression (like the
    one returned from the function junctiontree(tree)."""
    margs = {}
    for fac in multifac['fac']:
        for var in fac.domlist:
            if var not in margs:
                margs[var] = fac.sumto([var])
    return setfac({},margs.values())

def indextree(det,tree):
    """Insert evidence into a factor expression, just above each Multinom
    factor that contains an evidence variable.
    
    det: dict of type {'varname': 'observed val'}
    """
    #TODO test if it works with multiple vars
    ixvars = set(det.keys())
    def leaf(tree):
        assert isinstance(tree,Multinom)
        intersect = ixvars & set(tree.domtypes.keys())
        return tree.index(dict((k,det[k]) for k in intersect))
    def branch(tree,reschildren):
        assert isinstance(tree,SumProd)
        intersect = ixvars & set(tree.arg)
        #if intersect:#NOOO always create a new SumProd! otherwise domlist is wrong!
        newarg = [v for v in tree.arg if v not in intersect]
        if newarg==[] and len(reschildren)==1:
            return reschildren[0]
        else:
            return SumProd(newarg,reschildren)                
    return folddetfac(tree,branch,leaf)

def cost(tree):
    '''Count additions, multiplications and lookups.'''
    def leaf(tree):
        return (0,0,0)
    def shared(tree):
        return (0,0,0)
    def branch(tree,reschildren):
        if isinstance(tree,SumProd):
            # nr of additions is equal to the size of the intermediate domain
            nadd = len(SumProd([],tree.fac))
            # assume smart implementation that only multiplies if #fac > 1
            nfac = len(tree.fac)
            nmult = nadd * (nfac if nfac>1 else 0)
            # lookups are always done
            nlook = nadd * nfac
            allres = [(nadd,nmult,nlook)]+reschildren
            return tuple(map(sum,zip(*allres)))
        elif isinstance(tree,dict): # for junctiontree result
            return tuple(map(sum,zip(*reschildren)))
        else:
            assert False
    return folddetfac(tree,leaf=leaf, branch=branch, shared=shared)


