from . import factor
import yapgvb

__all__ = """drawplan""".split()

def str_elems(s,divider=','):
    return divider.join(str(elem) for elem in s)

def drawplan(plan,filename='out.svg'):
    """Draw a SumProd/Index/Multinom factor expression using graphviz."""

    graph = yapgvb.Digraph()
    
    def add_subtree(plan,mem):
        if id(plan) in mem:
            #print plan
            return mem[id(plan)]
        plannode = graph.add_node()
        mem[id(plan)]=plannode
        plannode.shape='none'
        plannode.fontname='sans'
        #if plan.range == factors.Prob():
        #    plannode.fontcolor='black'
        #else:
        #    plannode.fontcolor='blue3'            
        if isinstance(plan,factor.SumProd):
            plannode.label='/'+str_elems(plan.arg)
            for op in plan.fac:
                childnode = add_subtree(op,mem)
                edge = plannode >> childnode
                edge.dir = 'back'
        #elif isinstance(plan,factor.I):
        #    plannode.label='I'
        elif isinstance(plan,factor.Multinom):
            plannode.label=str_elems(plan.domlist)
        #elif isinstance(plan,factors.Rename):
        #    plannode.label='['+','.join(str(n)+'/'+str(o) for o,n in plan.old2new.items() if o!=n)+']'
        #    plannode >> add_subtree(plan.op,mem)
        elif isinstance(plan,factor.Index):
            #plannode.label=str_elems(plan.vars.keys())
            plannode.label= ','.join(var+':'+val for var,val in plan.det.iteritems())
            childnode = add_subtree(plan.fac[0],mem)
            edge = plannode >> childnode
            edge.dir='back'
            
            #for v,fa in plan.indices.iteritems():
            #    childnode = add_subtree(fa,mem)
            #    edge = plannode >> childnode
            #    edge.label='/'+str(v)
            #    edge.fontcolor='red'
            #    edge.fontname='sans'
        #elif isinstance(plan,factors.Match):
        #    plannode.label='=='
        #    plannode >> add_subtree(plan.op1,mem)
        #    plannode >> add_subtree(plan.op2,mem)
        #elif isinstance(plan,factors.Expand):
        #    plannode.label=str_elems(plan.vars.keys())
        #    childnode = add_subtree(plan.op,mem)
        #    plannode >> childnode
        #    for v,fa in plan.das.iteritems():
        #        childnode = add_subtree(fa,mem)
        #        (plannode >> childnode).label=str(v)+'='
        elif isinstance(plan,dict):
            plannode.label=''
            for ch in plan['fac']:
                add_subtree(ch,mem)
        else:
            plannode.shape='rectangle'
            plannode.label=plan.__class__.__name__+'\n'
            plannode.label+=plan.id+'\n'
        return plannode    
    
    add_subtree(plan,{})
    graph.layout('dot')
    graph.render(filename)    
            


