Browse Source

Always connect to postgresql with username and pass

oz123 10 years ago
parent
commit
31df825242
4 changed files with 20 additions and 26 deletions
  1. 2 0
      .travis.yml
  2. 5 17
      pwman/data/drivers/postgresql.py
  3. 2 2
      pwman/data/factory.py
  4. 11 7
      pwman/tests/test_postgresql.py

+ 2 - 0
.travis.yml

@@ -4,7 +4,9 @@ python:
   - 3.4 
 
 before_script:
+  - psql -c "CREATE USER tester WITH PASSWORD '123456';" -U postgres
   - psql -c 'create database pwman;' -U postgres
+  - psql -c 'grant ALL ON DATABASE pwman to tester' -U postgres
 
 before_install:
   - sudo apt-get update -qq

+ 5 - 17
pwman/data/drivers/postgresql.py

@@ -20,11 +20,6 @@
 # ============================================================================
 
 """Postgresql Database implementation."""
-import sys
-if sys.version_info.major > 2:  # pragma: no cover
-    from urllib.parse import urlparse, ParseResult
-else:  # pragma: no cover
-    from urlparse import urlparse, ParseResult
 import psycopg2 as pg
 from pwman.data.database import Database, __DB_FORMAT__
 
@@ -49,10 +44,10 @@ class PostgresqlDatabase(Database):
         """
         Check the database version
         """
-        if isinstance(dburi, ParseResult):
-            con = pg.connect(dburi.geturl())
-        else:
-            con = pg.connect(dburi)
+        #if isinstance(dburi, ParseResult):
+        #    con = pg.connect(dburi.geturl())
+        #else:
+        con = pg.connect(dburi)
         cur = con.cursor()
         try:
             cur.execute("SELECT VERSION from DBVERSION")
@@ -62,7 +57,6 @@ class PostgresqlDatabase(Database):
             return version
         except pg.ProgrammingError:
             con.rollback()
-            #raise DatabaseException("Something seems fishy with the DB")
 
     def __init__(self, pgsqluri, dbformat=__DB_FORMAT__):
         """
@@ -73,13 +67,7 @@ class PostgresqlDatabase(Database):
 
     def _open(self):
 
-        try:
-            # TODO: remove this. we only want to accept url object
-            u = urlparse(self._pgsqluri)
-            self._con = pg.connect(database=u.path[1:], user=u.username,
-                                   password=u.password, host=u.hostname)
-        except AttributeError:
-            self._con = pg.connect(self._pgsqluri.geturl())
+        self._con = pg.connect(self._pgsqluri.geturl())
         self._cur = self._con.cursor()
         self._create_tables()
 

+ 2 - 2
pwman/data/factory.py

@@ -53,9 +53,9 @@ def check_db_version(dburi):
             return float(ver.strip("\'"))
         except ValueError:
             return 0.3
-    # TODO: implement version checks for other supported DBs.
     if dbtype == "postgresql":
-        ver = postgresql.PostgresqlDatabase.check_db_version(dburi)
+        #  ver = postgresql.PostgresqlDatabase.check_db_version(dburi)
+        ver = postgresql.PostgresqlDatabase.check_db_version(dburi.geturl())
 
 
 def createdb(dburi, version):

+ 11 - 7
pwman/tests/test_postgresql.py

@@ -17,11 +17,15 @@
 # Copyright (C) 2015 Oz Nahum Tiram <nahumoz@gmail.com>
 # ============================================================================
 import unittest
+import sys
+from .test_crypto_engine import give_key, DummyCallback
+if sys.version_info.major > 2:  # pragma: no cover
+    from urllib.parse import urlparse
+else:  # pragma: no cover
+    from urlparse import urlparse
 import psycopg2 as pg
 from pwman.data.drivers.postgresql import PostgresqlDatabase
 from pwman.util.crypto_engine import CryptoEngine
-from .test_crypto_engine import give_key, DummyCallback
-
 ##
 # testing on linux host
 # su - postgres
@@ -36,10 +40,10 @@ class TestPostGresql(unittest.TestCase):
 
     @classmethod
     def setUpClass(self):
-        # no password required, for testing in travis
-        u = "postgresql:///pwman"
-        # password required, for all other hosts
-        #u = "postgresql://<user>:<pass>@localhost/pwman"
+        u = "postgresql://tester:123456@localhost/pwman"
+        u = urlparse(u)
+        # password required, for all hosts
+        # u = "postgresql://<user>:<pass>@localhost/pwman"
         self.db = PostgresqlDatabase(u)
         self.db._open()
 
@@ -114,7 +118,7 @@ class TestPostGresql(unittest.TestCase):
 
     def test_9_check_db_version(self):
 
-        dburi = "postgresql:///pwman"
+        dburi = "postgresql://tester:123456@localhost/pwman"
         v = self.db.check_db_version(dburi)
         self.assertEqual(v, ('0.6',))
         self.db._cur.execute("DROP TABLE DBVERSION")