"""Python-based numeric evaluator for symfer.

Uses NumPy (http://numpy.scipy.org) to do the array manipulations, and a layer
above it (FArray) that associates variable names to array dimensions.
"""
# At the moment, some of the name bookkeeping functionality is duplicated in
# module .factor and .farray; this is because the latter is a remnant of a
# previous version (pinfra).
# NB the FArray layer does not necessarily follow the domain order of .factor

import numpy as np
from . import farray
from .utils import nest
from .factor import *

__all__ = """evaluate""".split()

def evaluate(tree):
    """Numerically evaluate a Factor expression."""
    def branch(tree,reschildren):
        if isinstance(tree,dict): # top level for junction tree
            return reschildren
        elif isinstance(tree,SumProd):
            result = farray.FArray(np.array(1.0),[])
            for a in reschildren:
                result = result * a    # TODO: check why *= operator in FArray fails
            return result.sum(tree.arg)
        elif isinstance(tree,SimpleIndex):
            fac = tree.fac[0]
            facres = reschildren[0]
            return facres(**dict((d,fac.domtypes[d].index(v)) for d,v in tree.det.iteritems()))
        elif isinstance(tree,Index):
            fac = tree.fac[0]
            detres = reschildren[:-1]
            facres = reschildren[-1]
            indexed = facres(**dict((d.codlist[0],dr) for d,dr in zip(tree.det,detres)))
            revdomlist = tree.domlist[::-1]
            return farray.FArray(indexed.reorder(revdomlist),revdomlist)
        elif isinstance(tree,Embed):
            result = farray.FArray(np.array(1.0),[])
            for a,d in zip(reschildren,tree.det):
                zdims = a.arr.shape + tuple(len(d.codtypes[c]) for c in d.codlist)
                z = np.zeros(zdims)
                indices = []
                for i,s in enumerate(a.arr.shape):
                    index = [1 for _ in a.arr.shape]
                    index[i] = s
                    ar = np.arange(s)
                    ar.shape = tuple(index)
                    indices.append(ar)
                indices.append(a.arr)
                z[indices] = 1.0
                result = result * farray.FArray(z,a.dims + d.codlist)
            return result
        elif isinstance(tree,Reorder):
            old2new = dict((v,k) for entries in tree.dic for (k,v) in entries.iteritems())
            revdomlist = tree.domlist[::-1]
            return farray.FArray(reschildren[0].rename(old2new).reorder(revdomlist),revdomlist)
        else:
            raise TypeError('unexpected branch factor type: '+str(type(tree)))
    def leaf(tree):
        if isinstance(tree,(Multinom,Fun)):
            revdoms = tree.domlist[::-1]
            return farray.FArray(np.array(nest(tree.par,[len(tree.domtypes[k]) for k in revdoms])),revdoms)
        elif isinstance(tree,I):
            return farray.FArray(np.array(1.0),[])
        else:
            raise TypeError('unexpected leaf factor type: '+str(type(tree)))            
    
    res_farray = folddetfac(tree,branch,leaf)
    if isinstance(res_farray,farray.FArray):
        return farray_to_multinom(tree,res_farray)
    else:
        res = []
        for child,farr in zip(getfac(tree),res_farray):
            res.append(farray_to_multinom(child,farr))
        return setfac({},res)

def farray_to_multinom(tree,farr):
    return Multinom(makedom(tree.domlist,tree.domtypes),
                        list(farr.reorder(tree.domlist[::-1]).flat))
