Browse Source

Continue refactoring: dbfile -> dburi

oz123 10 years ago
parent
commit
4b832a1713

+ 2 - 3
pwman/__init__.py

@@ -133,9 +133,8 @@ def get_conf_options(args, OSX):
 def get_db_version(config, dbtype, args):
     # This method is seriously biased towards SQLite.
     # TODO: make this more Postgresql\Network Database friendly
-    if os.path.exists(config.get_value("Database", "filename")):
-        dbver = factory.check_db_version(dbtype, config.get_value("Database",
-                                                                  "filename"))
+    if os.path.exists(config.get_value("Database", "dburi")):
+        dbver = factory.check_db_version(config.get_value("Database", "dburi"))
     else:
         dbver = __DB_FORMAT__
     return dbver

+ 13 - 8
pwman/data/factory.py

@@ -36,20 +36,25 @@ if sys.version_info.major > 2:  # pragma: no cover
 else:
     from urlparse import urlparse
 
+import os
+
 from pwman.data.database import DatabaseException
 from pwman.data.drivers import sqlite
 
 
-def check_db_version(ftype, filename):
-    if ftype == "SQLite":
+def check_db_version(dburi):
+    dburi = urlparse(dburi)
+    dbtype = dburi.scheme
+    filename = os.path.abspath(dburi.path)
+    if dbtype == "sqlite":
         ver = sqlite.SQLite.check_db_version(filename)
         try:
             return float(ver.strip("\'"))
         except ValueError:
             return 0.3
     # TODO: implement version checks for other supported DBs.
-    if ftype == "Postgresql":
-        ver = sqlite.PostgresqlDatabase.check_db_version(filename)
+    if dbtype == "Postgresql":
+        ver = sqlite.PostgresqlDatabase.check_db_version(dburi)
 
 
 def create(dbtype, version=None, filename=None):
@@ -58,20 +63,20 @@ def create(dbtype, version=None, filename=None):
     Create a Database instance.
     'type' can only be 'SQLite' at the moment
     """
-    if dbtype == "SQLite":
+    if dbtype == "sqlite":
         from pwman.data.drivers import sqlite
         if str(version) == '0.6':
             db = sqlite.SQLite(filename)
         else:
             db = sqlite.SQLite(filename, dbformat=version)
 
-    elif dbtype == "Postgresql":  # pragma: no cover
+    elif dbtype == "postgresql":  # pragma: no cover
         try:
             from pwman.data.drivers import postgresql
             db = postgresql.PostgresqlDatabase()
         except ImportError:
             raise DatabaseException("python-psycopg2 not installed")
-    elif dbtype == "MySQL":  # pragma: no cover
+    elif dbtype == "mysql":  # pragma: no cover
         try:
             from pwman.data.drivers import mysql
             db = mysql.MySQLDatabase()
@@ -98,7 +103,7 @@ def createdb(dburi, version):
         try:
             from pwman.data.drivers import postgresql
             db = postgresql.PostgresqlDatabase(dburi)
-        except ImportError:
+        except ImportError:  # pragma: no cover
             raise DatabaseException("python-psycopg2 not installed")
     elif dbtype == "mysql":  # pragma: no cover
         try:

+ 1 - 1
pwman/tests/test_base_ui.py

@@ -61,7 +61,7 @@ class TestBaseUI(unittest.TestCase):
     def setUp(self):
         "test that the right db instance was created"
         dbver = __DB_FORMAT__
-        self.dbtype = 'SQLite'
+        self.dbtype = 'sqlite'
         self.db = factory.create(self.dbtype, dbver, testdb)
         self.tester = SetupTester(dbver, testdb)
         self.tester.create()

+ 16 - 8
pwman/tests/test_factory.py

@@ -41,26 +41,28 @@ class TestFactory(unittest.TestCase):
 
     def setUp(self):
         "test that the right db instance was created"
-        self.dbtype = 'SQLite'
+        self.dbtype = 'sqlite'
         self.db = factory.create(self.dbtype, __DB_FORMAT__, testdb)
         self.tester = SetupTester(__DB_FORMAT__, testdb)
         self.tester.create()
 
     def test_factory_check_db_ver(self):
-        self.assertEqual(factory.check_db_version('SQLite', testdb), 0.6)
+        self.assertEqual(factory.check_db_version('sqlite://'+testdb), 0.6)
 
     def test_factory_check_db_file(self):
-        db = factory.create('SQLite', version='0.3', filename='baz.db')
+        fn = os.path.join(os.path.dirname(__file__), 'baz.db')
+        db = factory.create('sqlite', filename=fn)
         db._open()
-        self.assertEqual(factory.check_db_version('SQLite', 'baz.db'), 0.3)
-        os.unlink('baz.db')
+        self.assertEqual(factory.check_db_version('sqlite://'+fn), 0.3)
+        os.unlink(fn)
 
     def test_factory_create(self):
-        db = factory.create('SQLite', filename='foo.db')
+        fn = os.path.join(os.path.dirname(__file__), 'foo.db')
+        db = factory.create('sqlite', filename=fn)
         db._open()
-        self.assertTrue(os.path.exists('foo.db'))
+        self.assertTrue(os.path.exists(fn))
         db.close()
-        os.unlink('foo.db')
+        os.unlink(fn)
         self.assertIsInstance(db, SQLite)
         self.assertRaises(DatabaseException, factory.create, 'UNKNOWN')
 
@@ -70,6 +72,12 @@ class TestFactory(unittest.TestCase):
         del db
         db = factory.createdb("postgresql:///pwman", 0.6)
         self.assertIsInstance(db, PostgresqlDatabase)
+        del db
+        db = factory.createdb("sqlite:///test.db", 0.7)
+        self.assertIsInstance(db, SQLite)
+        del db
+        db = factory.createdb("postgresql:///pwman", 0.7)
+        self.assertIsInstance(db, PostgresqlDatabase)
 
 if __name__ == '__main__':
     # make sure we use local pwman

+ 1 - 1
pwman/tests/test_importer.py

@@ -94,7 +94,7 @@ class TestImporter(unittest.TestCase):
         if os.path.exists('importdummy.db'):
             os.unlink('importdummy.db')
         args = Args(import_file=open('import_file.csv'), db='importdummy.db')
-        dbtype, dbver, fname = 'SQLite', 0.6, 'importdummy.db'
+        dbtype, dbver, fname = 'sqlite', 0.6, 'importdummy.db'
         db = pwman.data.factory.create(dbtype, dbver, fname)
         importer = Importer((args, '', db))
         importer.importer.run(callback=DummyCallback)

+ 7 - 6
pwman/tests/test_init.py

@@ -38,7 +38,8 @@ cls_timeout = 5
 [Database]
 """
 
-testdb = os.path.join(os.path.dirname(__file__), "test.pwman.db")
+testdb = os.path.abspath(os.path.join(os.path.dirname(__file__),
+                                      "test.pwman.db"))
 
 
 class TestFactory(unittest.TestCase):
@@ -65,16 +66,16 @@ class TestInit(unittest.TestCase):
 
     def setUp(self):
         "test that the right db instance was created"
-        self.dbtype = 'SQLite'
+        self.dbtype = 'sqlite'
         self.db = factory.create(self.dbtype, __DB_FORMAT__, testdb)
-        self.tester = SetupTester(__DB_FORMAT__, testdb)
+        self.tester = SetupTester(__DB_FORMAT__, dburi=testdb)
         self.tester.create()
 
     def test_get_db_version(self):
-        v = get_db_version(self.tester.configp, 'SQLite', None)
+        v = get_db_version(self.tester.configp, 'sqlite', None)
         self.assertEqual(v, __DB_FORMAT__)
         os.unlink(testdb)
-        v = get_db_version(self.tester.configp, 'SQLite', None)
+        v = get_db_version(self.tester.configp, 'sqlite', None)
         self.assertEqual(v, 0.6)
 
     def test_set_xsel(self):
@@ -93,7 +94,7 @@ class TestInit(unittest.TestCase):
         Args = namedtuple('args', 'cfile, dbase, algo')
         args = Args(cfile='dummy.cfg', dbase='dummy.db', algo='AES')
         xsel, dbtype, configp = get_conf_options(args, 'True')
-        self.assertEqual(dbtype, 'SQLite')
+        self.assertEqual(dbtype, 'sqlite')
 
 
 if __name__ == '__main__':

+ 9 - 12
pwman/tests/test_tools.py

@@ -52,7 +52,7 @@ class DummyCallback4(Callback):
         return u'newsecret'
 
 
-config.default_config['Database'] = {'type': 'SQLite',
+config.default_config['Database'] = {'type': 'sqlite',
                                      'filename':
                                      os.path.join(os.path.dirname(__file__),
                                                   "test.pwman.db"),
@@ -84,13 +84,13 @@ class SetupTester(object):
         self.configp = config.Config(os.path.join(os.path.dirname(__file__),
                                                   "test.conf"),
                                      config.default_config)
-        self.configp.set_value('Database', 'filename',
-                               os.path.join(os.path.dirname(__file__),
-                                            "test.pwman.db"))
+
         self.configp.set_value('Database', 'dburi',
-                               os.path.join('sqlite:///',
-                                            os.path.dirname(__file__),
-                                            "test.pwman.db"))
+                               'sqlite://' + os.path.join(os.path.abspath(
+                                                          os.path.dirname(__file__)),
+                                                          "test.pwman.db")
+                               )
+
         if not OSX:
             self.xselpath = which("xsel")
             self.configp.set_value("Global", "xsel", self.xselpath)
@@ -98,8 +98,7 @@ class SetupTester(object):
             self.xselpath = "xsel"
 
         self.dbver = dbver
-        self.filename = filename
-        self.dburi = dburi
+        self.dburi = self.configp.get_value('Database', 'dburi')
 
     def clean(self):
         dbfile = self.configp.get_value('Database', 'filename')
@@ -121,8 +120,6 @@ class SetupTester(object):
                                    'test.conf'))
 
     def create(self):
-        dbtype = 'SQLite'
-        db = factory.create(dbtype, self.dbver, self.filename)
-        #db = factory.createdb(self.dburi, self.dbver)
+        db = factory.createdb(self.dburi, self.dbver)
         self.cli = PwmanCliNew(db, self.xselpath, DummyCallback,
                                config_parser=self.configp)