import unittest
import symfer as s
import os
import subprocess
import tempfile
import numpy as np
from symfer.utils import TicToc

# need jsymfer in current dir for this -- TODO find a better place to stick it
class TestMultifac(unittest.TestCase):
    '''Compare multi-factor inference in Numpy and jsymfer.''' 
    
    @classmethod
    def setUpClass(cls):
        _,cls.tempfile1 = tempfile.mkstemp()
        _,cls.tempfile2 = tempfile.mkstemp()

    @classmethod
    def tearDownClass(cls):
        os.remove(cls.tempfile1)
        os.remove(cls.tempfile2)
        #print('remove '+cls.tempfile2+' yourself')
        
    def dumpnload(self,factor):
        s.dumpyaml(factor,self.tempfile)
        self.assertEqual(s.loadyaml(self.tempfile),factor)

    def allclose(self,m1,m2):
        return np.allclose(np.array(m1.par),np.array(m2.par),atol=1e-06)
        
    
    # ------------- TESTS -------------------

    def test_multifac(self):
        model = s.loadhugin('../examples/Water.net')
        target = ['CKNI_12_30']
        ve = s.ve_minweight(model,target)
        marg = s.marginals(s.junctiontree(ve))
        with TicToc() as nptime:
            npres = s.evaluate(marg)
        #print('numpy: ' + str(nptime) + ' sec')
        with TicToc() as ctime:
            cres = s.evalc(marg)
        #print('c: ' + str(ctime) + ' sec')
        s.dumpyaml(marg,self.tempfile1)
        subprocess.check_call('java -jar jsymfer.jar '+self.tempfile1+' > '+self.tempfile2,shell=True)
        jres = s.loadyaml(self.tempfile2)
        for m1,m2,m3 in zip(npres['fac'],cres['fac'],jres['fac']):
            self.assertTrue(self.allclose(m1,m2))
            self.assertTrue(self.allclose(m1,m3))
                    
        

if __name__ == '__main__':
    unittest.main()


