Quellcode durchsuchen

Add more testing to postgresql driver check_db_version

oz123 vor 10 Jahren
Ursprung
Commit
ee3045617c
2 geänderte Dateien mit 17 neuen und 6 gelöschten Zeilen
  1. 7 4
      pwman/data/drivers/postgresql.py
  2. 10 2
      pwman/tests/test_postgresql.py

+ 7 - 4
pwman/data/drivers/postgresql.py

@@ -22,9 +22,9 @@
 """Postgresql Database implementation."""
 import sys
 if sys.version_info.major > 2:  # pragma: no cover
-    from urllib.parse import urlparse
-else:
-    from urlparse import urlparse
+    from urllib.parse import urlparse, ParseResult
+else:  # pragma: no cover
+    from urlparse import urlparse, ParseResult
 import psycopg2 as pg
 from pwman.data.database import Database, __DB_FORMAT__
 
@@ -49,7 +49,10 @@ class PostgresqlDatabase(Database):
         """
         Check the database version
         """
-        con = pg.connect(dburi.geturl())
+        if isinstance(dburi, ParseResult):
+            con = pg.connect(dburi.geturl())
+        else:
+            con = pg.connect(dburi)
         cur = con.cursor()
         try:
             cur.execute("SELECT VERSION from DBVERSION")

+ 10 - 2
pwman/tests/test_postgresql.py

@@ -16,12 +16,11 @@
 # ============================================================================
 # Copyright (C) 2015 Oz Nahum Tiram <nahumoz@gmail.com>
 # ============================================================================
-
 import unittest
+import psycopg2 as pg
 from pwman.data.drivers.postgresql import PostgresqlDatabase
 from pwman.util.crypto_engine import CryptoEngine
 from .test_crypto_engine import give_key, DummyCallback
-import psycopg2 as pg
 
 ##
 # testing on linux host
@@ -107,6 +106,15 @@ class TestPostGresql(unittest.TestCase):
         n = self.db.listnodes()
         self.assertEqual(len(n), 0)
 
+    def test_9_check_db_version(self):
+
+        dburi = "postgresql:///pwman"
+        v = self.db.check_db_version(dburi)
+        self.assertEqual(v, ('0.6',))
+        self.db._cur.execute("DELETE FROM DBVERSION")
+        self.db._con.commit()
+        v = self.db.check_db_version(dburi)
+        self.assertEqual(v, None)
 
 if __name__ == '__main__':