Эх сурвалжийг харах

Completely implement mysql driver

oz123 10 жил өмнө
parent
commit
713a64e4bc

+ 0 - 1
.coveragerc

@@ -1,5 +1,4 @@
 [run]
 omit = pwman/tests/*.py, pwman/ui/mac.py, pwman/ui/win.py, \
        pwman/data/convertdb.py, \
-       pwman/data/drivers/mysql.py, 
 source = pwman

+ 132 - 0
pwman/data/drivers/mysql.py

@@ -108,3 +108,135 @@ class MySQLDatabase(Database):
             self._con.commit()
         except mysql.ProgrammingError:  # pragma: no cover
             self._con.rollback()
+
+    def getnodes(self, ids):
+        if ids:
+            sql = ("SELECT * FROM NODE WHERE ID IN ({})"
+                   "".format(','.join('%s' for i in ids)))
+        else:
+            sql = "SELECT * FROM NODE"
+        self._cur.execute(sql, (ids))
+        nodes = self._cur.fetchall()
+        nodes_w_tags = []
+        for node in nodes:
+            tags = list(self._get_node_tags(node))
+            nodes_w_tags.append(list(node) + tags)
+
+        return nodes_w_tags
+
+    def add_node(self, node):
+        sql = ("INSERT INTO NODE(USERNAME, PASSWORD, URL, NOTES)"
+               "VALUES(%s, %s, %s, %s)")
+        node_tags = list(node)
+        node, tags = node_tags[:4], node_tags[-1]
+        self._cur.execute(sql, (node))
+        nid = self._cur.lastrowid
+        self._setnodetags(nid, tags)
+        self._con.commit()
+
+    def _get_node_tags(self, node):
+        sql = "SELECT tagid FROM LOOKUP WHERE NODEID = %s"
+        self._cur.execute(sql, (str(node[0]),))
+        tagids = self._cur.fetchall()
+        if tagids:
+            sql = ("SELECT DATA FROM TAG WHERE ID IN (%s)"
+                   "" % ','.join(['%s']*len(tagids)))
+            tagids = [str(id[0]) for id in tagids]
+            self._cur.execute(sql, (tagids))
+            tags = self._cur.fetchall()
+            for t in tags:
+                yield t[0]
+
+    def _setnodetags(self, nodeid, tags):
+        for tag in tags:
+            tid = self._get_or_create_tag(tag)
+            self._update_tag_lookup(nodeid, tid)
+
+    def _get_tag(self, tagcipher):
+        sql_search = "SELECT ID FROM TAG WHERE DATA = %s"
+        self._cur.execute(sql_search, ([tagcipher]))
+        rv = self._cur.fetchone()
+        return rv
+
+    def _get_or_create_tag(self, tagcipher):
+        rv = self._get_tag(tagcipher)
+        if rv:
+            return rv[0]
+        else:
+            sql_insert = "INSERT INTO TAG(DATA) VALUES(%s)"
+            self._cur.execute(sql_insert, ([tagcipher]))
+            return self._cur.lastrowid
+
+    def _update_tag_lookup(self, nodeid, tid):
+        sql_lookup = "INSERT INTO LOOKUP(nodeid, tagid) VALUES(%s, %s)"
+        self._cur.execute(sql_lookup, (nodeid, tid))
+        self._con.commit()
+
+    def fetch_crypto_info(self):
+        self._cur.execute("SELECT * FROM CRYPTO")
+        row = self._cur.fetchone()
+        return row
+
+    def listtags(self):
+        self._clean_orphans()
+        get_tags = "select DATA from TAG"
+        self._cur.execute(get_tags)
+        tags = self._cur.fetchall()
+        if tags:
+            return [t[0] for t in tags]
+        return []  # pragma: no cover
+
+    def listnodes(self, filter=None):
+        if not filter:
+            sql_all = "SELECT ID FROM NODE"
+            self._cur.execute(sql_all)
+            ids = self._cur.fetchall()
+            return [id[0] for id in ids]
+        else:
+            tagid = self._get_tag(filter)
+            if not tagid:
+                return []  # pragma: no cover
+
+            sql_filter = "SELECT NODEID FROM LOOKUP WHERE TAGID = %s "
+            self._cur.execute(sql_filter, (tagid))
+            self._con.commit()
+            ids = self._cur.fetchall()
+            return [id[0] for id in ids]
+
+    def save_crypto_info(self, seed, digest):
+        """save the random seed and the digested key"""
+        self._cur.execute("DELETE  FROM CRYPTO")
+        self._cur.execute("INSERT INTO CRYPTO VALUES(%s, %s)", (seed, digest))
+        self._con.commit()
+
+    def loadkey(self):
+        sql = "SELECT * FROM CRYPTO"
+        try:
+            self._cur.execute(sql)
+            seed, digest = self._cur.fetchone()
+            return seed + u'$6$' + digest
+        except TypeError:  # pragma: no cover
+            return None
+
+    def _clean_orphans(self):
+        clean = ("delete from TAG where not exists "
+                 "(select 'x' from LOOKUP l where l.TAGID = TAG.ID)")
+        self._cur.execute(clean)
+
+    def removenodes(self, nid):
+        # shall we do this also in the sqlite driver?
+        sql_clean = "DELETE FROM LOOKUP WHERE NODEID=%s"
+        self._cur.execute(sql_clean, nid)
+        sql_rm = "delete from NODE where ID = %s"
+        self._cur.execute(sql_rm, nid)
+        self._con.commit()
+        self._con.commit()
+
+    def savekey(self, key):
+        salt, digest = key.split('$6$')
+        sql = "INSERT INTO CRYPTO(SEED, DIGEST) VALUES(%s,%s)"
+        self._cur.execute("DELETE FROM CRYPTO")
+        self._cur.execute(sql, (salt, digest))
+        self._digest = digest.encode('utf-8')
+        self._salt = salt.encode('utf-8')
+        self._con.commit()

+ 74 - 3
pwman/tests/test_mysql.py

@@ -23,6 +23,7 @@ if sys.version_info.major > 2:  # pragma: no cover
     from urllib.parse import urlparse
 else:  # pragma: no cover
     from urlparse import urlparse
+import MySQLdb
 from pwman.data.drivers.mysql import MySQLDatabase
 from pwman.util.crypto_engine import CryptoEngine
 
@@ -31,10 +32,10 @@ class TestMySQLDatabase(unittest.TestCase):
 
     @classmethod
     def setUpClass(self):
-        u = "mysql://pwman:123456@localhost/pwmantest"
+        u = "mysql://pwman:123456@localhost:3306/pwmantest"
         u = urlparse(u)
         # password required, for all hosts
-        # u = "postgresql://<user>:<pass>@localhost/pwman"
+        # u = "mysql://<user>:<pass>@localhost/pwmantest"
         self.db = MySQLDatabase(u)
         self.db._open()
 
@@ -48,7 +49,77 @@ class TestMySQLDatabase(unittest.TestCase):
         self.db._con.commit()
 
     def test_1_con(self):
-        pass
+        self.assertIsInstance(self.db._con, MySQLdb.connections.Connection)
+
+    def test_2_create_tables(self):
+        self.db._create_tables()
+        # invoking this method a second time should not raise an exception
+        self.db._create_tables()
+
+    def test_3_load_key(self):
+        self.db.savekey('SECRET$6$KEY')
+        secretkey = self.db.loadkey()
+        self.assertEqual(secretkey, 'SECRET$6$KEY')
+
+    def test_4_save_crypto(self):
+        self.db.save_crypto_info("TOP", "SECRET")
+        secretkey = self.db.loadkey()
+        self.assertEqual(secretkey, 'TOP$6$SECRET')
+        row = self.db.fetch_crypto_info()
+        self.assertEqual(row, ('TOP', 'SECRET'))
+
+    def test_5_add_node(self):
+        innode = ["TBONE", "S3K43T", "example.org", "some note",
+                  ["bartag", "footag"]]
+        self.db.add_node(innode)
+
+        outnode = self.db.getnodes([1])[0]
+        self.assertEqual(innode[:-1] + [t for t in innode[-1]], outnode[1:])
+
+    def test_6_list_nodes(self):
+        ret = self.db.listnodes()
+        self.assertEqual(ret, [1])
+        ret = self.db.listnodes("footag")
+        self.assertEqual(ret, [1])
+
+    def test_6a_list_tags(self):
+        ret = self.db.listtags()
+        self.assertListEqual(ret, ['bartag', 'footag'])
+
+    def test_6b_get_nodes(self):
+        ret = self.db.getnodes([1])
+        retb = self.db.getnodes([])
+        self.assertListEqual(ret, retb)
+
+    def test_7_get_or_create_tag(self):
+        s = self.db._get_or_create_tag("SECRET")
+        s1 = self.db._get_or_create_tag("SECRET")
+
+        self.assertEqual(s, s1)
+
+    def test_7a_clean_orphans(self):
+
+        self.db._clean_orphans()
+        rv = self.db._get_tag("SECRET")
+        self.assertIsNone(rv)
+
+    def test_8_remove_node(self):
+        self.db.removenodes([1])
+        n = self.db.listnodes()
+        self.assertEqual(len(n), 0)
+
+    def test_9_check_db_version(self):
+
+        dburi = "mysql://pwman:123456@localhost:3306/pwmantest"
+        v = self.db.check_db_version(urlparse(dburi))
+        self.assertEqual(v, '0.6')
+        self.db._cur.execute("DROP TABLE DBVERSION")
+        self.db._con.commit()
+        v = self.db.check_db_version(urlparse(dburi))
+        self.assertEqual(v, None)
+        self.db._cur.execute("CREATE TABLE DBVERSION("
+                             "VERSION TEXT NOT NULL) ")
+        self.db._con.commit()
 
 if __name__ == '__main__':
 

+ 2 - 0
pwman/tests/test_pwman.py

@@ -26,6 +26,7 @@ from .test_crypto_engine import CryptoEngineTest, TestPassGenerator
 from .test_config import TestConfig
 from .test_sqlite import TestSQLite
 from .test_postgresql import TestPostGresql
+from .test_mysql import TestMySQLDatabase
 from .test_importer import TestImporter
 from .test_factory import TestFactory
 from .test_base_ui import TestBaseUI
@@ -55,6 +56,7 @@ def suite():
     suite.addTest(loader.loadTestsFromTestCase(TestConfig))
     suite.addTest(loader.loadTestsFromTestCase(TestSQLite))
     suite.addTest(loader.loadTestsFromTestCase(TestPostGresql))
+    suite.addTest(loader.loadTestsFromTestCase(TestMySQLDatabase))
     suite.addTest(loader.loadTestsFromTestCase(TestImporter))
     suite.addTest(loader.loadTestsFromTestCase(TestFactory))
     suite.addTest(loader.loadTestsFromTestCase(TestBaseUI))