Browse Source

Add _create_tables to mysql

oz123 10 years ago
parent
commit
647cc1d1ca
2 changed files with 65 additions and 4 deletions
  1. 55 0
      pwman/data/drivers/mysql.py
  2. 10 4
      pwman/tests/test_mysql.py

+ 55 - 0
pwman/data/drivers/mysql.py

@@ -53,3 +53,58 @@ class MySQLDatabase(Database):
         self._mysqluri = mysqluri
         self.dbversion = dbformat
 
+    def _open(self):
+
+        port = 3306
+        credentials, host = self.dburi.netloc.split('@')
+        user, passwd = credentials.split(':')
+        if ':' in host:
+            host, port = host.split(':')
+            port = int(port)
+
+        self._con = mysql.connect(host=host, port=port, user=user,
+                                  passwd=passwd,
+                                  db=self.dburi.path.lstrip('/'))
+        self._cur = self._con.cursor()
+        self._create_tables()
+
+    def _create_tables(self):
+
+        try:
+            self._cur.execute("SELECT 1 from DBVERSION")
+            version = self._cur.fetchone()
+            if version:
+                return
+        except mysql.ProgrammingError:
+            self._con.rollback()
+
+        try:
+            self._cur.execute("CREATE TABLE NODE(ID SERIAL PRIMARY KEY, "
+                              "USERNAME TEXT NOT NULL, "
+                              "PASSWORD TEXT NOT NULL, "
+                              "URL TEXT NOT NULL, "
+                              "NOTES TEXT NOT NULL"
+                              ")")
+
+            self._cur.execute("CREATE TABLE TAG"
+                              "(ID SERIAL PRIMARY KEY,"
+                              "DATA TEXT NOT NULL UNIQUE)")
+
+            self._cur.execute("CREATE TABLE LOOKUP ("
+                              "nodeid SERIAL REFERENCES NODE(ID),"
+                              "tagid SERIAL REFERENCES TAG(ID)"
+                              ")")
+
+            self._cur.execute("CREATE TABLE CRYPTO "
+                              "(SEED TEXT, DIGEST TEXT)")
+
+            self._cur.execute("CREATE TABLE DBVERSION("
+                              "VERSION TEXT NOT NULL DEFAULT {}"
+                              ")".format(__DB_FORMAT__))
+
+            self._cur.execute("INSERT INTO DBVERSION VALUES(%s)",
+                              (self.dbversion,))
+
+            self._con.commit()
+        except mysql.ProgrammingError:  # pragma: no cover
+            self._con.rollback()

+ 10 - 4
pwman/tests/test_mysql.py

@@ -23,20 +23,19 @@ if sys.version_info.major > 2:  # pragma: no cover
     from urllib.parse import urlparse
 else:  # pragma: no cover
     from urlparse import urlparse
-import psycopg2 as pg
 from pwman.data.drivers.mysql import MySQLDatabase
 from pwman.util.crypto_engine import CryptoEngine
 
 
-class TestPostGresql(unittest.TestCase):
+class TestMySQLDatabase(unittest.TestCase):
 
     @classmethod
     def setUpClass(self):
-        u = "postgresql://tester:123456@localhost/pwman"
+        u = "mysql://pwman:123456@localhost/pwmantest"
         u = urlparse(u)
         # password required, for all hosts
         # u = "postgresql://<user>:<pass>@localhost/pwman"
-        self.db = PostgresqlDatabase(u)
+        self.db = MySQLDatabase(u)
         self.db._open()
 
     @classmethod
@@ -47,3 +46,10 @@ class TestPostGresql(unittest.TestCase):
         self.db._cur.execute("DROP TABLE DBVERSION")
         self.db._cur.execute("DROP TABLE CRYPTO")
         self.db._con.commit()
+
+if __name__ == '__main__':
+
+    ce = CryptoEngine.get()
+    ce.callback = DummyCallback()
+    ce.changepassword(reader=give_key)
+    unittest.main(verbosity=2, failfast=True)