Explorar o código

Improve config module and it's testing

oz123 %!s(int64=10) %!d(string=hai) anos
pai
achega
23fee72d93
Modificáronse 2 ficheiros con 22 adicións e 2 borrados
  1. 15 0
      pwman/tests/db_tests.py
  2. 7 2
      pwman/util/config.py

+ 15 - 0
pwman/tests/db_tests.py

@@ -377,6 +377,8 @@ class ConfigTest(unittest.TestCase):
         self.db = factory.create(self.dbtype, dbver)
         self.tester = SetupTester(dbver)
         self.tester.create()
+        self.orig_config = config._conf.copy()
+        self.orig_config['Encryption'] = {'algorithm': 'AES'}
 
     def test_config_write(self):
         _filename = os.path.join(os.path.dirname(__file__),
@@ -403,6 +405,7 @@ class ConfigTest(unittest.TestCase):
     def test_add_default(self):
         config.add_defaults({'Section1': {'name': 'value'}})
         self.assertIn('Section1', config._defaults)
+        config._defaults.pop('Section1')
 
     def test_get_conf(self):
         cnf = config.get_conf()
@@ -458,3 +461,15 @@ class ConfigTest(unittest.TestCase):
         # args.cfile does not exist, hence the config values
         # should be the same as in the defaults
         config.set_config(foo)
+
+    def test_get_conf_options(self):
+        Args = namedtuple('args', 'cfile, dbase, algo')
+        args = Args(cfile='nosuchfile', dbase='dummy.db', algo='AES')
+        self.assertRaises(Exception, get_conf_options, (args, 'False'))
+        config._defaults['Database']['type'] = 'SQLite'
+        # config._conf['Database']['type'] = 'SQLite'
+        xsel, dbtype = get_conf_options(args, 'True')
+        self.assertEqual(dbtype, 'SQLite')
+
+    def tearDown(self):
+        config._conf = self.orig_config.copy()

+ 7 - 2
pwman/util/config.py

@@ -41,6 +41,11 @@ _conf = dict()
 _defaults = dict()
 
 
+def set_conf(conf_dict):
+    global _conf
+    _conf = conf_dict
+
+
 def set_defaults(defaults):
     global _defaults
     _defaults = defaults
@@ -49,7 +54,7 @@ def set_defaults(defaults):
 def add_defaults(defaults):
     global _defaults
     for n in defaults.keys():
-        if not n in _defaults:
+        if n not in _defaults:
             _defaults[n] = dict()
         for k in defaults[n].keys():
             _defaults[n][k] = defaults[n][k]
@@ -74,7 +79,7 @@ def get_value(section, name):
 
 def set_value(section, name, value):
     global _conf
-    if not section in _conf:
+    if section not in _conf:
         _conf[section] = dict()
     _conf[section][name] = value