Kaynağa Gözat

Add more testing to postgresql driver

oz123 10 yıl önce
ebeveyn
işleme
96f608fcb1

+ 8 - 8
pwman/data/drivers/postgresql.py

@@ -82,10 +82,10 @@ class PostgresqlDatabase(Database):
             self._cur.execute(sql_all)
             ids = self._cur.fetchall()
             return [id[0] for id in ids]
-        else:                               # pragma: no cover
+        else:
             tagid = self._get_tag(filter)
             if not tagid:
-                return []
+                return []  # pragma: no cover
 
             sql_filter = "SELECT NODEID FROM LOOKUP WHERE TAGID = %s "
             self._cur.execute(sql_filter, (tagid))
@@ -100,7 +100,7 @@ class PostgresqlDatabase(Database):
         tags = self._cur.fetchall()
         if tags:
             return [t[0] for t in tags]
-        return []
+        return []  # pragma: no cover
 
     def _create_tables(self):
 
@@ -140,7 +140,7 @@ class PostgresqlDatabase(Database):
                               (self.dbversion,))
 
             self._con.commit()
-        except pg.ProgrammingError:
+        except pg.ProgrammingError:  # pragma: no cover
             self._con.rollback()
 
     def fetch_crypto_info(self):
@@ -192,9 +192,10 @@ class PostgresqlDatabase(Database):
 
     def _get_node_tags(self, node):  # pragma: no cover
         sql = "SELECT tagid FROM LOOKUP WHERE NODEID = %s"
-        tagids = self._cur.execute(sql, (str(node[0]),)).fetchall()
+        self._cur.execute(sql, (str(node[0]),))
+        tagids = self._cur.fetchall()
         sql = ("SELECT DATA FROM TAG WHERE ID IN (%s)"
-               "" % ','.join('%%s'*len(tagids)))
+               "" % ','.join(['%s']*len(tagids)))
         tagids = [str(id[0]) for id in tagids]
         self._cur.execute(sql, (tagids))
         tags = self._cur.fetchall()
@@ -208,8 +209,7 @@ class PostgresqlDatabase(Database):
         nodes = self._cur.fetchall()
         nodes_w_tags = []
         for node in nodes:
-            #tags = list(self._get_node_tags(node))
-            tags = []
+            tags = list(self._get_node_tags(node))
             nodes_w_tags.append(list(node) + tags)
 
         return nodes_w_tags

+ 11 - 3
pwman/tests/test_postgresql.py

@@ -71,13 +71,21 @@ class TestPostGresql(unittest.TestCase):
         self.assertEqual(row, ('TOP', 'SECRET'))
 
     def test_5_add_node(self):
-        innode = ["TBONE", "S3K43T", "example.org", "some note"]
+        innode = ["TBONE", "S3K43T", "example.org", "some note",
+                  ["footag", "bartag"]]
         self.db.add_node(innode)
         outnode = self.db.getnodes([1])[0]
-        self.assertEqual(innode, outnode[1:])
+        self.assertEqual(innode[:-1] + [t for t in innode[-1]], outnode[1:])
 
     def test_6_list_nodes(self):
-        self.db.listnodes()
+        ret = self.db.listnodes()
+        self.assertEqual(ret, [1])
+        ret = self.db.listnodes("footag")
+        self.assertEqual(ret, [1])
+
+    def test_6a_list_tags(self):
+        ret = self.db.listtags()
+        self.assertListEqual(ret, ['footag', 'bartag'])
 
     def test_7_get_or_create_tag(self):
         s = self.db._get_or_create_tag("SECRET")