"""Defines the inference algorithms as simple as possible, for the article."""

from symfer import SumProd, Multinom, I, indextree
from symfer.utils import product

# NB facs is list of factors
def ve_minweight(facs,query):
    remaining = set().union(*[set(fac.domlist) for fac in facs])
    remaining -= set(query)
    while remaining:
        cand = None,None,float('inf')
        for rv in remaining:
             rv_facs = [fac for fac in facs if rv in fac.domtypes]
             rv_prod = I().product(*rv_facs)
             rv_weight = len(rv_prod)
             if rv_weight < cand[2]:
                cand = rv,rv_prod,rv_weight
        rv,rv_prod,_ = cand
        remaining.remove(rv)
        rv_sum = rv_prod.sumout([rv])
        other_facs = [fac for fac in facs if rv not in fac.domtypes]
        facs = [rv_sum] + other_facs
    return I().product(*facs)

def ve_order(facs,order):
    for rv in order:
        rv_sum = SumProd([rv],[f for f in facs if rv in f.domtypes])
        other_facs = [f for f in facs if rv not in f.domtypes]
        facs = [rv_sum] + other_facs
    return I().product(*facs)

def ve_order_obs(facs,order,obs):
    obsfacs = [f.index(**dict((var,val) for var,val in obs.iteritems() if var in f.domlist)) for f in facs]
    for rv in order:
        rv_sum = SumProd([rv],[f for f in obsfacs if rv in f.domtypes])
        other_facs = [f for f in obsfacs if rv not in f.domtypes]
        obsfacs = [rv_sum] + other_facs
    return I().product(*obsfacs)

def naive(facs,query,obs):
    return I().product(*facs).index(obs).sumto(query)

# NB returns list of factors
def junctiontree(tree,upfac=I()):
    if isinstance(tree,SumProd):
        out = []
        for f in tree.fac:
            prod_others = upfac.product(*[other for other in tree.fac if other is not f])
            facres = junctiontree(f,prod_others.sumto(f.domlist))
            out.extend(facres)
        return out
    elif isinstance(tree,Multinom):
        return [upfac.product(tree)]

def marginals(jtlist):
    margs = {}
    for fac in jtlist:
        for var in fac.domlist:
            if var not in margs:
                margs[var] = fac.sumto([var])
    return margs

def condition(tree):
    assert isinstance(tree,SumProd)
    var = tree.arg[0]
    result = [indextree({var:val},tree) for val in I().product(*tree.fac).domtypes[var]]
    return result
    
