Parcourir la source

fix failing tests in MongoDB

Oz N Tiram il y a 8 ans
Parent
commit
f5a1396ee6
2 fichiers modifiés avec 18 ajouts et 14 suppressions
  1. 13 6
      pwman/data/drivers/mongodb.py
  2. 5 8
      tests/test_mongodb.py

+ 13 - 6
pwman/data/drivers/mongodb.py

@@ -18,6 +18,8 @@
 # ============================================================================
 
 from pwman.data.database import Database, __DB_FORMAT__
+from pwman.util.crypto_engine import CryptoEngine
+
 import pymongo
 
 
@@ -65,14 +67,19 @@ class MongoDB(Database):
 
         return nodes
 
-    def listnodes(self, filter=None):
-        if not filter:
+    def listnodes(self, filter_=None):
+        if not filter_:
             nodes = self._db.nodes.find({}, {'_id': 1})
-
+            return [node["_id"] for node in nodes]
         else:
-            nodes = self._db.nodes.find({"tags": {'$in': [filter]}}, {'_id': 1})
-
-        return [node['_id'] for node in list(nodes)]
+            matching = []
+            ce = CryptoEngine.get()
+            nodes = list(self._db.nodes.find({}, {'_id': 1, 'tags': 1}))
+            for node in nodes:
+                node['tags'] = [ce.decrypt(t) for t in node['tags']]
+                if filter_ in node['tags']:
+                    matching.append(node)
+            return [node["_id"] for node in matching]
 
     def add_node(self, node):
         nid = self._get_next_node_id()

+ 5 - 8
tests/test_mongodb.py

@@ -84,7 +84,7 @@ class TestMongoDB(unittest.TestCase):
                   [u"bartag", u"footag"]]
 
         kwargs = {
-            "username":innode[0], "password": innode[1],
+            "username": innode[0], "password": innode[1],
             "url": innode[2], "notes": innode[3], "tags": innode[4]
         }
 
@@ -99,17 +99,14 @@ class TestMongoDB(unittest.TestCase):
     def test_6_list_nodes(self):
         ret = self.db.listnodes()
         self.assertEqual(ret, [1])
-        ce = CryptoEngine.get()
-        fltr = ce.encrypt("footag")
-        ret = self.db.listnodes(fltr)
-        self.assertEqual(ret, [1])
+        ret = self.db.listnodes(filter_=b"footag")
 
     def test_6a_list_tags(self):
         ret = self.db.listtags()
         ce = CryptoEngine.get()
-        ec_tags = map(ce.encrypt,[u'bartag', u'footag'])
-        for t in ec_tags:
-            self.assertIn(t, ret)
+        tags = list(map(ce.decrypt, ret))
+        for tag in tags:
+            self.assertIn(tag, [b'footag', b'bartag'])
 
     def test_6b_get_nodes(self):
         ret = self.db.getnodes([1])