Ver Fonte

Fix bug in getnodes

oz123 há 10 anos atrás
pai
commit
8512ef3c17
2 ficheiros alterados com 28 adições e 15 exclusões
  1. 22 14
      pwman/data/drivers/postgresql.py
  2. 6 1
      pwman/data/drivers/sqlite.py

+ 22 - 14
pwman/data/drivers/postgresql.py

@@ -26,7 +26,7 @@ if sys.version_info.major > 2:  # pragma: no cover
 else:
     from urlparse import urlparse
 import psycopg2 as pg
-from pwman.data.database import Database, DatabaseException, __DB_FORMAT__
+from pwman.data.database import Database, __DB_FORMAT__
 
 
 class PostgresqlDatabase(Database):
@@ -59,7 +59,7 @@ class PostgresqlDatabase(Database):
             return version
         except pg.ProgrammingError:
             con.rollback()
-            raise DatabaseException("Something seems fishy with the DB")
+            #raise DatabaseException("Something seems fishy with the DB")
 
     def __init__(self, pgsqluri, dbformat=__DB_FORMAT__):
         """
@@ -70,9 +70,13 @@ class PostgresqlDatabase(Database):
 
     def _open(self):
 
-        u = urlparse(self._pgsqluri)
-        self._con = pg.connect(database=u.path[1:], user=u.username,
-                               password=u.password, host=u.hostname)
+        try:
+            # TODO: remove this. we only want to accept url object
+            u = urlparse(self._pgsqluri)
+            self._con = pg.connect(database=u.path[1:], user=u.username,
+                                   password=u.password, host=u.hostname)
+        except AttributeError:
+            self._con = pg.connect(self._pgsqluri.geturl())
         self._cur = self._con.cursor()
         self._create_tables()
 
@@ -194,17 +198,21 @@ class PostgresqlDatabase(Database):
         sql = "SELECT tagid FROM LOOKUP WHERE NODEID = %s"
         self._cur.execute(sql, (str(node[0]),))
         tagids = self._cur.fetchall()
-        sql = ("SELECT DATA FROM TAG WHERE ID IN (%s)"
-               "" % ','.join(['%s']*len(tagids)))
-        tagids = [str(id[0]) for id in tagids]
-        self._cur.execute(sql, (tagids))
-        tags = self._cur.fetchall()
-        for t in tags:
-            yield t[0]
+        if tagids:
+            sql = ("SELECT DATA FROM TAG WHERE ID IN (%s)"
+                   "" % ','.join(['%s']*len(tagids)))
+            tagids = [str(id[0]) for id in tagids]
+            self._cur.execute(sql, (tagids))
+            tags = self._cur.fetchall()
+            for t in tags:
+                yield t[0]
 
     def getnodes(self, ids):
-        sql = "SELECT * FROM NODE WHERE ID IN ({})".format(','.join('%s' for
-                                                                    i in ids))
+        if ids:
+            sql = ("SELECT * FROM NODE WHERE ID IN ({})"
+                   "".format(','.join('%s' for i in ids)))
+        else:
+            sql = "SELECT * FROM NODE"
         self._cur.execute(sql, (ids))
         nodes = self._cur.fetchall()
         nodes_w_tags = []

+ 6 - 1
pwman/data/drivers/sqlite.py

@@ -174,7 +174,12 @@ class SQLite(Database):
         """
         get nodes as raw ciphertext
         """
-        sql = "SELECT * FROM NODE WHERE ID IN (%s)" % ','.join('?'*len(ids))
+        if ids:
+            sql = ("SELECT * FROM NODE WHERE ID IN ({})"
+                   "".format(','.join('?'*len(ids))))
+        else:
+            sql = "SELECT * FROM NODE"
+
         self._cur.execute(sql, (ids))
         nodes = self._cur.fetchall()
         nodes_w_tags = []