|
@@ -26,7 +26,7 @@ if sys.version_info.major > 2: # pragma: no cover
|
|
|
else:
|
|
|
from urlparse import urlparse
|
|
|
import psycopg2 as pg
|
|
|
-from pwman.data.database import Database, DatabaseException, __DB_FORMAT__
|
|
|
+from pwman.data.database import Database, __DB_FORMAT__
|
|
|
|
|
|
|
|
|
class PostgresqlDatabase(Database):
|
|
@@ -59,7 +59,7 @@ class PostgresqlDatabase(Database):
|
|
|
return version
|
|
|
except pg.ProgrammingError:
|
|
|
con.rollback()
|
|
|
- raise DatabaseException("Something seems fishy with the DB")
|
|
|
+ #raise DatabaseException("Something seems fishy with the DB")
|
|
|
|
|
|
def __init__(self, pgsqluri, dbformat=__DB_FORMAT__):
|
|
|
"""
|
|
@@ -70,9 +70,13 @@ class PostgresqlDatabase(Database):
|
|
|
|
|
|
def _open(self):
|
|
|
|
|
|
- u = urlparse(self._pgsqluri)
|
|
|
- self._con = pg.connect(database=u.path[1:], user=u.username,
|
|
|
- password=u.password, host=u.hostname)
|
|
|
+ try:
|
|
|
+ # TODO: remove this. we only want to accept url object
|
|
|
+ u = urlparse(self._pgsqluri)
|
|
|
+ self._con = pg.connect(database=u.path[1:], user=u.username,
|
|
|
+ password=u.password, host=u.hostname)
|
|
|
+ except AttributeError:
|
|
|
+ self._con = pg.connect(self._pgsqluri.geturl())
|
|
|
self._cur = self._con.cursor()
|
|
|
self._create_tables()
|
|
|
|
|
@@ -194,17 +198,21 @@ class PostgresqlDatabase(Database):
|
|
|
sql = "SELECT tagid FROM LOOKUP WHERE NODEID = %s"
|
|
|
self._cur.execute(sql, (str(node[0]),))
|
|
|
tagids = self._cur.fetchall()
|
|
|
- sql = ("SELECT DATA FROM TAG WHERE ID IN (%s)"
|
|
|
- "" % ','.join(['%s']*len(tagids)))
|
|
|
- tagids = [str(id[0]) for id in tagids]
|
|
|
- self._cur.execute(sql, (tagids))
|
|
|
- tags = self._cur.fetchall()
|
|
|
- for t in tags:
|
|
|
- yield t[0]
|
|
|
+ if tagids:
|
|
|
+ sql = ("SELECT DATA FROM TAG WHERE ID IN (%s)"
|
|
|
+ "" % ','.join(['%s']*len(tagids)))
|
|
|
+ tagids = [str(id[0]) for id in tagids]
|
|
|
+ self._cur.execute(sql, (tagids))
|
|
|
+ tags = self._cur.fetchall()
|
|
|
+ for t in tags:
|
|
|
+ yield t[0]
|
|
|
|
|
|
def getnodes(self, ids):
|
|
|
- sql = "SELECT * FROM NODE WHERE ID IN ({})".format(','.join('%s' for
|
|
|
- i in ids))
|
|
|
+ if ids:
|
|
|
+ sql = ("SELECT * FROM NODE WHERE ID IN ({})"
|
|
|
+ "".format(','.join('%s' for i in ids)))
|
|
|
+ else:
|
|
|
+ sql = "SELECT * FROM NODE"
|
|
|
self._cur.execute(sql, (ids))
|
|
|
nodes = self._cur.fetchall()
|
|
|
nodes_w_tags = []
|