فهرست منبع

DRY - remove duplicate driver code

oz123 10 سال پیش
والد
کامیت
d82c1dab13
6فایلهای تغییر یافته به همراه44 افزوده شده و 175 حذف شده
  1. 27 17
      pwman/data/database.py
  2. 6 4
      pwman/data/drivers/mysql.py
  3. 2 0
      pwman/data/drivers/postgresql.py
  4. 5 151
      pwman/data/drivers/sqlite.py
  5. 2 2
      tests/test_postgresql.py
  6. 2 1
      tests/test_sqlite.py

+ 27 - 17
pwman/data/database.py

@@ -101,14 +101,15 @@ class Database(object):
         clean = ("delete from TAG where not exists "
                  "(select 'x' from LOOKUP l where l.TAGID = TAG.ID)")
         self._cur.execute(clean)
+        self._con.commit()
 
     def _get_node_tags(self, node):
-        sql = "SELECT tagid FROM LOOKUP WHERE NODEID = %s"
+        sql = "SELECT tagid FROM LOOKUP WHERE NODEID = {}".format(self._sub)
         self._cur.execute(sql, (str(node[0]),))
         tagids = self._cur.fetchall()
         if tagids:
-            sql = ("SELECT DATA FROM TAG WHERE ID IN (%s)"
-                   "" % ','.join(['%s']*len(tagids)))
+            sql = ("SELECT DATA FROM TAG WHERE ID IN"
+                   " ({})".format(','.join([self._sub]*len(tagids))))
             tagids = [str(id[0]) for id in tagids]
             self._cur.execute(sql, (tagids))
             tags = self._cur.fetchall()
@@ -121,7 +122,7 @@ class Database(object):
             self._update_tag_lookup(nodeid, tid)
 
     def _get_tag(self, tagcipher):
-        sql_search = "SELECT ID FROM TAG WHERE DATA = %s"
+        sql_search = "SELECT ID FROM TAG WHERE DATA = {}".format(self._sub)
         self._cur.execute(sql_search, ([tagcipher]))
         rv = self._cur.fetchone()
         return rv
@@ -138,14 +139,15 @@ class Database(object):
                 return self._cur.lastrowid
 
     def _update_tag_lookup(self, nodeid, tid):
-        sql_lookup = "INSERT INTO LOOKUP(nodeid, tagid) VALUES(%s, %s)"
+        sql_lookup = "INSERT INTO LOOKUP(nodeid, tagid) VALUES({}, {})".format(
+            self._sub, self._sub)
         self._cur.execute(sql_lookup, (nodeid, tid))
         self._con.commit()
 
     def getnodes(self, ids):
         if ids:
             sql = ("SELECT * FROM NODE WHERE ID IN ({})"
-                   "".format(','.join('%s' for i in ids)))
+                   "".format(','.join(self._sub for i in ids)))
         else:
             sql = "SELECT * FROM NODE"
         self._cur.execute(sql, (ids))
@@ -158,6 +160,7 @@ class Database(object):
         return nodes_w_tags
 
     def listnodes(self, filter=None):
+        """return a list of node ids"""
         if not filter:
             sql_all = "SELECT ID FROM NODE"
             self._cur.execute(sql_all)
@@ -168,8 +171,7 @@ class Database(object):
             if not tagid:
                 return []  # pragma: no cover
 
-            sql_filter = "SELECT NODEID FROM LOOKUP WHERE TAGID = %s "
-            self._cur.execute(sql_filter, (tagid))
+            self._cur.execute(self._list_nodes_sql, (tagid))
             self._con.commit()
             ids = self._cur.fetchall()
             return [id[0] for id in ids]
@@ -194,17 +196,19 @@ class Database(object):
             return [t[0] for t in tags]
         return []  # pragma: no cover
 
-    # TODO: add this to tests !
-    def editnode(self, nid, **kwargs):  # pragma: no cover
+    # TODO: add this to tests of postgresql and mysql!
+    def editnode(self, nid, **kwargs):
         tags = kwargs.pop('tags', None)
-        sql = ("UPDATE NODE SET %s WHERE ID = %%s "
-               "" % ','.join('%s=%%s' % k for k in list(kwargs)))
+        sql = ("UPDATE NODE SET {} WHERE ID = {} ".format(
+            ','.join(['{}={}'.format(k, self._sub) for k in list(kwargs)]),
+            self._sub))
+
         self._cur.execute(sql, (list(kwargs.values()) + [nid]))
         if tags:
             # update all old node entries in lookup
             # create new entries
             # clean all old tags
-            sql_clean = "DELETE FROM LOOKUP WHERE NODEID=?"
+            sql_clean = "DELETE FROM LOOKUP WHERE NODEID={}".format(self._sub)
             self._cur.execute(sql_clean, (str(nid),))
             self._setnodetags(nid, tags)
 
@@ -212,9 +216,9 @@ class Database(object):
 
     def removenodes(self, nid):
         # shall we do this also in the sqlite driver?
-        sql_clean = "DELETE FROM LOOKUP WHERE NODEID=%s"
+        sql_clean = "DELETE FROM LOOKUP WHERE NODEID={}".format(self._sub)
         self._cur.execute(sql_clean, nid)
-        sql_rm = "delete from NODE where ID = %s"
+        sql_rm = "delete from NODE where ID = {}".format(self._sub)
         self._cur.execute(sql_rm, nid)
         self._con.commit()
         self._con.commit()
@@ -227,10 +231,15 @@ class Database(object):
     def save_crypto_info(self, seed, digest):
         """save the random seed and the digested key"""
         self._cur.execute("DELETE  FROM CRYPTO")
-        self._cur.execute("INSERT INTO CRYPTO VALUES(%s, %s)", (seed, digest))
+        self._cur.execute("INSERT INTO CRYPTO VALUES({}, {})".format(self._sub,
+                                                                     self._sub),
+                          (seed, digest))
         self._con.commit()
 
     def loadkey(self):
+        """
+        return _keycrypted
+        """
         sql = "SELECT * FROM CRYPTO"
         try:
             self._cur.execute(sql)
@@ -241,7 +250,8 @@ class Database(object):
 
     def savekey(self, key):
         salt, digest = key.split('$6$')
-        sql = "INSERT INTO CRYPTO(SEED, DIGEST) VALUES(%s,%s)"
+        sql = "INSERT INTO CRYPTO(SEED, DIGEST) VALUES({},{})".format(self._sub,
+                                                                      self._sub)
         self._cur.execute("DELETE FROM CRYPTO")
         self._cur.execute(sql, (salt, digest))
         self._digest = digest.encode('utf-8')

+ 6 - 4
pwman/data/drivers/mysql.py

@@ -16,10 +16,10 @@
 # ============================================================================
 # Copyright (C) 2012-2015 Oz Nahum <nahumoz@gmail.com>
 # ============================================================================
-#mysql -u root -p
-#create database pwmantest
-#create user 'pwman'@'localhost' IDENTIFIED BY '123456';
-#grant all on pwmantest.* to 'pwman'@'localhost';
+# mysql -u root -p
+# create database pwmantest
+# create user 'pwman'@'localhost' IDENTIFIED BY '123456';
+# grant all on pwmantest.* to 'pwman'@'localhost';
 
 """MySQL Database implementation."""
 from pwman.data.database import Database, __DB_FORMAT__
@@ -53,6 +53,8 @@ class MySQLDatabase(Database):
     def __init__(self, mysqluri, dbformat=__DB_FORMAT__):
         self.dburi = mysqluri
         self.dbversion = dbformat
+        self._sub = "%s"
+        self._list_nodes_sql = "SELECT NODEID FROM LOOKUP WHERE TAGID = %s "
         self._add_node_sql = ("INSERT INTO NODE(USERNAME, PASSWORD, URL, "
                               "NOTES) "
                               "VALUES(%s, %s, %s, %s)")

+ 2 - 0
pwman/data/drivers/postgresql.py

@@ -61,6 +61,8 @@ class PostgresqlDatabase(Database):
         """
         self._pgsqluri = pgsqluri
         self.dbversion = dbformat
+        self._sub = "%s"
+        self._list_nodes_sql = "SELECT NODEID FROM LOOKUP WHERE TAGID = %s "
         self._add_node_sql = ('INSERT INTO NODE(USERNAME, PASSWORD, URL, '
                               'NOTES) VALUES(%s, %s, %s, %s) RETURNING ID')
         self._insert_tag_sql = "INSERT INTO TAG(DATA) VALUES(%s) RETURNING ID"

+ 5 - 151
pwman/data/drivers/sqlite.py

@@ -45,36 +45,17 @@ class SQLite(Database):
         """Initialise SQLitePwmanDatabase instance."""
         self._filename = filename
         self.dbformat = dbformat
+        self._add_node_sql = ("INSERT INTO NODE(USER, PASSWORD, URL, NOTES)"
+                              "VALUES(?, ?, ?, ?)")
+        self._list_nodes_sql = "SELECT NODEID FROM LOOKUP WHERE TAGID = ? "
+        self._insert_tag_sql = "INSERT INTO TAG(DATA) VALUES(?)"
+        self._sub = '?'
 
     def _open(self):
         self._con = sqlite.connect(self._filename)
         self._cur = self._con.cursor()
         self._create_tables()
 
-    def listnodes(self, filter=None):
-        """return a list of node ids"""
-        if not filter:
-            sql_all = "SELECT ID FROM NODE"
-            self._cur.execute(sql_all)
-            ids = self._cur.fetchall()
-            return [id[0] for id in ids]
-        else:
-            tagid = self._get_tag(filter)
-            if not tagid:
-                return []
-            sql_filter = "SELECT NODEID FROM LOOKUP WHERE TAGID = ? "
-            self._cur.execute(sql_filter, (tagid))
-            ids = self._cur.fetchall()
-            return [id[0] for id in ids]
-
-    def listtags(self):
-        self._clean_orphans()
-        get_tags = "select data from tag"
-        self._cur.execute(get_tags)
-        tags = self._cur.fetchall()
-        if tags:
-            return [t[0] for t in tags]
-        return []
 
     def _create_tables(self):
         self._cur.execute("PRAGMA TABLE_INFO(NODE)")
@@ -114,130 +95,3 @@ class SQLite(Database):
             self._con.rollback()
             raise e
 
-    def fetch_crypto_info(self):
-        self._cur.execute("SELECT * FROM CRYPTO")
-        keyrow = self._cur.fetchone()
-        return keyrow
-
-    def save_crypto_info(self, seed, digest):
-        """save the random seed and the digested key"""
-        self._cur.execute("DELETE  FROM CRYPTO")
-        self._cur.execute("INSERT INTO CRYPTO VALUES(?, ?)", [seed, digest])
-        self._con.commit()
-
-    def add_node(self, node):
-        sql = ("INSERT INTO NODE(USER, PASSWORD, URL, NOTES)"
-               "VALUES(?, ?, ?, ?)")
-        node_tags = list(node)
-        node, tags = node_tags[:4], node_tags[-1]
-        self._cur.execute(sql, (node))
-        self._setnodetags(self._cur.lastrowid, tags)
-        self._con.commit()
-
-    def _get_tag(self, tagcipher):
-        sql_search = "SELECT ID FROM TAG WHERE DATA = ?"
-        self._cur.execute(sql_search, ([tagcipher]))
-        rv = self._cur.fetchone()
-        return rv
-
-    def _get_or_create_tag(self, tagcipher):
-        rv = self._get_tag(tagcipher)
-        if rv:
-            return rv[0]
-        else:
-            sql_insert = "INSERT INTO TAG(DATA) VALUES(?)"
-            self._cur.execute(sql_insert, ([tagcipher]))
-            return self._cur.lastrowid
-
-    def _update_tag_lookup(self, nodeid, tid):
-        sql_lookup = "INSERT INTO LOOKUP(nodeid, tagid) VALUES(?,?)"
-        self._cur.execute(sql_lookup, (nodeid, tid))
-        self._con.commit()
-
-    def _setnodetags(self, nodeid, tags):
-        for tag in tags:
-            tid = self._get_or_create_tag(tag)
-            self._update_tag_lookup(nodeid, tid)
-
-    def _get_node_tags(self, node):
-        sql = "SELECT tagid FROM LOOKUP WHERE NODEID = ?"
-        tagids = self._cur.execute(sql, (str(node[0]),)).fetchall()
-        sql = ("SELECT DATA FROM TAG WHERE ID IN (%s)"
-               "" % ','.join('?'*len(tagids)))
-        tagids = [str(id[0]) for id in tagids]
-        self._cur.execute(sql, (tagids))
-        tags = self._cur.fetchall()
-        for t in tags:
-            yield t[0]
-
-    def getnodes(self, ids):
-        """
-        get nodes as raw ciphertext
-        """
-        if ids:
-            sql = ("SELECT * FROM NODE WHERE ID IN ({})"
-                   "".format(','.join('?'*len(ids))))
-        else:
-            sql = "SELECT * FROM NODE"
-
-        self._cur.execute(sql, (ids))
-        nodes = self._cur.fetchall()
-        nodes_w_tags = []
-        for node in nodes:
-            tags = list(self._get_node_tags(node))
-            nodes_w_tags.append(list(node) + tags)
-
-        return nodes_w_tags
-
-    def editnode(self, nid, **kwargs):
-        tags = kwargs.pop('tags', None)
-        sql = ("UPDATE NODE SET %s WHERE ID = ? "
-               "" % ','.join('%s=?' % k for k in list(kwargs)))
-        self._cur.execute(sql, (list(kwargs.values()) + [nid]))
-        if tags:
-            # update all old node entries in lookup
-            # create new entries
-            # clean all old tags
-            sql_clean = "DELETE FROM LOOKUP WHERE NODEID=?"
-            self._cur.execute(sql_clean, (str(nid),))
-            self._setnodetags(nid, tags)
-
-        self._con.commit()
-
-    def removenodes(self, nids):
-        sql_rm = "delete from node where id in (%s)" % ','.join('?'*len(nids))
-        self._cur.execute(sql_rm, (nids))
-        self._con.commit()
-
-    def _clean_orphans(self):
-        clean = ("delete from tag where not exists "
-                 "(select 'x' from lookup l where l.tagid = tag.id)")
-
-        self._cur.execute(clean)
-        self._con.commit()
-
-    def savekey(self, key):
-        salt, digest = key.split('$6$')
-        sql = "INSERT INTO CRYPTO(SEED, DIGEST) VALUES(?,?)"
-        self._cur.execute("DELETE FROM CRYPTO")
-        self._cur.execute(sql, (salt, digest))
-        self._digest = digest.encode('utf-8')
-        self._salt = salt.encode('utf-8')
-        self._con.commit()
-
-    def loadkey(self):
-        # TODO: rename this method!
-        """
-        return _keycrypted
-        """
-        sql = "SELECT * FROM CRYPTO"
-        try:
-            seed, digest = self._cur.execute(sql).fetchone()
-            return seed + u'$6$' + digest
-        except TypeError:
-            return None
-
-    def close(self):
-        self._clean_orphans()
-        self._cur.close()
-        self._con.close()

+ 2 - 2
tests/test_postgresql.py

@@ -30,8 +30,8 @@ from pwman.util.crypto_engine import CryptoEngine
 # testing on linux host
 # su - postgres
 # psql
-# postgres=# create user $YOUR_USERNAME;
-# postgres=# grant ALL ON DATABASE pwman to $YOUR_USERNAME;
+# postgres=# CREATE USER tester WITH PASSWORD '123456';
+# postgres=# grant ALL ON DATABASE pwman to tester;
 #
 ##
 

+ 2 - 1
tests/test_sqlite.py

@@ -151,7 +151,8 @@ class TestSQLite(unittest.TestCase):
         self.assertEqual(4, len(list(tags)))
 
     def test_a11_test_rmnodes(self):
-        self.db.removenodes([1, 2])
+        for n in [1, 2]:
+            self.db.removenodes([n])
         rv = self.db._cur.execute("select * from node").fetchall()
         self.assertListEqual(rv, [])