"""
Tests some functions inside utilities/import_datasets.py
"""
import unittest
import numpy as np
from utilities.import_datasets import loadvars, array_generator
import os
current_dir = os.path.dirname(__file__)
# All have three branches 'A','B','C' 100 events each
path_pi = os.path.join(current_dir, 'dummy/dummy_pi.root') # A = 1, B = 2, C = 3
path_k = os.path.join(current_dir, 'dummy/dummy_k.root') # A = 4, B = 5, C = 6
path_dat = os.path.join(current_dir, 'dummy/dummy_dat.root') # A = 7, B = 8, C = 9
[docs]class TestLoadvars(unittest.TestCase):
"""
Class of tests for the load_vars() function
"""
[docs] def test_loadvars_flag(self):
"""
Tests that the array extracted from a root file corresponds to the
expected one. We use a dummy Root file with variables generate for this
purpose
"""
tree = 'dummytree;1'
dummyvars = ('A', 'C')
v_pi, v_k = loadvars(path_pi, path_k, tree,
vars=dummyvars, flag_column=True)
self.assertEqual(v_pi.shape, (100, 3))
self.assertEqual(v_k.shape, (100, 3))
# !!!! can't figure out how to do this with the np directly
self.assertAlmostEqual(tuple(v_pi[:, 0]), tuple(1*np.ones(100)))
self.assertAlmostEqual(tuple(v_pi[:, 1]), tuple(3*np.ones(100)))
self.assertAlmostEqual(tuple(v_pi[:, 2]), tuple(np.zeros(100)))
self.assertAlmostEqual(tuple(v_k[:, 0]), tuple(4*np.ones(100)))
self.assertAlmostEqual(tuple(v_k[:, 1]), tuple(6*np.ones(100)))
self.assertAlmostEqual(tuple(v_k[:, 2]), tuple(np.ones(100)))
[docs] def test_loadvars_noflag(self):
"""
Tests that the array extracted from a root file corresponds to the
expected one. We use a dummy Root file with variables generate for this
purpose
"""
tree = 'dummytree;1'
dummyvars = ('C', 'B')
v_pi, v_k = loadvars(path_pi, path_k, tree,
vars=dummyvars, flag_column=False)
self.assertEqual(v_pi.shape, (100, 2))
self.assertEqual(v_k.shape, (100, 2))
self.assertAlmostEqual(tuple(v_pi[:, 0]), tuple(3*np.ones(100)))
self.assertAlmostEqual(tuple(v_pi[:, 1]), tuple(2*np.ones(100)))
self.assertAlmostEqual(tuple(v_k[:, 0]), tuple(6*np.ones(100)))
self.assertAlmostEqual(tuple(v_k[:, 1]), tuple(5*np.ones(100)))
[docs] def test_loadvars_1D(self):
"""
Tests that the array extracted from a root file corresponds to the
expected one. We use a dummy Root file with variables generate for this
purpose
"""
tree = 'dummytree;1'
dummyvars = ('B')
v_pi, v_k = loadvars(path_pi, path_k, tree,
vars=dummyvars, flag_column=False, flatten1d=False)
self.assertEqual(v_pi.shape, (100, 1))
self.assertEqual(v_k.shape, (100, 1))
self.assertAlmostEqual(tuple(v_pi[:, 0]), tuple(2*np.ones(100)))
self.assertAlmostEqual(tuple(v_k[:, 0]), tuple(5*np.ones(100)))
v_pi, v_k = loadvars(path_pi, path_k, tree,
vars=dummyvars, flag_column=False, flatten1d=True)
self.assertEqual(v_pi.shape, (100,))
self.assertEqual(v_k.shape, (100,))
self.assertAlmostEqual(tuple(v_pi), tuple(2*np.ones(100)))
self.assertAlmostEqual(tuple(v_k), tuple(5*np.ones(100)))
[docs]class TestArrayGenerator(unittest.TestCase):
"""
Class of tests for the array_generator() function
"""
[docs] def test_training_testing(self):
"""
Tests that the array extracted from a root file corresponds to the
expected one. We use a dummy Root file with variables generate for this
purpose
"""
tree = 'dummytree;1'
dummyvars = ('A', 'C')
v_mc, v_dat = array_generator(
(path_pi, path_k, path_dat), tree, vars=dummyvars, n_mc=100, n_data=50)
self.assertEqual(v_mc.shape, (100, 3))
self.assertEqual(v_dat.shape, (50, 2))
self.assertAlmostEqual(tuple(v_dat[:, 0]), tuple(7*np.ones(50)))
self.assertAlmostEqual(tuple(v_dat[:, 1]), tuple(9*np.ones(50)))
if __name__ == '__main__':
unittest.main()