Sfoglia il codice sorgente

Add tag insertion to pgsql driver

oz123 10 anni fa
parent
commit
798375dac5
2 ha cambiato i file con 43 aggiunte e 6 eliminazioni
  1. 33 3
      pwman/data/drivers/postgresql.py
  2. 10 3
      pwman/tests/test_postgresql.py

+ 33 - 3
pwman/data/drivers/postgresql.py

@@ -74,13 +74,43 @@ class PostgresqlDatabase(Database):
         self._cur = self._con.cursor()
         self._create_tables()
 
+    def _get_tag(self, tagcipher):
+        sql_search = "SELECT ID FROM TAG WHERE DATA = %s"
+        self._cur.execute(sql_search, ([tagcipher]))
+        rv = self._cur.fetchone()
+        return rv
+
+    def _get_or_create_tag(self, tagcipher):
+        rv = self._get_tag(tagcipher)
+        if rv:
+            return rv[0]
+        else:
+            sql_insert = "INSERT INTO TAG(DATA) VALUES(%s) RETURNING ID"
+            self._cur.execute(sql_insert, ([tagcipher]))
+            rid = self._cur.fetchone()[0]
+            return rid
+
     def close(self):
         # TODO: implement _clean_orphands
         self._cur.close()
         self._con.close()
 
-    def listtags(self, filter=None):
-        pass
+    def listnodes(self, filter=None):
+        if not filter:
+            sql_all = "SELECT ID FROM NODE"
+            self._cur.execute(sql_all)
+            ids = self._cur.fetchall()
+            return [id[0] for id in ids]
+        else:
+            tagid = self._get_tag(filter)
+            if not tagid:
+                return []
+
+            sql_filter = "SELECT NODEID FROM LOOKUP WHERE TAGID = %s "
+            self._cur.execute(sql_filter, (tagid))
+            self._con.commit()
+            ids = self._cur.fetchall()
+            return [id[0] for id in ids]
 
     def editnode(self, nid, **kwargs):
         pass
@@ -110,7 +140,7 @@ class PostgresqlDatabase(Database):
     def removenodes(self, nodes):
         pass
 
-    def listnodes(self):
+    def listtags(self):
         pass
 
     def _create_tables(self):

+ 10 - 3
pwman/tests/test_postgresql.py

@@ -17,11 +17,8 @@
 # Copyright (C) 2015 Oz Nahum Tiram <nahumoz@gmail.com>
 # ============================================================================
 
-import os
 import unittest
-import sys
 from pwman.data.drivers.postgresql import PostgresqlDatabase
-from pwman.data.nodes import Node
 from pwman.util.crypto_engine import CryptoEngine
 from .test_crypto_engine import give_key, DummyCallback
 import psycopg2 as pg
@@ -71,6 +68,16 @@ class TestPostGresql(unittest.TestCase):
         outnode = self.db.getnodes([1])[0]
         self.assertEqual(innode, outnode[1:])
 
+    def test_6_list_nodes(self):
+        self.db.listnodes()
+
+    def test_7_get_or_create_tag(self):
+        s = self.db._get_or_create_tag("SECRET")
+        s1 = self.db._get_or_create_tag("SECRET")
+
+        self.assertEqual(s, s1)
+
+
 if __name__ == '__main__':
 
     ce = CryptoEngine.get()