Explorar o código

Pass all SQLite and Postgresql tests

Oz N Tiram %!s(int64=8) %!d(string=hai) anos
pai
achega
b8fb4ff3ef

+ 15 - 7
pwman/data/database.py

@@ -54,7 +54,6 @@ class Database(object):
             self._con.rollback()
             self._con.rollback()
 
 
     def _create_tables(self):
     def _create_tables(self):
-
         if self._check_tables():
         if self._check_tables():
             return
             return
         try:
         try:
@@ -67,7 +66,7 @@ class Database(object):
 
 
             self._cur.execute("CREATE TABLE TAG"
             self._cur.execute("CREATE TABLE TAG"
                               "(ID  SERIAL PRIMARY KEY,"
                               "(ID  SERIAL PRIMARY KEY,"
-                              "DATA VARCHAR(255) NOT NULL UNIQUE)")
+                              "DATA TEXT NOT NULL UNIQUE)")
 
 
             self._cur.execute("CREATE TABLE LOOKUP ("
             self._cur.execute("CREATE TABLE LOOKUP ("
                               "nodeid INTEGER NOT NULL REFERENCES NODE(ID),"
                               "nodeid INTEGER NOT NULL REFERENCES NODE(ID),"
@@ -124,14 +123,20 @@ class Database(object):
         self._cur.execute(sql_search)
         self._cur.execute(sql_search)
 
 
         ce = CryptoEngine.get()
         ce = CryptoEngine.get()
-        tag = ce.decrypt(tagcipher)
+
+        try:
+            tag = ce.decrypt(tagcipher)
+            encrypted = True
+        except Exception:
+            tag = tagcipher
+            encrypted = False
 
 
         rv = self._cur.fetchall()
         rv = self._cur.fetchall()
         for idx, cipher in rv:
         for idx, cipher in rv:
-            if tag == ce.decrypt(cipher):
+            if encrypted and tag == ce.decrypt(cipher):
+                return idx
+            elif tag == cipher:
                 return idx
                 return idx
-
-        return
 
 
     def _get_or_create_tag(self, tagcipher):
     def _get_or_create_tag(self, tagcipher):
         rv = self._get_tag(tagcipher)
         rv = self._get_tag(tagcipher)
@@ -158,9 +163,12 @@ class Database(object):
             sql = "SELECT * FROM NODE"
             sql = "SELECT * FROM NODE"
         self._cur.execute(sql, (ids))
         self._cur.execute(sql, (ids))
         nodes = self._cur.fetchall()
         nodes = self._cur.fetchall()
+        # sqlite returns nodes as bytes, postgresql returns them as str
+        if isinstance(nodes[0][1], str):
+            nodes = [node for node in nodes]
         nodes_w_tags = []
         nodes_w_tags = []
         for node in nodes:
         for node in nodes:
-            tags = list(self._get_node_tags(node))
+            tags = [t for t in self._get_node_tags(node)]
             nodes_w_tags.append(list(node) + tags)
             nodes_w_tags.append(list(node) + tags)
 
 
         return nodes_w_tags
         return nodes_w_tags

+ 1 - 1
pwman/data/drivers/sqlite.py

@@ -74,7 +74,7 @@ class SQLite(Database):
 
 
         self._cur.execute("CREATE TABLE TAG"
         self._cur.execute("CREATE TABLE TAG"
                           "(ID INTEGER PRIMARY KEY AUTOINCREMENT,"
                           "(ID INTEGER PRIMARY KEY AUTOINCREMENT,"
-                          "DATA BLOB NOT NULL UNIQUE)")
+                          "DATA BLOB NOT NULL)")
 
 
         self._cur.execute("CREATE TABLE LOOKUP ("
         self._cur.execute("CREATE TABLE LOOKUP ("
                           "nodeid INTEGER NOT NULL, "
                           "nodeid INTEGER NOT NULL, "

+ 5 - 1
pwman/util/crypto_engine.py

@@ -45,7 +45,11 @@ def encode_AES(cipher, clear_text):
 
 
 
 
 def decode_AES(cipher, encoded_text):
 def decode_AES(cipher, encoded_text):
-    return cipher.decrypt(base64.b64decode(encoded_text)).rstrip()
+    if not isinstance(encoded_text, bytes):
+        encoded_text = encoded_text.encode()
+
+    encoded_text = base64.b64decode(encoded_text)
+    return cipher.decrypt(encoded_text).rstrip()
 
 
 
 
 def generate_password(pass_len=8, uppercase=True, lowercase=True, digits=True,
 def generate_password(pass_len=8, uppercase=True, lowercase=True, digits=True,

+ 1 - 1
tests/test_postgresql.py

@@ -77,10 +77,10 @@ class TestPostGresql(unittest.TestCase):
         self.assertEqual(row, ('TOP', 'SECRET'))
         self.assertEqual(row, ('TOP', 'SECRET'))
 
 
     def test_5_add_node(self):
     def test_5_add_node(self):
+        # fuck, saving b"TBONE" has is harder ...
         innode = ["TBONE", "S3K43T", "example.org", "some note",
         innode = ["TBONE", "S3K43T", "example.org", "some note",
                   ["footag", "bartag"]]
                   ["footag", "bartag"]]
         self.db.add_node(innode)
         self.db.add_node(innode)
-
         outnode = self.db.getnodes([1])[0]
         outnode = self.db.getnodes([1])[0]
         self.assertEqual(innode[:-1] + [t for t in innode[-1]], outnode[1:])
         self.assertEqual(innode[:-1] + [t for t in innode[-1]], outnode[1:])