import itertools
from .utils import product
from .factor import *

__all__ = """detect_noisy_or""".split()

def detect_noisy_or(mn,tolerance=0.0001):
    # assume child is first, and domains are ordered as [False,True]
    try:
        assert all(len(d)==2 for d in mn.domtypes.itervalues())
        ninputs = len(mn.domlist)-1
        leak = mn.par[0]
        qs = [(mn.par[2**(i+1)] / leak) for i in xrange(ninputs)]
        for p,indices in zip(mn.par,binarycount(1+ninputs)):
            if indices[0]==0:  # output is False
                noisy_or_prob = leak
                for q,ix in zip(qs,indices[1:]):
                    if ix: noisy_or_prob *= q
                assert abs(p - noisy_or_prob) < tolerance
            else: # output is True
                # use noisy_or_prob from previous iteration
                assert abs(p + noisy_or_prob - 1) < tolerance
    except AssertionError:
        return None
    else:
        child = mn.domlist[0]
        fac = {}
        booldom = [False,True]
        for var,q in zip(mn.domlist[1:],qs):
            auxvar = var + '_aux'
            dom = [{auxvar:booldom}, {var:mn.domtypes[var]}]
            par = [1.0, 0.0, q, 1-q]
            fac[auxvar] = Multinom(dom,par)
        if leak<1:  # or use tolerance?
            leakvar = child + '_leak'
            dom = [{leakvar:booldom}]
            par = [1-leak, leak]
            fac[leakvar] = Multinom(dom,par)
        dom = [{auxvar:booldom} for auxvar in fac]
        cod = [{child:mn.domtypes[child]}]
        sumout = fac.keys()
        fac[child] = Fun(cod=cod, dom=dom, cls='or')
        return sumout,fac  # cannot make a sumprod because of deterministic fun
        
def binarycount(nbits):
    """Count in binary with leftmost tuple element changing fastest."""
    fastestright = itertools.product(*[(0,1) for _ in xrange(nbits)])
    return (tuple(reversed(pr)) for pr in fastestright) 
    
