"""An instance of class Factor (or a subclass) is a data structure that
represents an expression in factor algebra.

Manipulation of these expressions is symfer's core functionality. Afterwards,
a factor expression can be exported as a YAML file using function
dumpyaml(expr1,filename); likewise, symfer can also load factor expressions
from YAML files using expr2 = loadyaml(filename). Roundtripping should result
in the same factor algebra expression (expr1==expr2 and sharing is preserved).

The representation of a Factor is kept as close as possible to the YAML data
structure; for example, the mapping
   !sumprod {arg: [A,B], fac: []}
corresponds to the factor
   SumProd(arg=['A','B'], fac=[])
which is an object with attributes 'arg' and 'fac'. The only difference is
(currently) that each factor has auxiliary attributes 'domlist' and 'domtypes'
providing information about its domain variables.

To provide maximum flexibility in the data structure layout, a factor
expression is not limited to Factor instances but can be mixed with dict.
For example:
   SumProd(arg=[], fac=[ {'label': 'aap', 'fac': [Multinom([],[])]} ])
and
   fac:
   - !multinom {dom: [], par: []}
   - !sumprod {arg: [], fac: []}
are valid factor expressions in Python and YAML format, respectively.
If such a dict contains another factor expression, it should be in a singleton
list behind the key 'fac'. Then, in constructing the Python object from YAML,
the domlist/domtypes information of that factor is copied to the dict.
Otherwise, the dict gets domlist [] and domtypes {}.
(When domlist/domtypes are already provided in the dict, these are used,
however they are currently stripped in dumpyaml so this doesn't roundtrip.)

In order to give both Factor and dict factors a compatible interface, the
attributes of a factor can also be accessed as dict items:
   mysumprod['arg']

Factor expressions are meant to be IMMUTABLE; especially since they are meant
to be freely shared, mutation is a very bad idea and can break things at
unexpected places. It should only happen in a factor's own __init__ method.
"""

import yaml
import copy
import random
from utils import foldsh, identity, normalize, product

__all__ = """
    Factor SumProd Index SimpleIndex Multinom Fun Embed Reorder I
    getfac setfac getdetfac setdetfac folddetfac reconstruct
    remove_doms mergedoms makedom
    """.split()

# - deserializing from YAML has the advantage over unpickling that we can change
# the class definitions
# - also possible to use yaml.add_constructor, yaml.add_representer

class Factor(yaml.YAMLObject):
    yaml_loader = yaml.SafeLoader
    def __getitem__(self,key): # or use __getattribute__ ?; also: raise appropriate errors
        return self.__dict__[key]
    def __setitem__(self,key,value):
        self.__dict__[key] = value
    def __delitem__(self,key):
        del self.__dict__[key]
    def __contains__(self,item):
        return item in self.__dict__
    def get(self,key,default):
        return self.__dict__.get(key,default)
    def __eq__(self,other):
        return (type(self) == type(other) and self.__dict__ == other.__dict__)
    def __ne__(self,other):
        return not self.__eq__(other)
    def __len__(self):
        return product(len(domtype) for domtype in self.domtypes.values())
        
    def sumto(self,remvars):
        """Convenience method for constructing a SumProd given the *remaining*
        variables instead of the variables being summed out.
        
        This method is applied to a base factor. If this factor is already a
        SumProd, it will return a copy with extra variables added to self.arg.
        If not, it will construct a SumProd with one factor. If no variables
        need to be summed over, the base factor itself is simply returned.
        """
        sumvars = list(set(self.domlist) - set(remvars))
        if sumvars:
            if isinstance(self,SumProd):
                return SumProd(self.arg+sumvars,self.fac)
            else:
                return SumProd(sumvars,[self])
        else:
            return self
    def sumout(self,sumvars):
        """Convenience method for constructing a SumProd.
        
        This method is applied to a base factor. If this factor is already a
        SumProd, it will return a copy with extra variables added to self.arg.
        If not, it will construct a SumProd with one factor. If no variables
        need to be summed over, the base factor itself is simply returned.
        """
        if sumvars:
            if isinstance(self,SumProd):
                return SumProd(self.arg+sumvars,self.fac)
            else:
                return SumProd(sumvars,[self])
        else:
            return self
    def product(self,*fac):
        """Convenience method for constructing a product."""
        realfac = [f for f in (self,)+fac if f != I()]
        if len(realfac)==0:
            return I()
        elif len(realfac)==1:
            return realfac[0]
        else:
            return SumProd([],realfac)
    def reorder(self,domlist): #TODO omit Reorder when no changes
        """Convenience method returning Reorder(..,[self]).
        
        The new domain consists of the variables in domlist, in that order.
        An element of domlist can be one of the following:
        
        - var (a simple variable name)
        - a 1-element dict {newvar:oldvar}        
        """
        mapping = dict((d,d) for d in domlist if not isinstance(d,dict))
        dic,_ = smartdom(domlist,mapping)
        return Reorder(dic,[self])
    def index(self,*det,**dic):
        """Convenience method for indexing by det.factors and var=val list:
        
        - f.index(Fun.random(...))
        - f.index(Rain=False)
        - f.index({'Rain':False})
        """
        if len(det)==1 and isinstance(det[0],dict):
            dic.update(det[0])
            det = []
        funs = [ Fun(cod=[{var:self.domtypes[var]}],dom=[],par=[self.domtypes[var].index(val)])
                 for var,val in dic.iteritems() ]
        newdet = list(det) + funs
        if len(newdet):
            return Index(newdet, [self])
        else:
            return self
    def embed(self,*det):
        return Embed([self]+det)          
         
class SumProd(Factor):
    yaml_tag = u'!sumprod'
    def __init__(self,arg,fac): # NB not used in YAML parsing
        self.arg = arg
        self.fac = fac
        self.domlist, self.domtypes = mergedoms(fac)
        for v in arg:
            self.domlist.remove(v)
            del self.domtypes[v]
    def __repr__(self):
        return 'SumProd(arg={0},fac={1})'.format(self.arg,self.fac)

class Multinom(Factor):
    yaml_tag = u'!multinom'
    def __init__(self,dom,par): # NB not used in YAML parsing
        self.dom = dom
        self.par = par
        self.domlist, self.domtypes = split_ordered_list(dom)
    def __repr__(self):
        return 'Multinom(dom={0},par={1})'.format(self.dom,self.par)
    @staticmethod
    def random(domlist, **domtypes):
        """Return a Multinom with given domain and random parameters.
        
        The domain consists of the variables in domlist, in that order.
        An element of domlist can be one of the following:
        - var (a simple variable name)
        - a 1-element dict {var:domtype}
        
        If domtypes has a var:domtype entry for the same var, domtype
        will override the one in the list. If the list contains only
        var, such an entry is mandatory.
        If domtype x is an int, it will be converted to range(x).
        The parameters are normalized (sum to 1) over the first
        variable.
        """
        dom, domsizes = smartdom(domlist,domtypes)
        par_unn = (random.random() for _ in xrange(product(domsizes)))
        par = normalize(par_unn,domsizes[0])
        return Multinom(dom,par)

class SimpleIndex(Factor): # to be obsoleted
    yaml_tag = u'!sindex'
    def __init__(self,det,fac): # NB not used in YAML parsing
        self.det = det   # {'var':'val'} dict
        self.fac = fac
        self.domlist = fac[0]['domlist'][:] # make a shallow copy because we'll mutate
        self.domtypes = dict(fac[0]['domtypes'].iteritems()) # idem
        for v in det.keys():
            self.domlist.remove(v)
            del self.domtypes[v]
    def __repr__(self):
        return 'SimpleIndex(det={0},fac={1})'.format(self.det,self.fac)

#TODO: more error checking
class Index(Factor):
    yaml_tag = u'!index'
    def __init__(self,det,fac): # NB not used in YAML parsing
        self.det = det   # list of Fun
        self.fac = fac   # single element list
        assert len(fac)==1
        self.domlist = fac[0]['domlist'][:] # make a shallow copy because we'll mutate
        self.domtypes = dict(fac[0]['domtypes'].iteritems()) # idem
        if 'codlist' in fac[0]:
            self.codlist = fac[0].codlist
            self.codtypes = fac[0].codtypes
        indexdomlist = [] # all indexing doms, in order of occurrence
        indexed = set()   # all indexed doms
        for d in det:
            indexed.add(d.codlist[0])
            self.domlist.remove(d.codlist[0])
            del self.domtypes[d.codlist[0]]
            for v in d.domlist:
                try:
                    self.domlist.remove(v)
                except ValueError:
                    pass
                if v not in indexdomlist:
                    indexdomlist.append(v)
                    self.domtypes[v] = d.domtypes[v]
        assert not set(indexdomlist) & indexed
        self.domlist.extend(indexdomlist)
    def __repr__(self):
        return 'Index(det={0},fac={1})'.format(self.det,self.fac)

class Fun(Factor):
    yaml_tag = u'!fun'
    def __init__(self,cod,dom,par=None,cls=None):
        self.cod = cod
        self.dom = dom
        self.domlist, self.domtypes = split_ordered_list(dom)
        self.codlist, self.codtypes = split_ordered_list(cod)
        if par is not None:
            self.par = par
        if cls is not None:
            self.cls = cls
    def __repr__(self):
        attrs = []
        for attr in ('cls','cod','dom','par'):
            if attr in self:
                attrs.append(attr + '=' + repr(self[attr]))
        return 'Fun('+','.join(attrs)+')'
    @staticmethod
    def random(codlist, domlist, **domtypes):
        """Return a Fun with given (co)domain and random parameters.
        
        The domain consists of the variables in domlist, in that order.
        An element of domlist can be one of the following:
        
        - var (a simple variable name)
        - a 1-element dict {var:domtype}
        
        If domtypes has a var:domtype entry for the same var, domtype
        will override the one in the list. If the list contains only
        var, such an entry is mandatory.
        If domtype x is an int, it will be converted to range(x).
        """
        coddom, coddomsizes = smartdom(codlist+domlist,domtypes)
        cod = coddom[:len(codlist)]
        dom = coddom[len(codlist):]
        codsize = coddomsizes[0]
        domsizes = coddomsizes[len(codlist):]
        par = [random.randrange(codsize) for _ in xrange(product(domsizes))]
        return Fun(cod,dom,par)


class Embed(Factor):
    yaml_tag = u'!embed'
    def __init__(self,det):
        self.det = det
        self.domlist,self.domtypes = mergedoms(det)
        for d in reversed(det):
            assert not set(self.domlist).intersection(d.codlist)
            self.domlist[0:0] = d.codlist
            self.domtypes.update(d.codtypes)
    def __repr__(self):
        return 'Embed({0})'.format(self.det)

class Reorder(Factor):
    yaml_tag = u'!reorder'
    def __init__(self,dic,fac):
        self.dic = dic
        self.fac = fac
        self.domlist = []
        self.domtypes = {}
        for entries in dic:
            for (k,v) in entries.iteritems(): # should be 1
                self.domlist.append(k)
                self.domtypes[k] = fac[0].domtypes[v]       
    def __repr__(self):
        return 'Reorder({0},{1})'.format(self.dic,self.fac)


class I(Factor):
    yaml_tag = u'!i' # NB use '!i {}' in YAML
    """Unit for factor multiplication."""
    def __init__(self):
        self.domlist = []
        self.domtypes = {}
    def __repr__(self):
        return 'I()'


def getfac(tree):
    """Return the subtrees of a factor tree, ie tree['fac']."""    
    return tree['fac']

def setfac(tree,reschildren):
    """Return a copy of this factor tree with tree['fac'] set to reschildren."""
    newobj = copy.copy(tree)
    newobj['fac'] = reschildren
    return newobj

def foldfac(tree,
            branch=setfac, leaf=identity, shared=identity,
            getchildren=getfac):
    """Transform a factor tree, taking possible shared subtrees into account.
    
    Subtrees are assumed to be listed under a tree['fac'] item (an iterable).
    - If there is no such item, apply the function 'leaf'.
    - If this tree has been encountered before, apply the function 'shared' to
      the previous result.
    - If there are subtrees, recurse into them, collect the results into a list,
      and apply the function 'branch' to this list.
      
    When using the default values for these functions, a copy of the factor tree
    is returned (sharing leaves between original and copy).
    """
    return foldsh(tree,branch,leaf,shared,getchildren)

def getdetfac(tree):
    """Get det and fac children. If neither are present, raise an error."""
    if isinstance(tree,SimpleIndex):
        return tree['fac']  #in SimpleIndex (to be obsoleted), det is no child
    try:
        ch = tree['det']
    except KeyError:
        ch = tree['fac']
    else:
        ch = ch + tree.get('fac',[]) #NB do not use ch.extend - mutation!
    return ch

def setdetfac(tree,reschildren):
    newobj = copy.copy(tree)
    rciter = iter(reschildren)
    if 'det' in tree:
        newobj['det'] = [rciter.next() for _ in tree['det']]
    if 'fac' in tree:
        newobj['fac'] = [rciter.next() for _ in tree['fac']]
    # make sure len(reschildren) equals len(tree['det'])+len(tree['fac'])
    assert not [x for x in rciter]
    return newobj

def folddetfac(tree,
               branch=setdetfac, leaf=identity, shared=identity,
               getchildren=getdetfac):
    return foldsh(tree,branch,leaf,shared,getchildren)
               

def reconstruct(tree):
    """Reconstruct a factor tree using the class constructors (which YAML does
    not use).
    
    All keys/value combinations in the factor's dict are passed to the constructor
    as keyword arguments.
    """
    def branch(tree,reschildren):
        if isinstance(tree,Factor):
            kwargs = setdetfac(tree.__dict__,reschildren)
            return tree.__class__(**kwargs)
        elif isinstance(tree,dict):
            newtree = setdetfac(tree,reschildren)
            if 'domlist' not in newtree:
                if reschildren:
                    newtree['domlist'] = reschildren[0]['domlist']
                    newtree['domtypes'] = reschildren[0]['domtypes']
                else:
                    newtree['domlist'] = []
                    newtree['domtypes'] = {}
            return newtree
        else:
            assert False
    def leaf(tree):
        if isinstance(tree,Factor):
            kwargs = tree.__dict__
            return tree.__class__(**kwargs)
        elif isinstance(tree,dict):
            if 'domlist' not in tree:
                tree = copy.copy(tree)
                tree['domlist'] = []
                tree['domtypes'] = {}
            return tree
    return folddetfac(tree,branch,leaf)

def remove_doms(tree):
    """Remove auxiliary attritutes 'domlist', 'domtypes', 'codlist', 'codtypes' from factor tree."""
    def branch(obj,ch):
        if 'domlist' in obj:
            obj = setdetfac(obj,ch) # NB makes a copy
            del obj['domlist']
            del obj['domtypes']
            if 'codlist' in obj:
                del obj['codlist']
                del obj['codtypes']
        return obj
    def leaf(obj):
        if 'domlist' in obj:
            obj = copy.copy(obj)
            del obj['domlist']
            del obj['domtypes']        
            if 'codlist' in obj:
                del obj['codlist']
                del obj['codtypes']
        return obj
    return folddetfac(tree,branch,leaf)

#TODO: optionally use "is" instead of "==" for domain equivalence?
# and maybe also test the domain names using "is"?
def mergedoms(fac):
    """Merge the domlist/domtype of a list of factors. Detect domtype conflicts."""
    domlist = []
    domtypes = {}
    for f in fac:
        for v in f.domlist:
            try:
                if not domtypes[v]==f['domtypes'][v]:
                    raise Exception("domain conflict for {0}: {1} vs {2}".format(v,domtypes[v],f['domtypes'][v]))
            except KeyError:
                domlist.append(v)
                domtypes[v] = f['domtypes'][v]
    return domlist,domtypes

def split_ordered_list(dom):
    """Split a [{dom:domtype}] ordered list into a list + dict."""
    domlist = []
    domtypes = {}
    for m in dom:
        domlist.append(m.keys()[0])
        domtypes.update(m)
    return (domlist,domtypes)

def smartdom(domlist,domtypes):
    """Process a convenience domlist (e.g. for Multinom.random())."""
    dom = []  # in [{var:domtype},...] format
    varset = set() # to check for duplicates
    domsizes = []
    
    for d in domlist:
        try: # is it a 1-element dict?
            ((var,domtype),) = d.iteritems()
        except (ValueError, AttributeError):
            var = d
            domtype = domtypes[var]
        else:
            # override domtype from domtypes, if present
            domtype = domtypes.get(var,domtype)            
        try: # will work if domtype is an int
            domtype = range(domtype)
        except TypeError:
            pass # not an int, leave domtype as it is
        if var in varset:
            raise ValueError('duplicate variable in domlist: '+str(var))
        dom.append({var:domtype})
        try:
            domsizes.append(len(domtype))
        except AttributeError:
            domsizes.append(None)
        varset.add(var)
    return dom,domsizes

def makedom(domlist,domtypes):
    return [{var:domtypes[var]} for var in domlist]
