Bläddra i källkod

Pass all SQLite and Postgresql tests

Oz N Tiram 8 år sedan
förälder
incheckning
b8fb4ff3ef
4 ändrade filer med 22 tillägg och 10 borttagningar
  1. 15 7
      pwman/data/database.py
  2. 1 1
      pwman/data/drivers/sqlite.py
  3. 5 1
      pwman/util/crypto_engine.py
  4. 1 1
      tests/test_postgresql.py

+ 15 - 7
pwman/data/database.py

@@ -54,7 +54,6 @@ class Database(object):
             self._con.rollback()
 
     def _create_tables(self):
-
         if self._check_tables():
             return
         try:
@@ -67,7 +66,7 @@ class Database(object):
 
             self._cur.execute("CREATE TABLE TAG"
                               "(ID  SERIAL PRIMARY KEY,"
-                              "DATA VARCHAR(255) NOT NULL UNIQUE)")
+                              "DATA TEXT NOT NULL UNIQUE)")
 
             self._cur.execute("CREATE TABLE LOOKUP ("
                               "nodeid INTEGER NOT NULL REFERENCES NODE(ID),"
@@ -124,14 +123,20 @@ class Database(object):
         self._cur.execute(sql_search)
 
         ce = CryptoEngine.get()
-        tag = ce.decrypt(tagcipher)
+
+        try:
+            tag = ce.decrypt(tagcipher)
+            encrypted = True
+        except Exception:
+            tag = tagcipher
+            encrypted = False
 
         rv = self._cur.fetchall()
         for idx, cipher in rv:
-            if tag == ce.decrypt(cipher):
+            if encrypted and tag == ce.decrypt(cipher):
+                return idx
+            elif tag == cipher:
                 return idx
-
-        return
 
     def _get_or_create_tag(self, tagcipher):
         rv = self._get_tag(tagcipher)
@@ -158,9 +163,12 @@ class Database(object):
             sql = "SELECT * FROM NODE"
         self._cur.execute(sql, (ids))
         nodes = self._cur.fetchall()
+        # sqlite returns nodes as bytes, postgresql returns them as str
+        if isinstance(nodes[0][1], str):
+            nodes = [node for node in nodes]
         nodes_w_tags = []
         for node in nodes:
-            tags = list(self._get_node_tags(node))
+            tags = [t for t in self._get_node_tags(node)]
             nodes_w_tags.append(list(node) + tags)
 
         return nodes_w_tags

+ 1 - 1
pwman/data/drivers/sqlite.py

@@ -74,7 +74,7 @@ class SQLite(Database):
 
         self._cur.execute("CREATE TABLE TAG"
                           "(ID INTEGER PRIMARY KEY AUTOINCREMENT,"
-                          "DATA BLOB NOT NULL UNIQUE)")
+                          "DATA BLOB NOT NULL)")
 
         self._cur.execute("CREATE TABLE LOOKUP ("
                           "nodeid INTEGER NOT NULL, "

+ 5 - 1
pwman/util/crypto_engine.py

@@ -45,7 +45,11 @@ def encode_AES(cipher, clear_text):
 
 
 def decode_AES(cipher, encoded_text):
-    return cipher.decrypt(base64.b64decode(encoded_text)).rstrip()
+    if not isinstance(encoded_text, bytes):
+        encoded_text = encoded_text.encode()
+
+    encoded_text = base64.b64decode(encoded_text)
+    return cipher.decrypt(encoded_text).rstrip()
 
 
 def generate_password(pass_len=8, uppercase=True, lowercase=True, digits=True,

+ 1 - 1
tests/test_postgresql.py

@@ -77,10 +77,10 @@ class TestPostGresql(unittest.TestCase):
         self.assertEqual(row, ('TOP', 'SECRET'))
 
     def test_5_add_node(self):
+        # fuck, saving b"TBONE" has is harder ...
         innode = ["TBONE", "S3K43T", "example.org", "some note",
                   ["footag", "bartag"]]
         self.db.add_node(innode)
-
         outnode = self.db.getnodes([1])[0]
         self.assertEqual(innode[:-1] + [t for t in innode[-1]], outnode[1:])