Browse Source

Extend create_tables of postgresql

oz123 10 năm trước cách đây
mục cha
commit
62886a8bc2
2 tập tin đã thay đổi với 50 bổ sung22 xóa
  1. 42 18
      pwman/data/drivers/postgresql.py
  2. 8 4
      pwman/tests/test_postgresql.py

+ 42 - 18
pwman/data/drivers/postgresql.py

@@ -32,6 +32,7 @@ from pwman.data.database import Database, DatabaseException, __DB_FORMAT__
 
 
 class PostgresqlDatabase(Database):
+
     """
     Postgresql Database implementation
 
@@ -79,6 +80,7 @@ class PostgresqlDatabase(Database):
         Initialise PostgresqlDatabase instance.
         """
         self._pgsqluri = pgsqluri
+        self.dbversion = dbformat
 
     def _open(self):
 
@@ -86,8 +88,6 @@ class PostgresqlDatabase(Database):
         self._con = pg.connect(database=u.path[1:], user=u.username,
                                password=u.password, host=u.hostname)
         self._cur = self._con.cursor()
-        self._cur.execute("CREATE TABLE DBVERSION(VERSION TEXT NOT NULL "
-                          "DEFAULT '0.6')")
 
     def _get_cur(self):
         try:
@@ -209,7 +209,8 @@ class PostgresqlDatabase(Database):
     def addnodes(self, nodes):
         cursor = self._get_cur()
         for n in nodes:
-            sql = "INSERT INTO %sNODES(DATA) VALUES(%%(data)s)" % (self._prefix)
+            sql = "INSERT INTO %sNODES(DATA) VALUES(%%(data)s)" % (
+                self._prefix)
             if not isinstance(n, Node):
                 raise DatabaseException("Tried to insert foreign object into "
                                         "database [%s]", n)
@@ -296,7 +297,8 @@ class PostgresqlDatabase(Database):
                 continue
             except KeyError as e:
                 pass  # not in cache
-            sql = "SELECT ID FROM %sTAGS WHERE DATA = %%(tag)s" % (self._prefix)
+            sql = "SELECT ID FROM %sTAGS WHERE DATA = %%(tag)s" % (
+                self._prefix)
             if not isinstance(t, Tag):
                 raise DatabaseException("Tried to insert foreign object"
                                         " into database [%s]", t)
@@ -322,7 +324,8 @@ class PostgresqlDatabase(Database):
     def _deletenodetags(self, node):
         try:
             cursor = self._get_cur()
-            sql = "DELETE FROM %sLOOKUP WHERE NODE = %%(node)d" % (self._prefix)
+            sql = "DELETE FROM %sLOOKUP WHERE NODE = %%(node)d" % (
+                self._prefix)
             cursor.execute(sql, {"node": node.get_id()})
 
         except pgdb.DatabaseError as e:
@@ -368,30 +371,51 @@ class PostgresqlDatabase(Database):
     def _create_tables(self):
 
         try:
-            self._cur.execute("SELECT 1 from NODE")
+            self._cur.execute("SELECT 1 from DBVERSION")
             version = self._cur.fetchone()
             if version:
                 return
         except pg.ProgrammingError:
             self._con.rollback()
 
-        self._cur.execute("CREATE TABLE NODE(ID SERIAL PRIMARY KEY, "
-                          "USERNAME TEXT NOT NULL, "
-                          "PASSWORD TEXT NOT NULL, "
-                          "URL TEXT NOT NULL, "
-                          "NOTES TEXT NOT NULL"
-                          ")")
-
-        self._cur.execute("CREATE TABLE DBVERSION("
-                          "VERSION TEXT NOT NULL DEFAULT {}"
-                          ")".format(__DB_FORMAT__))
-    #def _checktables(self):
+        try:
+            self._cur.execute("CREATE TABLE NODE(ID SERIAL PRIMARY KEY, "
+                              "USERNAME TEXT NOT NULL, "
+                              "PASSWORD TEXT NOT NULL, "
+                              "URL TEXT NOT NULL, "
+                              "NOTES TEXT NOT NULL"
+                              ")")
+
+            self._cur.execute("CREATE TABLE TAG"
+                              "(ID SERIAL PRIMARY KEY,"
+                              "DATA TEXT NOT NULL UNIQUE)")
+
+            self._cur.execute("CREATE TABLE LOOKUP ("
+                              "nodeid SERIAL REFERENCES NODE(ID),"
+                              "tagid SERIAL REFERENCES TAG(ID)"
+                              ")")
+
+            self._cur.execute("CREATE TABLE CRYPTO "
+                              "(SEED TEXT, DIGEST TEXT)")
+
+            self._cur.execute("CREATE TABLE DBVERSION("
+                              "VERSION TEXT NOT NULL DEFAULT {}"
+                              ")".format(__DB_FORMAT__))
+
+            self._cur.execute("INSERT INTO DBVERSION VALUES(%s)",
+                              (self.dbversion,))
+
+            self._con.commit()
+        except pg.ProgrammingError:
+            self._con.rollback()
+            #raise E
+    # def _checktables(self):%
     #    """ Check if the Pwman tables exist """
     #    cursor = self._get_cur()
     #    cursor.execute("SELECT 1 FROM INFORMATION_SCHEMA.TABLES WHERE "
     #                   "TABLE_NAME = '%snodes'" % (self._prefix))
     #    if not cursor.fetchone():
-    #        # table doesn't exist, create it
+    # table doesn't exist, create it
     #        cursor.execute(("CREATE TABLE %sNODES "
     #                        "(ID SERIAL PRIMARY KEY, DATA TEXT NOT NULL)"
     #                        % (self._prefix)))

+ 8 - 4
pwman/tests/test_postgresql.py

@@ -38,9 +38,12 @@ class TestPostGresql(unittest.TestCase):
 
     @classmethod
     def tearDownClass(self):
-
-        self.db._cur.execute("TRUNCATE DBVERSION")
-        self.db._cur.execute("TRUNCATE NODE")
+        self.db._cur.execute("DROP TABLE LOOKUP")
+        self.db._cur.execute("DROP TABLE TAG")
+        self.db._cur.execute("DROP TABLE NODE")
+        self.db._cur.execute("DROP TABLE DBVERSION")
+        self.db._cur.execute("DROP TABLE CRYPTO")
+        self.db._con.commit()
 
     def test_1_con(self):
 
@@ -48,7 +51,8 @@ class TestPostGresql(unittest.TestCase):
 
     def test_2_create_tables(self):
         self.db._create_tables()
-
+        # invoking this method a second time should not raise an exception
+        self.db._create_tables()
 if __name__ == '__main__':
 
     ce = CryptoEngine.get()