import numpy as np
import types
from itertools import *

"""Dictionary-dimensioned arrays. Now without value lists, each dimension is supposed to consist of integers."""

def xmap(fun,unbound,bound=None):
    if bound is None:
        bound = []
    if len(unbound)==0:
        return fun(*bound)
    else:
        return [xmap(fun,unbound[1:],bound+[v]) for v in unbound[0]]


class FArray(object):
    
    def __init__(self,base,dimensions):
        if isinstance(base,np.ndarray):
            self.arr = base
            self.dims = dimensions
            if not len(base.shape)==len(dimensions):
                raise Exception('Number of array dimensions does not match length of name list.')
        elif isinstance(base,(types.FunctionType,types.MethodType)):
            if isinstance(dimensions,types.DictType):
                self.dims = dimensions.keys()
                sizes = dimensions.values()
            else:
                self.dims, sizes = unzip(dimensions)
            self.arr = np.fromfunction(base,sizes)
        else:
            raise Exception('NIY for '+repr(type(base)))
        self.shape = dict(zip(self.dims,self.arr.shape))
    
    def __call__(self,**pairs):
        ''' __call__ also provides indexing by FArrays, like numpy provides indexing by arrays.
        The FArray resulting from __call__ is like the original array minus the indexed dimensions,
        plus the dimensions of the indexing FArrays.        
        When using multiple array indices, their shapes should all be the same (modulo broadcasting).
        Do not index again on dimensions of indexing FArrays.'''
        # A problem is the way numpy shapes the result array for multiple indices:
        # - if the indices form an unbroken sequence (e.g. a[:,b,b,:,:] but also a[:,b,3,b,:])
        #   then the indexing dimensions are inserted at that place
        # - if the indexing arrays form a broken sequence (e.g. a[:,b,:,b,:] but also a[3,:,b,b,:])
        #   then the indexing dimensions are inserted at the start
        # We use the variable infirstsequence to determine which is the case.
        
        if not pairs:
            return self
        if not set(pairs.keys()) <= set(self.dims):
            raise Exception('Indexing over non-present dimensions: '+repr_elems(set(pairs.keys()) - set(self.dims)))
        
        fas = dict((d,i) for d,i in pairs.iteritems() if isinstance(i,FArray))
        arraylist1,dimorder = align(fas.values())
        noslices = dict((d,FArray(np.arange(s),[d]))
            for d,s in self.shape.iteritems() if d in dimorder and d not in fas.keys())
        arraylist2,_ = align(noslices.values(),dimorder)

        arraydict = dict(zip(fas.keys()+noslices.keys(),arraylist1+arraylist2))

        indices = []
        for d in self.dims:
            if d in arraydict:
                indices.append(arraydict[d])
            elif d in pairs:
                indices.append(pairs[d])
            else:
                indices.append(slice(None,None,None))
        
        nonsliceindices = False
        resultdims = []
        for d,i in zip(self.dims,indices):
            if isinstance(i,types.SliceType):
                resultdims.append(d)
                infirstsequence=False
            else:
                if nonsliceindices:
                    if not infirstsequence:
                        insertat=0
                else:
                    nonsliceindices=True
                    infirstsequence=True
                    insertat=len(resultdims)
        result = self.arr[tuple(indices)]
        if not isinstance(result,np.ndarray):
            result=np.array(result)
        if nonsliceindices:
            resultdims[insertat:insertat]=dimorder        
        return FArray(result,resultdims)
    
    def __repr__(self):
        return 'FArray(dimensions=' + repr(self.dims) + ', base=\n' + repr(self.arr) + ')'

    
    def sum(self,sumdims):
    # better idea?: collapse summed dimensions into one, then sum.
    # but they would have to be adjacent...
        """Sum over one or more dimensions."""
        result = self.arr
        remainingdims = self.dims[:]
        for d in sumdims:
            dix = remainingdims.index(d)
            result = np.add.reduce(result,dix)
            del remainingdims[dix]
        if not isinstance(result,np.ndarray):
            result = np.array(result)
        return FArray(result,remainingdims)
  
    def dimensions(self):
        return self.dims[:]

    
    def reorder(self,order=None):
        # maybe return dimension order
        if order==None: order=self.dims
        if len(order)==0:
            return self.arr
        if type(order) in [types.ListType,types.TupleType]:
            permutation = [self.dims.index(n) for n in order if n!=np.newaxis]
            extension = tuple((n if n==np.newaxis else slice(None,None,None)) for n in order)
        result = self.arr.transpose(permutation)[extension]
        return result
        
    def align(self,other):
        """Results in two ndarrays with aligned dimensions. These need not be of the same shape;
        the dimension size of one ndarray may be 1 where that of the other is not.
        Useful for preprocessing for a product-join."""
        [arr1,arr2],alldims = align([self,other]) # simply use this module's align function (outside the class)
        return (arr1,arr2,alldims)

        
    def cat(self,other,dimension):
        """ Concatenates two FArrays in the specified dimension.
        If an array lacks dimensions that the other array has, these dimensions are added and filled with copies
        to match the shape of the other array."""
        # note: may be slow...
        (arr1,arr2,alldims) = self.align(other)
        dimindex=alldims.index(dimension)
        for d in alldims:
            if d not in self.dims:
                arr1=np.concatenate([arr1 for _ in xrange(other.shape[d])],alldims.index(d))
            if d not in other.dims:
                arr2=np.concatenate([arr2 for _ in xrange(self.shape[d])],alldims.index(d))
        arr3=np.concatenate([arr1,arr2],dimindex)
        return FArray(arr3,alldims)
    
    def match(self,other,dtype=bool):
        """Returns an FArray R whose dimensions are the unions of self and other, with
        R(a=1,b=2,c=3) = ( S(a=1,b=2)==O(b=2,c=3) )
        Compare with __eq__."""
        (arr1,arr2,alldims) = self.align(other)
        return FArray(np.array(arr1==arr2,dtype),alldims)
        
    def __eq__(self,other):
        """Unlike Numpy, returns a Boolean instead of an array of Booleans. The latter can be obtained
        using self.match(other)"""
        if set(self.dims) != set(other.dims):
            return False
        else:
            return (self.arr == other.reorder(self.dims)).all()
        
    def __mul__(self,other):
        if isinstance(other,FArray):
            (arr1,arr2,alldims) = self.align(other)
            resarr = arr1.__mul__(arr2)
            if not isinstance(resarr,np.ndarray):
                resarr=np.array(resarr)
            return FArray(resarr,alldims)            
        else: # assume numeric scalar
            resarr = self.arr.__mul__(other)
            if not isinstance(resarr,np.ndarray):
                resarr=np.array(resarr)
            return FArray(resarr,self.dims)

    # WATCH OUT WITH MUTABILITY
    def __imul__(self,other):
        if isinstance(other,FArray):
            (arr1,arr2,alldims) = self.align(other)
            arr1.__imul__(arr2)
            if not isinstance(arr1,np.ndarray):
                arr1=np.array(arr1)
            self.arr = arr1
            self.dims = alldims
            self.shape = dict(zip(self.dims,self.arr.shape))
        else: # assume numeric scalar
            self.arr.__imul__(other)
            if not isinstance(self.arr,np.ndarray):
                self.arr=np.array(self.arr)
        return self

    # WATCH OUT WITH MUTABILITY
    def __iadd__(self,other):
        if isinstance(other,FArray):
            (arr1,arr2,alldims) = self.align(other)
            arr1.__iadd__(arr2)
            if not isinstance(arr1,np.ndarray):
                arr1=np.array(arr1)
            self.arr = arr1
            self.dims = alldims
            self.shape = dict(zip(self.dims,self.arr.shape))
        else: # assume numeric scalar
            self.arr.__iadd__(other)
            if not isinstance(self.arr,np.ndarray):
                self.arr=np.array(self.arr)
        return self

    def __add__(self,other):
        if isinstance(other,FArray):
            (arr1,arr2,alldims) = self.align(other)
            resarr = arr1.__add__(arr2)
            if not isinstance(resarr,np.ndarray):
                resarr=np.array(resarr)
            return FArray(resarr,alldims)            
        else: # assume numeric scalar
            resarr = self.arr.__add__(other)
            if not isinstance(resarr,np.ndarray):
                resarr=np.array(resarr)
            return FArray(resarr,self.dims)

    def __sub__(self,other):
        if isinstance(other,FArray):
            (arr1,arr2,alldims) = self.align(other)
            resarr = arr1.__sub__(arr2)
            if not isinstance(resarr,np.ndarray):
                resarr=np.array(resarr)
            return FArray(resarr,alldims)            
        else: # assume numeric scalar
            resarr = self.arr.__sub__(other)
            if not isinstance(resarr,np.ndarray):
                resarr=np.array(resarr)
            return FArray(resarr,self.dims)

    def __rmul__(self,other):
        return self.__mul__(other)

    def __div__(self,other):
        if isinstance(other,FArray):
            (arr1,arr2,alldims) = self.align(other)
            resarr = arr1.__div__(arr2)
            if not isinstance(resarr,np.ndarray):
                resarr=np.array(resarr)
            return FArray(resarr,alldims)            
        else: # assume numeric scalar
            resarr = self.arr.__div__(other)
            if not isinstance(resarr,np.ndarray):
                resarr=np.array(resarr)
            return FArray(resarr,self.dims)
        
    def rename(self,old2new):
        # returns a new FArray with renamed dimensions, but using the same underlying array (and dimension order)
        newnamelist = [old2new[old] for old in self.dims]
        return FArray(self.arr,newnamelist)
    
    def copy(self):
        return FArray(self.arr.copy(),self.dims)

    def normalize(self,dim=None):
        ''' Make sure that summing over dim returns an array of ones.'''
        if dim is None and len(self.dims)==1:
            dimix = 0
        else:
            dimix = self.dims.index(dim)
        arrsum = np.array(np.sum(self.arr,dimix)) # sum over dim dimension
        arrsum.shape = arrsum.shape[:dimix] + (1,) + arrsum.shape[dimix:]
        return FArray(self.arr / arrsum, self.dims)

def align(fas,order=None):
    alldims = dict_union(fa.shape for fa in fas)
    if order:
        if not set(alldims.keys()) <= set(order):
            raise Exception('Union of array dimensions is no subset of order: '
                + repr_elems(alldims.keys()) + ' vs ' + repr_elems(set(order)))
    else:
        order = list(alldims.keys())
    return ([fa.reorder([d if d in fa.dimensions() else np.newaxis for d in order]) for fa in fas], order)

def randfa(shp):
    '''Return an FArray of the given shape with random values between 0 and 1.'''
    return FArray(np.random.random(tuple(shp.values())),shp.keys())

def dict_union(it):
    it = iter(it)
    try:
        u=it.next().copy()
    except StopIteration:
        u=dict()
    for d in it:
        for k,v in d.iteritems():
            try:
                if u[k]!=v:
                    raise Exception('Different values for '+repr(k)+': '+repr(u[k])+','+repr(v))
            except KeyError:
                u[k] = v
    return u    
